mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
fix(middleware): repair dangling tool-call history after loop interru… (#2035)
* fix(middleware): repair dangling tool-call history after loop interruption (#2029) * docs(backend): fix middleware chain ordering --------- Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
This commit is contained in:
+41
-2
@@ -13,6 +13,7 @@ at the correct positions (immediately after each dangling AIMessage), not append
|
||||
to the end of the message list as before_model + add_messages reducer would do.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
@@ -33,6 +34,44 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
offending AIMessage so the LLM receives a well-formed conversation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _message_tool_calls(msg) -> list[dict]:
|
||||
"""Return normalized tool calls from structured fields or raw provider payloads."""
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
if tool_calls:
|
||||
return list(tool_calls)
|
||||
|
||||
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
|
||||
normalized: list[dict] = []
|
||||
for raw_tc in raw_tool_calls:
|
||||
if not isinstance(raw_tc, dict):
|
||||
continue
|
||||
|
||||
function = raw_tc.get("function")
|
||||
name = raw_tc.get("name")
|
||||
if not name and isinstance(function, dict):
|
||||
name = function.get("name")
|
||||
|
||||
args = raw_tc.get("args", {})
|
||||
if not args and isinstance(function, dict):
|
||||
raw_args = function.get("arguments")
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
parsed_args = {}
|
||||
args = parsed_args if isinstance(parsed_args, dict) else {}
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"id": raw_tc.get("id"),
|
||||
"name": name or "unknown",
|
||||
"args": args if isinstance(args, dict) else {},
|
||||
}
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
|
||||
@@ -51,7 +90,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
@@ -70,7 +109,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
patched.append(
|
||||
|
||||
@@ -17,6 +17,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
@@ -323,6 +324,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
@staticmethod
|
||||
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
|
||||
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
|
||||
update = {
|
||||
"tool_calls": [],
|
||||
"content": content,
|
||||
}
|
||||
|
||||
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
|
||||
for key in ("tool_calls", "function_call"):
|
||||
additional_kwargs.pop(key, None)
|
||||
update["additional_kwargs"] = additional_kwargs
|
||||
|
||||
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
|
||||
if response_metadata.get("finish_reason") == "tool_calls":
|
||||
response_metadata["finish_reason"] = "stop"
|
||||
update["response_metadata"] = response_metadata
|
||||
|
||||
return update
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
@@ -330,12 +351,8 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": self._append_text(last_msg.content, warning),
|
||||
}
|
||||
)
|
||||
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
|
||||
stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
|
||||
Reference in New Issue
Block a user