fix(middleware): sync raw tool call metadata (#2757)

This commit is contained in:
Eilen Shin
2026-05-08 10:08:53 +08:00
committed by GitHub
parent 530bda7107
commit 5fd0e6ac89
5 changed files with 165 additions and 5 deletions
@@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"):
return {"name": name, "id": call_id, "args": {}}
def _raw_tool_call(call_id: str, name: str = "task") -> dict:
return {
"id": call_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
class TestClampSubagentLimit:
def test_below_min_clamped_to_min(self):
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
@@ -117,6 +125,23 @@ class TestTruncateTaskCalls:
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
assert len(task_calls) == 2
def test_truncation_syncs_raw_provider_tool_calls(self):
mw = SubagentLimitMiddleware(max_concurrent=2)
msg = AIMessage(
content="",
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]},
response_metadata={"finish_reason": "tool_calls"},
)
result = mw._truncate_task_calls({"messages": [msg]})
assert result is not None
updated_msg = result["messages"][0]
assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"]
assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"]
assert updated_msg.response_metadata["finish_reason"] == "tool_calls"
def test_only_non_task_calls_returns_none(self):
mw = SubagentLimitMiddleware()
msg = AIMessage(
@@ -75,6 +75,14 @@ def _skill_conversation() -> list:
]
def _raw_tool_call(tool_id: str, name: str = "read_file") -> dict:
return {
"id": tool_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
def test_before_summarization_hook_receives_messages_before_compression() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
@@ -413,6 +421,47 @@ def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls(
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
def test_skill_rescue_syncs_raw_provider_tool_calls_on_split_ai_messages() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill and notes",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1"), _raw_tool_call("file-1")]},
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in preserved_ai.additional_kwargs["tool_calls"]] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
assert [tc["id"] for tc in summarized_ai.additional_kwargs["tool_calls"]] == ["file-1"]
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
@@ -451,6 +500,42 @@ def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
assert summarized_ai.content == "reading skill and notes"
def test_skill_rescue_removes_raw_provider_tool_calls_from_content_only_summary_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill",
tool_calls=[_skill_read_call("skill-1", "alpha")],
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1")], "function_call": {"name": "read_file"}},
response_metadata={"finish_reason": "tool_calls"},
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
summarized = captured[0].messages_to_summarize
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage))
assert summarized_ai.content == "reading skill"
assert summarized_ai.tool_calls == []
assert "tool_calls" not in summarized_ai.additional_kwargs
assert "function_call" not in summarized_ai.additional_kwargs
assert summarized_ai.response_metadata["finish_reason"] == "stop"
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(