fix(middleware): sync raw tool call metadata (#2757)

This commit is contained in:
Eilen Shin
2026-05-08 10:08:53 +08:00
committed by GitHub
parent 530bda7107
commit 5fd0e6ac89
5 changed files with 165 additions and 5 deletions
@@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"):
return {"name": name, "id": call_id, "args": {}}
def _raw_tool_call(call_id: str, name: str = "task") -> dict:
return {
"id": call_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
class TestClampSubagentLimit:
def test_below_min_clamped_to_min(self):
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
@@ -117,6 +125,23 @@ class TestTruncateTaskCalls:
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
assert len(task_calls) == 2
def test_truncation_syncs_raw_provider_tool_calls(self):
mw = SubagentLimitMiddleware(max_concurrent=2)
msg = AIMessage(
content="",
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]},
response_metadata={"finish_reason": "tool_calls"},
)
result = mw._truncate_task_calls({"messages": [msg]})
assert result is not None
updated_msg = result["messages"][0]
assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"]
assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"]
assert updated_msg.response_metadata["finish_reason"] == "tool_calls"
def test_only_non_task_calls_returns_none(self):
mw = SubagentLimitMiddleware()
msg = AIMessage(