diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 000ca51a2..6026d834e 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do. import json import logging +from collections import defaultdict, deque from collections.abc import Awaitable, Callable from typing import override @@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): This normalizes model-bound causal order before provider serialization while preserving already-valid transcripts unchanged. """ - tool_messages_by_id: dict[str, ToolMessage] = {} + tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque) for msg in messages: if isinstance(msg, ToolMessage): - tool_messages_by_id.setdefault(msg.tool_call_id, msg) + tool_messages_by_id[msg.tool_call_id].append(msg) tool_call_ids: set[str] = set() for msg in messages: @@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): tool_call_ids.add(tc_id) patched: list = [] - consumed_tool_msg_ids: set[str] = set() patch_count = 0 for msg in messages: if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: @@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if not tc_id or tc_id in consumed_tool_msg_ids: + if not tc_id: continue - existing_tool_msg = tool_messages_by_id.get(tc_id) + tool_msg_queue = tool_messages_by_id.get(tc_id) + existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None if existing_tool_msg is not None: patched.append(existing_tool_msg) - consumed_tool_msg_ids.add(tc_id) else: patched.append( ToolMessage( @@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): status="error", ) ) - consumed_tool_msg_ids.add(tc_id) patch_count += 1 if patched == messages: diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index 5ecded924..34f1ac035 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching: assert mw._build_patched_messages(msgs) is None + def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self): + mw = DanglingToolCallMiddleware() + msgs = [ + HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}), + _ai_with_tool_calls( + [ + _tc("web_search", "web_search:11"), + _tc("web_search", "web_search:12"), + _tc("web_search", "web_search:13"), + ] + ), + _tool_msg("web_search:11", "web_search"), + _tool_msg("web_search:12", "web_search"), + _tool_msg("web_search:13", "web_search"), + _ai_with_tool_calls( + [ + _tc("web_search", "web_search:9"), + _tc("web_search", "web_search:10"), + _tc("web_search", "web_search:11"), + ] + ), + _tool_msg("web_search:9", "web_search"), + _tool_msg("web_search:10", "web_search"), + _tool_msg("web_search:11", "web_search"), + ] + + assert mw._build_patched_messages(msgs) is None + + def test_reused_tool_call_id_patches_second_dangling_occurrence(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + _tool_msg("web_search:11", "web_search"), + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "web_search:11" + assert patched[1].status == "success" + assert isinstance(patched[3], ToolMessage) + assert patched[3].tool_call_id == "web_search:11" + assert patched[3].status == "error" + + def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self): + mw = DanglingToolCallMiddleware() + result = _tool_msg("web_search:11", "web_search") + msgs = [ + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + result, + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert patched[1] is result + assert patched[1].status == "success" + assert isinstance(patched[3], ToolMessage) + assert patched[3].tool_call_id == "web_search:11" + assert patched[3].status == "error" + def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): mw = DanglingToolCallMiddleware() msgs = [