mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(middleware): sync raw tool call metadata (#2757)
This commit is contained in:
@@ -7,6 +7,7 @@ from langchain.agents import AgentState
|
|||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
|||||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||||
|
|
||||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
|
||||||
return {"messages": [updated_msg]}
|
return {"messages": [updated_msg]}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from langgraph.config import get_config
|
|||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,10 +80,7 @@ def _clone_ai_message(
|
|||||||
content: Any | None = None,
|
content: Any | None = None,
|
||||||
) -> AIMessage:
|
) -> AIMessage:
|
||||||
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
||||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
||||||
if content is not None:
|
|
||||||
update["content"] = content
|
|
||||||
return message.model_copy(update=update)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""Helpers for keeping AIMessage tool-call metadata consistent."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
|
||||||
|
if not isinstance(raw_tool_call, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_id = raw_tool_call.get("id")
|
||||||
|
return raw_id if isinstance(raw_id, str) and raw_id else None
|
||||||
|
|
||||||
|
|
||||||
|
def clone_ai_message_with_tool_calls(
|
||||||
|
message: AIMessage,
|
||||||
|
tool_calls: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
content: Any | None = None,
|
||||||
|
) -> AIMessage:
|
||||||
|
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
|
||||||
|
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
|
||||||
|
|
||||||
|
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||||
|
if content is not None:
|
||||||
|
update["content"] = content
|
||||||
|
|
||||||
|
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
|
||||||
|
raw_tool_calls = additional_kwargs.get("tool_calls")
|
||||||
|
if isinstance(raw_tool_calls, list):
|
||||||
|
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
|
||||||
|
if synced_raw_tool_calls:
|
||||||
|
additional_kwargs["tool_calls"] = synced_raw_tool_calls
|
||||||
|
else:
|
||||||
|
additional_kwargs.pop("tool_calls", None)
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
additional_kwargs.pop("function_call", None)
|
||||||
|
|
||||||
|
update["additional_kwargs"] = additional_kwargs
|
||||||
|
|
||||||
|
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
|
||||||
|
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
|
||||||
|
response_metadata["finish_reason"] = "stop"
|
||||||
|
update["response_metadata"] = response_metadata
|
||||||
|
|
||||||
|
return message.model_copy(update=update)
|
||||||
@@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"):
|
|||||||
return {"name": name, "id": call_id, "args": {}}
|
return {"name": name, "id": call_id, "args": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_tool_call(call_id: str, name: str = "task") -> dict:
|
||||||
|
return {
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": name, "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestClampSubagentLimit:
|
class TestClampSubagentLimit:
|
||||||
def test_below_min_clamped_to_min(self):
|
def test_below_min_clamped_to_min(self):
|
||||||
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
|
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
|
||||||
@@ -117,6 +125,23 @@ class TestTruncateTaskCalls:
|
|||||||
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
|
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
|
||||||
assert len(task_calls) == 2
|
assert len(task_calls) == 2
|
||||||
|
|
||||||
|
def test_truncation_syncs_raw_provider_tool_calls(self):
|
||||||
|
mw = SubagentLimitMiddleware(max_concurrent=2)
|
||||||
|
msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
|
||||||
|
additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]},
|
||||||
|
response_metadata={"finish_reason": "tool_calls"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mw._truncate_task_calls({"messages": [msg]})
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
updated_msg = result["messages"][0]
|
||||||
|
assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"]
|
||||||
|
assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"]
|
||||||
|
assert updated_msg.response_metadata["finish_reason"] == "tool_calls"
|
||||||
|
|
||||||
def test_only_non_task_calls_returns_none(self):
|
def test_only_non_task_calls_returns_none(self):
|
||||||
mw = SubagentLimitMiddleware()
|
mw = SubagentLimitMiddleware()
|
||||||
msg = AIMessage(
|
msg = AIMessage(
|
||||||
|
|||||||
@@ -75,6 +75,14 @@ def _skill_conversation() -> list:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_tool_call(tool_id: str, name: str = "read_file") -> dict:
|
||||||
|
return {
|
||||||
|
"id": tool_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": name, "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
||||||
captured: list[SummarizationEvent] = []
|
captured: list[SummarizationEvent] = []
|
||||||
middleware = _middleware(before_summarization=[captured.append])
|
middleware = _middleware(before_summarization=[captured.append])
|
||||||
@@ -413,6 +421,47 @@ def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls(
|
|||||||
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
|
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
|
||||||
|
|
||||||
|
|
||||||
|
def test_skill_rescue_syncs_raw_provider_tool_calls_on_split_ai_messages() -> None:
|
||||||
|
captured: list[SummarizationEvent] = []
|
||||||
|
middleware = _middleware(
|
||||||
|
before_summarization=[captured.append],
|
||||||
|
trigger=("messages", 4),
|
||||||
|
keep=("messages", 2),
|
||||||
|
preserve_recent_skill_count=5,
|
||||||
|
preserve_recent_skill_tokens=10_000,
|
||||||
|
preserve_recent_skill_tokens_per_skill=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="u1"),
|
||||||
|
AIMessage(
|
||||||
|
content="reading skill and notes",
|
||||||
|
tool_calls=[
|
||||||
|
_skill_read_call("skill-1", "alpha"),
|
||||||
|
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
|
||||||
|
],
|
||||||
|
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1"), _raw_tool_call("file-1")]},
|
||||||
|
),
|
||||||
|
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||||
|
ToolMessage(content="user notes", tool_call_id="file-1"),
|
||||||
|
HumanMessage(content="u2"),
|
||||||
|
AIMessage(content="done"),
|
||||||
|
]
|
||||||
|
|
||||||
|
middleware.before_model({"messages": messages}, _runtime())
|
||||||
|
|
||||||
|
preserved = captured[0].preserved_messages
|
||||||
|
summarized = captured[0].messages_to_summarize
|
||||||
|
|
||||||
|
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
|
||||||
|
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
|
||||||
|
|
||||||
|
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
|
||||||
|
assert [tc["id"] for tc in preserved_ai.additional_kwargs["tool_calls"]] == ["skill-1"]
|
||||||
|
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
|
||||||
|
assert [tc["id"] for tc in summarized_ai.additional_kwargs["tool_calls"]] == ["file-1"]
|
||||||
|
|
||||||
|
|
||||||
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
|
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
|
||||||
captured: list[SummarizationEvent] = []
|
captured: list[SummarizationEvent] = []
|
||||||
middleware = _middleware(
|
middleware = _middleware(
|
||||||
@@ -451,6 +500,42 @@ def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
|
|||||||
assert summarized_ai.content == "reading skill and notes"
|
assert summarized_ai.content == "reading skill and notes"
|
||||||
|
|
||||||
|
|
||||||
|
def test_skill_rescue_removes_raw_provider_tool_calls_from_content_only_summary_clone() -> None:
|
||||||
|
captured: list[SummarizationEvent] = []
|
||||||
|
middleware = _middleware(
|
||||||
|
before_summarization=[captured.append],
|
||||||
|
trigger=("messages", 4),
|
||||||
|
keep=("messages", 2),
|
||||||
|
preserve_recent_skill_count=5,
|
||||||
|
preserve_recent_skill_tokens=10_000,
|
||||||
|
preserve_recent_skill_tokens_per_skill=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="u1"),
|
||||||
|
AIMessage(
|
||||||
|
content="reading skill",
|
||||||
|
tool_calls=[_skill_read_call("skill-1", "alpha")],
|
||||||
|
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1")], "function_call": {"name": "read_file"}},
|
||||||
|
response_metadata={"finish_reason": "tool_calls"},
|
||||||
|
),
|
||||||
|
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||||
|
HumanMessage(content="u2"),
|
||||||
|
AIMessage(content="done"),
|
||||||
|
]
|
||||||
|
|
||||||
|
middleware.before_model({"messages": messages}, _runtime())
|
||||||
|
|
||||||
|
summarized = captured[0].messages_to_summarize
|
||||||
|
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage))
|
||||||
|
|
||||||
|
assert summarized_ai.content == "reading skill"
|
||||||
|
assert summarized_ai.tool_calls == []
|
||||||
|
assert "tool_calls" not in summarized_ai.additional_kwargs
|
||||||
|
assert "function_call" not in summarized_ai.additional_kwargs
|
||||||
|
assert summarized_ai.response_metadata["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
|
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
|
||||||
captured: list[SummarizationEvent] = []
|
captured: list[SummarizationEvent] = []
|
||||||
middleware = _middleware(
|
middleware = _middleware(
|
||||||
|
|||||||
Reference in New Issue
Block a user