fix(middleware): handle repeated tool call ids (#3143)

* fix(middleware): handle repeated tool call ids

* add tests

* refactor(middleware): rely on tool result queues
This commit is contained in:
Nan Gao
2026-05-22 15:44:05 +02:00
committed by GitHub
parent 2eeb597985
commit f0bae28636
2 changed files with 70 additions and 7 deletions
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
import json
import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from typing import override
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged.
"""
tool_messages_by_id: dict[str, ToolMessage] = {}
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
for msg in messages:
if isinstance(msg, ToolMessage):
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
tool_messages_by_id[msg.tool_call_id].append(msg)
tool_call_ids: set[str] = set()
for msg in messages:
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
tool_call_ids.add(tc_id)
patched: list = []
consumed_tool_msg_ids: set[str] = set()
patch_count = 0
for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if not tc_id or tc_id in consumed_tool_msg_ids:
if not tc_id:
continue
existing_tool_msg = tool_messages_by_id.get(tc_id)
tool_msg_queue = tool_messages_by_id.get(tc_id)
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else:
patched.append(
ToolMessage(
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error",
)
)
consumed_tool_msg_ids.add(tc_id)
patch_count += 1
if patched == messages:
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [