fix(middleware): Handle invalid tool calls in dangling pairing middleware (#2890) (#2891)

This commit is contained in:
Nan Gao
2026-05-12 04:55:13 +02:00
committed by GitHub
parent 0009655454
commit 20d2d2b373
2 changed files with 107 additions and 26 deletions
@@ -14,6 +14,10 @@ def _ai_with_tool_calls(tool_calls):
return AIMessage(content="", tool_calls=tool_calls)
def _ai_with_invalid_tool_calls(invalid_tool_calls):
return AIMessage(content="", tool_calls=[], invalid_tool_calls=invalid_tool_calls)
def _tool_msg(tool_call_id, name="test_tool"):
return ToolMessage(content="result", tool_call_id=tool_call_id, name=name)
@@ -22,6 +26,16 @@ def _tc(name="bash", tc_id="call_1"):
return {"name": name, "id": tc_id, "args": {}}
def _invalid_tc(name="write_file", tc_id="write_file:36", error="Failed to parse tool arguments: malformed JSON"):
return {
"type": "invalid_tool_call",
"name": name,
"id": tc_id,
"args": '{"description":"write report","path":"/mnt/user-data/outputs/report.md","content":"bad {"json"}"}',
"error": error,
}
class TestBuildPatchedMessagesNoPatch:
def test_empty_messages(self):
mw = DanglingToolCallMiddleware()
@@ -144,6 +158,42 @@ class TestBuildPatchedMessagesPatching:
assert patched[1].name == "bash"
assert patched[1].status == "error"
def test_invalid_tool_call_is_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert len(patched) == 2
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "write_file:36"
assert patched[1].name == "write_file"
assert patched[1].status == "error"
assert "arguments were invalid" in patched[1].content
assert "Failed to parse tool arguments" in patched[1].content
def test_valid_and_invalid_tool_calls_are_both_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [
AIMessage(
content="",
tool_calls=[_tc("bash", "call_1")],
invalid_tool_calls=[_invalid_tc()],
)
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
tool_msgs = [m for m in patched if isinstance(m, ToolMessage)]
assert len(tool_msgs) == 2
assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "write_file:36"}
def test_invalid_tool_call_already_responded_is_not_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_invalid_tool_calls([_invalid_tc()]),
_tool_msg("write_file:36", "write_file"),
]
assert mw._build_patched_messages(msgs) is None
class TestWrapModelCall:
def test_no_patch_passthrough(self):