mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
This commit is contained in:
@@ -14,6 +14,10 @@ def _ai_with_tool_calls(tool_calls):
|
||||
return AIMessage(content="", tool_calls=tool_calls)
|
||||
|
||||
|
||||
def _ai_with_invalid_tool_calls(invalid_tool_calls):
|
||||
return AIMessage(content="", tool_calls=[], invalid_tool_calls=invalid_tool_calls)
|
||||
|
||||
|
||||
def _tool_msg(tool_call_id, name="test_tool"):
|
||||
return ToolMessage(content="result", tool_call_id=tool_call_id, name=name)
|
||||
|
||||
@@ -22,6 +26,16 @@ def _tc(name="bash", tc_id="call_1"):
|
||||
return {"name": name, "id": tc_id, "args": {}}
|
||||
|
||||
|
||||
def _invalid_tc(name="write_file", tc_id="write_file:36", error="Failed to parse tool arguments: malformed JSON"):
|
||||
return {
|
||||
"type": "invalid_tool_call",
|
||||
"name": name,
|
||||
"id": tc_id,
|
||||
"args": '{"description":"write report","path":"/mnt/user-data/outputs/report.md","content":"bad {"json"}"}',
|
||||
"error": error,
|
||||
}
|
||||
|
||||
|
||||
class TestBuildPatchedMessagesNoPatch:
|
||||
def test_empty_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
@@ -144,6 +158,42 @@ class TestBuildPatchedMessagesPatching:
|
||||
assert patched[1].name == "bash"
|
||||
assert patched[1].status == "error"
|
||||
|
||||
def test_invalid_tool_call_is_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
assert len(patched) == 2
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "write_file:36"
|
||||
assert patched[1].name == "write_file"
|
||||
assert patched[1].status == "error"
|
||||
assert "arguments were invalid" in patched[1].content
|
||||
assert "Failed to parse tool arguments" in patched[1].content
|
||||
|
||||
def test_valid_and_invalid_tool_calls_are_both_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[_tc("bash", "call_1")],
|
||||
invalid_tool_calls=[_invalid_tc()],
|
||||
)
|
||||
]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
tool_msgs = [m for m in patched if isinstance(m, ToolMessage)]
|
||||
assert len(tool_msgs) == 2
|
||||
assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "write_file:36"}
|
||||
|
||||
def test_invalid_tool_call_already_responded_is_not_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_invalid_tool_calls([_invalid_tc()]),
|
||||
_tool_msg("write_file:36", "write_file"),
|
||||
]
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
|
||||
class TestWrapModelCall:
|
||||
def test_no_patch_passthrough(self):
|
||||
|
||||
Reference in New Issue
Block a user