mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
fix(middleware): handle repeated tool call ids (#3143)
* fix(middleware): handle repeated tool call ids * add tests * refactor(middleware): rely on tool result queues
This commit is contained in:
+6
-7
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
This normalizes model-bound causal order before provider serialization while
|
This normalizes model-bound causal order before provider serialization while
|
||||||
preserving already-valid transcripts unchanged.
|
preserving already-valid transcripts unchanged.
|
||||||
"""
|
"""
|
||||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
tool_messages_by_id[msg.tool_call_id].append(msg)
|
||||||
|
|
||||||
tool_call_ids: set[str] = set()
|
tool_call_ids: set[str] = set()
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
tool_call_ids.add(tc_id)
|
tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
patched: list = []
|
patched: list = []
|
||||||
consumed_tool_msg_ids: set[str] = set()
|
|
||||||
patch_count = 0
|
patch_count = 0
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||||
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
for tc in self._message_tool_calls(msg):
|
for tc in self._message_tool_calls(msg):
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
if not tc_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
tool_msg_queue = tool_messages_by_id.get(tc_id)
|
||||||
|
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
|
||||||
if existing_tool_msg is not None:
|
if existing_tool_msg is not None:
|
||||||
patched.append(existing_tool_msg)
|
patched.append(existing_tool_msg)
|
||||||
consumed_tool_msg_ids.add(tc_id)
|
|
||||||
else:
|
else:
|
||||||
patched.append(
|
patched.append(
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
consumed_tool_msg_ids.add(tc_id)
|
|
||||||
patch_count += 1
|
patch_count += 1
|
||||||
|
|
||||||
if patched == messages:
|
if patched == messages:
|
||||||
|
|||||||
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
|
|||||||
|
|
||||||
assert mw._build_patched_messages(msgs) is None
|
assert mw._build_patched_messages(msgs) is None
|
||||||
|
|
||||||
|
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
|
||||||
|
mw = DanglingToolCallMiddleware()
|
||||||
|
msgs = [
|
||||||
|
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
|
||||||
|
_ai_with_tool_calls(
|
||||||
|
[
|
||||||
|
_tc("web_search", "web_search:11"),
|
||||||
|
_tc("web_search", "web_search:12"),
|
||||||
|
_tc("web_search", "web_search:13"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_msg("web_search:11", "web_search"),
|
||||||
|
_tool_msg("web_search:12", "web_search"),
|
||||||
|
_tool_msg("web_search:13", "web_search"),
|
||||||
|
_ai_with_tool_calls(
|
||||||
|
[
|
||||||
|
_tc("web_search", "web_search:9"),
|
||||||
|
_tc("web_search", "web_search:10"),
|
||||||
|
_tc("web_search", "web_search:11"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_msg("web_search:9", "web_search"),
|
||||||
|
_tool_msg("web_search:10", "web_search"),
|
||||||
|
_tool_msg("web_search:11", "web_search"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mw._build_patched_messages(msgs) is None
|
||||||
|
|
||||||
|
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
|
||||||
|
mw = DanglingToolCallMiddleware()
|
||||||
|
msgs = [
|
||||||
|
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||||
|
_tool_msg("web_search:11", "web_search"),
|
||||||
|
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||||
|
]
|
||||||
|
|
||||||
|
patched = mw._build_patched_messages(msgs)
|
||||||
|
|
||||||
|
assert patched is not None
|
||||||
|
assert isinstance(patched[1], ToolMessage)
|
||||||
|
assert patched[1].tool_call_id == "web_search:11"
|
||||||
|
assert patched[1].status == "success"
|
||||||
|
assert isinstance(patched[3], ToolMessage)
|
||||||
|
assert patched[3].tool_call_id == "web_search:11"
|
||||||
|
assert patched[3].status == "error"
|
||||||
|
|
||||||
|
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
|
||||||
|
mw = DanglingToolCallMiddleware()
|
||||||
|
result = _tool_msg("web_search:11", "web_search")
|
||||||
|
msgs = [
|
||||||
|
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||||
|
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||||
|
result,
|
||||||
|
]
|
||||||
|
|
||||||
|
patched = mw._build_patched_messages(msgs)
|
||||||
|
|
||||||
|
assert patched is not None
|
||||||
|
assert patched[1] is result
|
||||||
|
assert patched[1].status == "success"
|
||||||
|
assert isinstance(patched[3], ToolMessage)
|
||||||
|
assert patched[3].tool_call_id == "web_search:11"
|
||||||
|
assert patched[3].status == "error"
|
||||||
|
|
||||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||||
mw = DanglingToolCallMiddleware()
|
mw = DanglingToolCallMiddleware()
|
||||||
msgs = [
|
msgs = [
|
||||||
|
|||||||
Reference in New Issue
Block a user