fix(middleware): normalize tool result adjacency before model calls (#2939)

* normalizing tool-call transcripts before invocation

* test(middleware): cover tool result regrouping edge cases
This commit is contained in:
LawranceLiao
2026-05-15 22:09:04 +08:00
committed by GitHub
parent 45060a9ffc
commit 181d836541
2 changed files with 109 additions and 22 deletions
@@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
return "[Tool call was interrupted and did not return a result.]" return "[Tool call was interrupted and did not return a result.]"
def _build_patched_messages(self, messages: list) -> list | None: def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions. """Return messages with tool results grouped after their tool-call AIMessage.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage), This normalizes model-bound causal order before provider serialization while
a synthetic ToolMessage is inserted immediately after that AIMessage. preserving already-valid transcripts unchanged.
Returns None if no patches are needed.
""" """
# Collect IDs of all existing ToolMessages tool_messages_by_id: dict[str, ToolMessage] = {}
existing_tool_msg_ids: set[str] = set()
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage): if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id) tool_messages_by_id.setdefault(msg.tool_call_id, msg)
# Check if any patching is needed tool_call_ids: set[str] = set()
needs_patch = False
for msg in messages: for msg in messages:
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
for tc in self._message_tool_calls(msg): for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids: if tc_id:
needs_patch = True tool_call_ids.add(tc_id)
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = [] patched: list = []
patched_ids: set[str] = set() consumed_tool_msg_ids: set[str] = set()
patch_count = 0 patch_count = 0
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
continue
patched.append(msg) patched.append(msg)
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
for tc in self._message_tool_calls(msg): for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: if not tc_id or tc_id in consumed_tool_msg_ids:
continue
existing_tool_msg = tool_messages_by_id.get(tc_id)
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else:
patched.append( patched.append(
ToolMessage( ToolMessage(
content=self._synthetic_tool_message_content(tc), content=self._synthetic_tool_message_content(tc),
@@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error", status="error",
) )
) )
patched_ids.add(tc_id) consumed_tool_msg_ids.add(tc_id)
patch_count += 1 patch_count += 1
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") if patched == messages:
return None
if patch_count:
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched return patched
@override @override
@@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching:
assert patched[1].name == "bash" assert patched[1].name == "bash"
assert patched[1].status == "error" assert patched[1].status == "error"
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
middleware = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = middleware._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
HumanMessage(content="interruption"),
_tool_msg("call_2", "read"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert isinstance(patched[2], ToolMessage)
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
assert isinstance(patched[3], HumanMessage)
def test_valid_adjacent_tool_results_are_unchanged(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
_tool_msg("call_1", "bash"),
HumanMessage(content="next"),
]
assert mw._build_patched_messages(msgs) is None
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_ai_with_tool_calls([_tc("read", "call_2")]),
_tool_msg("call_1", "bash"),
_tool_msg("call_2", "read"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
assert isinstance(patched[3], AIMessage)
assert isinstance(patched[4], ToolMessage)
assert patched[4].tool_call_id == "call_2"
def test_orphan_tool_message_is_preserved_during_grouping(self):
mw = DanglingToolCallMiddleware()
orphan = _tool_msg("orphan_call", "orphan")
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
orphan,
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert orphan in patched
assert patched.count(orphan) == 1
def test_invalid_tool_call_is_patched(self): def test_invalid_tool_call_is_patched(self):
mw = DanglingToolCallMiddleware() mw = DanglingToolCallMiddleware()
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])] msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]