mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
Merge branch 'main' into fix-2804
This commit is contained in:
+27
-22
@@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
return "[Tool call was interrupted and did not return a result.]"
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
"""Return messages with tool results grouped after their tool-call AIMessage.
|
||||
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
"""
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
tool_call_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
if tc_id:
|
||||
tool_call_ids.add(tc_id)
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
patched_ids: set[str] = set()
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
continue
|
||||
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content=self._synthetic_tool_message_content(tc),
|
||||
@@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
patched_ids.add(tc_id)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
if patched == messages:
|
||||
return None
|
||||
|
||||
if patch_count:
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
|
||||
@@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list.
|
||||
|
||||
Additionally, this middleware prevents the agent from exiting the loop while
|
||||
there are still incomplete todo items. When the model produces a final response
|
||||
(no tool calls) but todos are not yet complete, the middleware injects a reminder
|
||||
and jumps back to the model node to force continued engagement.
|
||||
(no tool calls) but todos are not yet complete, the middleware queues a reminder
|
||||
for the next model request and jumps back to the model node to force continued
|
||||
engagement. The completion reminder is injected via ``wrap_model_call`` instead
|
||||
of being persisted into graph state as a normal user-visible message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain.agents.middleware.types import hook_config
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
@@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_completion_reminder(todos: list[Todo]) -> str:
|
||||
"""Format a completion reminder for incomplete todo items."""
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
return (
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
)
|
||||
|
||||
|
||||
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
|
||||
|
||||
|
||||
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
|
||||
"""Return True when an AIMessage is not a clean final answer.
|
||||
|
||||
Todo completion reminders should only fire when the model has produced a
|
||||
plain final response. Provider/tool parsing details have moved across
|
||||
LangChain versions and integrations, so keep all tool-intent/error signals
|
||||
behind this helper instead of checking one concrete field at the call site.
|
||||
"""
|
||||
if message.tool_calls:
|
||||
return True
|
||||
|
||||
if getattr(message, "invalid_tool_calls", None):
|
||||
return True
|
||||
|
||||
# Backward/provider compatibility: some integrations preserve raw or legacy
|
||||
# tool-call intent in additional_kwargs even when structured tool_calls is
|
||||
# empty. If this helper changes, update the matching sentinel test
|
||||
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
|
||||
# if that test fails after a LangChain upgrade, review this helper so new
|
||||
# tool-call/error fields are not silently treated as clean final answers.
|
||||
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
|
||||
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
|
||||
return True
|
||||
|
||||
response_metadata = getattr(message, "response_metadata", {}) or {}
|
||||
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
@@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
@@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
# Maximum number of completion reminders before allowing the agent to exit.
|
||||
# This prevents infinite loops when the agent cannot make further progress.
|
||||
_MAX_COMPLETION_REMINDERS = 2
|
||||
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
|
||||
_MAX_COMPLETION_REMINDER_KEYS = 4096
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._lock = threading.Lock()
|
||||
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
|
||||
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
|
||||
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
|
||||
self._completion_reminder_next_order = 0
|
||||
|
||||
@staticmethod
|
||||
def _get_thread_id(runtime: Runtime) -> str:
|
||||
context = getattr(runtime, "context", None)
|
||||
thread_id = context.get("thread_id") if context else None
|
||||
return str(thread_id) if thread_id else "default"
|
||||
|
||||
@staticmethod
|
||||
def _get_run_id(runtime: Runtime) -> str:
|
||||
context = getattr(runtime, "context", None)
|
||||
run_id = context.get("run_id") if context else None
|
||||
return str(run_id) if run_id else "default"
|
||||
|
||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
||||
|
||||
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||
self._completion_reminder_next_order += 1
|
||||
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
|
||||
|
||||
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
|
||||
keys = set(self._pending_completion_reminders)
|
||||
keys.update(self._completion_reminder_counts)
|
||||
keys.update(self._completion_reminder_touch_order)
|
||||
return keys
|
||||
|
||||
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||
self._pending_completion_reminders.pop(key, None)
|
||||
self._completion_reminder_counts.pop(key, None)
|
||||
self._completion_reminder_touch_order.pop(key, None)
|
||||
|
||||
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
|
||||
keys = self._completion_reminder_keys_locked()
|
||||
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
candidates = [key for key in keys if key != protected_key]
|
||||
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
|
||||
for key in candidates[:overflow]:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._pending_completion_reminders.setdefault(key, []).append(reminder)
|
||||
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
|
||||
self._touch_completion_reminder_key_locked(key)
|
||||
self._prune_completion_reminder_state_locked(protected_key=key)
|
||||
|
||||
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
return self._completion_reminder_counts.get(key, 0)
|
||||
|
||||
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
reminders = self._pending_completion_reminders.pop(key, [])
|
||||
if reminders or key in self._completion_reminder_counts:
|
||||
self._touch_completion_reminder_key_locked(key)
|
||||
return reminders
|
||||
|
||||
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||
thread_id, current_run_id = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
for key in self._completion_reminder_keys_locked():
|
||||
if key[0] == thread_id and key[1] != current_run_id:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@override
|
||||
@@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
if base_result is not None:
|
||||
return base_result
|
||||
|
||||
# 2. Only intervene when the agent wants to exit (no tool calls).
|
||||
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
|
||||
# intent or tool-call parse errors should be handled by the tool path
|
||||
# instead of being masked by todo reminders.
|
||||
messages = state.get("messages") or []
|
||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
||||
if not last_ai or last_ai.tool_calls:
|
||||
if not last_ai or _has_tool_call_intent_or_error(last_ai):
|
||||
return None
|
||||
|
||||
# 3. Allow exit when all todos are completed or there are no todos.
|
||||
@@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
return None
|
||||
|
||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
||||
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
|
||||
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
|
||||
return None
|
||||
|
||||
# 5. Inject a reminder and force the agent back to the model.
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
reminder = HumanMessage(
|
||||
name="todo_completion_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"jump_to": "model", "messages": [reminder]}
|
||||
# 5. Queue a reminder for the next model request and jump back. We must
|
||||
# not persist this control prompt as a normal HumanMessage, otherwise it
|
||||
# can leak into user-visible message streams and saved transcripts.
|
||||
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
|
||||
return {"jump_to": "model"}
|
||||
|
||||
@override
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of after_model."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@staticmethod
|
||||
def _format_pending_completion_reminders(reminders: list[str]) -> str:
|
||||
return "\n\n".join(dict.fromkeys(reminders))
|
||||
|
||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
||||
reminders = self._drain_completion_reminders(request.runtime)
|
||||
if not reminders:
|
||||
return request
|
||||
new_messages = [
|
||||
*request.messages,
|
||||
HumanMessage(
|
||||
content=self._format_pending_completion_reminders(reminders),
|
||||
name="todo_completion_reminder",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
),
|
||||
]
|
||||
return request.override(messages=new_messages)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, override
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
||||
for tc in message.tool_calls or []:
|
||||
if isinstance(tc, dict):
|
||||
if tc.get("id") == tool_call_id:
|
||||
return True
|
||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
||||
# written back to the same dispatch message (merging into one update).
|
||||
state_updates: dict[int, AIMessage] = {}
|
||||
if len(messages) >= 2:
|
||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
||||
|
||||
idx = len(messages) - 2
|
||||
while idx >= 0:
|
||||
tool_msg = messages[idx]
|
||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
||||
break
|
||||
|
||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
||||
if subagent_usage:
|
||||
# Search backward from the ToolMessage to find the AIMessage
|
||||
# that dispatched it. A single model response can dispatch
|
||||
# multiple task tool calls, so we can't assume a fixed offset.
|
||||
dispatch_idx = idx - 1
|
||||
while dispatch_idx >= 0:
|
||||
candidate = messages[dispatch_idx]
|
||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
||||
# Accumulate into an existing update for the same
|
||||
# AIMessage (multiple task calls in one response),
|
||||
# or merge fresh from the original message.
|
||||
existing_update = state_updates.get(dispatch_idx)
|
||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
||||
merged = {
|
||||
**prev,
|
||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
||||
}
|
||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
||||
break
|
||||
dispatch_idx -= 1
|
||||
idx -= 1
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
if state_updates:
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
state_updates[len(messages) - 1] = updated_msg
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
|
||||
Reference in New Issue
Block a user