fix(middleware): sync raw tool call metadata (#2757)
This commit is contained in:
@@ -7,6 +7,7 @@ from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
|
||||
@@ -14,6 +14,8 @@ from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -78,10 +80,7 @@ def _clone_ai_message(
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
return message.model_copy(update=update)
|
||||
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Helpers for keeping AIMessage tool-call metadata consistent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
return None
|
||||
|
||||
raw_id = raw_tool_call.get("id")
|
||||
return raw_id if isinstance(raw_id, str) and raw_id else None
|
||||
|
||||
|
||||
def clone_ai_message_with_tool_calls(
|
||||
message: AIMessage,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
*,
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
|
||||
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
|
||||
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
|
||||
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
|
||||
raw_tool_calls = additional_kwargs.get("tool_calls")
|
||||
if isinstance(raw_tool_calls, list):
|
||||
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
|
||||
if synced_raw_tool_calls:
|
||||
additional_kwargs["tool_calls"] = synced_raw_tool_calls
|
||||
else:
|
||||
additional_kwargs.pop("tool_calls", None)
|
||||
|
||||
if not tool_calls:
|
||||
additional_kwargs.pop("function_call", None)
|
||||
|
||||
update["additional_kwargs"] = additional_kwargs
|
||||
|
||||
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
|
||||
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
|
||||
response_metadata["finish_reason"] = "stop"
|
||||
update["response_metadata"] = response_metadata
|
||||
|
||||
return message.model_copy(update=update)
|
||||
Reference in New Issue
Block a user