mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
fix(middleware): handle repeated tool call ids (#3143)
* fix(middleware): handle repeated tool call ids * add tests * refactor(middleware): rely on tool result queues
This commit is contained in:
+6
-7
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
"""
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
tool_messages_by_id[msg.tool_call_id].append(msg)
|
||||
|
||||
tool_call_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
tool_call_ids.add(tc_id)
|
||||
|
||||
patched: list = []
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
if not tc_id:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
tool_msg_queue = tool_messages_by_id.get(tc_id)
|
||||
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
if patched == messages:
|
||||
|
||||
Reference in New Issue
Block a user