diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 5bb54f3e5..000ca51a2 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): return "[Tool call was interrupted and did not return a result.]" def _build_patched_messages(self, messages: list) -> list | None: - """Return a new message list with patches inserted at the correct positions. + """Return messages with tool results grouped after their tool-call AIMessage. - For each AIMessage with dangling tool_calls (no corresponding ToolMessage), - a synthetic ToolMessage is inserted immediately after that AIMessage. - Returns None if no patches are needed. + This normalizes model-bound causal order before provider serialization while + preserving already-valid transcripts unchanged. """ - # Collect IDs of all existing ToolMessages - existing_tool_msg_ids: set[str] = set() + tool_messages_by_id: dict[str, ToolMessage] = {} for msg in messages: if isinstance(msg, ToolMessage): - existing_tool_msg_ids.add(msg.tool_call_id) + tool_messages_by_id.setdefault(msg.tool_call_id, msg) - # Check if any patching is needed - needs_patch = False + tool_call_ids: set[str] = set() for msg in messages: if getattr(msg, "type", None) != "ai": continue for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if tc_id and tc_id not in existing_tool_msg_ids: - needs_patch = True - break - if needs_patch: - break + if tc_id: + tool_call_ids.add(tc_id) - if not needs_patch: - return None - - # Build new list with patches inserted right after each dangling AIMessage patched: list = [] - patched_ids: set[str] = set() + consumed_tool_msg_ids: set[str] = set() patch_count = 0 for msg in messages: + if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: + continue + patched.append(msg) if getattr(msg, "type", None) != "ai": continue + for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: + if not tc_id or tc_id in consumed_tool_msg_ids: + continue + + existing_tool_msg = tool_messages_by_id.get(tc_id) + if existing_tool_msg is not None: + patched.append(existing_tool_msg) + consumed_tool_msg_ids.add(tc_id) + else: patched.append( ToolMessage( content=self._synthetic_tool_message_content(tc), @@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): status="error", ) ) - patched_ids.add(tc_id) + consumed_tool_msg_ids.add(tc_id) patch_count += 1 - logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") + if patched == messages: + return None + + if patch_count: + logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") return patched @override diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index b1d5c476a..f9f47369d 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching: assert patched[1].name == "bash" assert patched[1].status == "error" + def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self): + middleware = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + HumanMessage(content="interruption"), + _tool_msg("call_1", "bash"), + ] + patched = middleware._build_patched_messages(msgs) + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert isinstance(patched[2], HumanMessage) + + def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]), + HumanMessage(content="interruption"), + _tool_msg("call_2", "read"), + _tool_msg("call_1", "bash"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert isinstance(patched[2], ToolMessage) + assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"] + assert isinstance(patched[3], HumanMessage) + + def test_valid_adjacent_tool_results_are_unchanged(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + _tool_msg("call_1", "bash"), + HumanMessage(content="next"), + ] + + assert mw._build_patched_messages(msgs) is None + + def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + HumanMessage(content="interruption"), + _ai_with_tool_calls([_tc("read", "call_2")]), + _tool_msg("call_1", "bash"), + _tool_msg("call_2", "read"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert isinstance(patched[2], HumanMessage) + assert isinstance(patched[3], AIMessage) + assert isinstance(patched[4], ToolMessage) + assert patched[4].tool_call_id == "call_2" + + def test_orphan_tool_message_is_preserved_during_grouping(self): + mw = DanglingToolCallMiddleware() + orphan = _tool_msg("orphan_call", "orphan") + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + orphan, + HumanMessage(content="interruption"), + _tool_msg("call_1", "bash"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert orphan in patched + assert patched.count(orphan) == 1 + def test_invalid_tool_call_is_patched(self): mw = DanglingToolCallMiddleware() msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]