mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
fix(middleware): normalize tool result adjacency before model calls (#2939)
* normalizing tool-call transcripts before invocation * test(middleware): cover tool result regrouping edge cases
This commit is contained in:
@@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching:
|
||||
assert patched[1].name == "bash"
|
||||
assert patched[1].status == "error"
|
||||
|
||||
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
|
||||
middleware = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
patched = middleware._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert isinstance(patched[2], HumanMessage)
|
||||
|
||||
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_2", "read"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert isinstance(patched[2], ToolMessage)
|
||||
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
|
||||
assert isinstance(patched[3], HumanMessage)
|
||||
|
||||
def test_valid_adjacent_tool_results_are_unchanged(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
HumanMessage(content="next"),
|
||||
]
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_ai_with_tool_calls([_tc("read", "call_2")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
_tool_msg("call_2", "read"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert isinstance(patched[2], HumanMessage)
|
||||
assert isinstance(patched[3], AIMessage)
|
||||
assert isinstance(patched[4], ToolMessage)
|
||||
assert patched[4].tool_call_id == "call_2"
|
||||
|
||||
def test_orphan_tool_message_is_preserved_during_grouping(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
orphan = _tool_msg("orphan_call", "orphan")
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
orphan,
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert orphan in patched
|
||||
assert patched.count(orphan) == 1
|
||||
|
||||
def test_invalid_tool_call_is_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
|
||||
|
||||
Reference in New Issue
Block a user