mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
fix(middleware): normalize tool result adjacency before model calls (#2939)
* normalizing tool-call transcripts before invocation * test(middleware): cover tool result regrouping edge cases
This commit is contained in:
+27
-22
@@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
return "[Tool call was interrupted and did not return a result.]"
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
"""Return messages with tool results grouped after their tool-call AIMessage.
|
||||
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
"""
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
tool_call_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
if tc_id:
|
||||
tool_call_ids.add(tc_id)
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
patched_ids: set[str] = set()
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
continue
|
||||
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content=self._synthetic_tool_message_content(tc),
|
||||
@@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
patched_ids.add(tc_id)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
if patched == messages:
|
||||
return None
|
||||
|
||||
if patch_count:
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
|
||||
Reference in New Issue
Block a user