mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(middleware): sync raw tool call metadata (#2757)
This commit is contained in:
@@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"):
|
||||
return {"name": name, "id": call_id, "args": {}}
|
||||
|
||||
|
||||
def _raw_tool_call(call_id: str, name: str = "task") -> dict:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": "{}"},
|
||||
}
|
||||
|
||||
|
||||
class TestClampSubagentLimit:
|
||||
def test_below_min_clamped_to_min(self):
|
||||
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
|
||||
@@ -117,6 +125,23 @@ class TestTruncateTaskCalls:
|
||||
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
|
||||
assert len(task_calls) == 2
|
||||
|
||||
def test_truncation_syncs_raw_provider_tool_calls(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=2)
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
|
||||
additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]},
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
|
||||
result = mw._truncate_task_calls({"messages": [msg]})
|
||||
|
||||
assert result is not None
|
||||
updated_msg = result["messages"][0]
|
||||
assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"]
|
||||
assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"]
|
||||
assert updated_msg.response_metadata["finish_reason"] == "tool_calls"
|
||||
|
||||
def test_only_non_task_calls_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
msg = AIMessage(
|
||||
|
||||
Reference in New Issue
Block a user