mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
feat(loop-detection): defer warning injection (#2752)
* fix(loop-detection): defer warn injection to wrap_model_call The warn branch in LoopDetectionMiddleware injected a HumanMessage into state from after_model. The tools node had not yet produced ToolMessage responses to the previous AIMessage(tool_calls=...), so the new HumanMessage landed *between* the assistant's tool_calls and their responses. OpenAI/Moonshot reject the next request with "tool_call_ids did not have response messages" because their validators require tool_calls to be followed immediately by tool messages. Detection now runs in after_model as before, but only enqueues the warning into a per-thread list. Injection happens in wrap_model_call, where every prior ToolMessage is already present in request.messages. The warning is appended at the end as HumanMessage(name="loop_warning") — pairing intact, AIMessage semantics untouched, no SystemMessage issues for Anthropic. Closes #2029, addresses #2255 #2293 #2304 #2511. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(channels): remove loop warning display filter * feat(loop-detection): scope pending warnings by run * docs(loop-detection): update docs * test(loop-detection): assert deferred warnings are queued * fix(loop-detection): cap transient warning state * docs: update docs * add async awrap_model_call test coverage * docs(loop-detection): document transient warnings --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+201
-28
@@ -6,10 +6,36 @@ arguments indefinitely until the recursion limit kills the run.
|
||||
Detection strategy:
|
||||
1. After each model response, hash the tool calls (name + args).
|
||||
2. Track recent hashes in a sliding window.
|
||||
3. If the same hash appears >= warn_threshold times, inject a
|
||||
"you are repeating yourself — wrap up" system message (once per hash).
|
||||
3. If the same hash appears >= warn_threshold times, queue a
|
||||
"you are repeating yourself — wrap up" warning for the current
|
||||
thread/run. The warning is **injected at the next model call** (in
|
||||
``wrap_model_call``) as a ``HumanMessage`` appended to the message
|
||||
list, *after* all ToolMessage responses to the previous
|
||||
AIMessage(tool_calls).
|
||||
4. If it appears >= hard_limit times, strip all tool_calls from the
|
||||
response so the agent is forced to produce a final text answer.
|
||||
|
||||
Why the warning is injected at ``wrap_model_call`` instead of
|
||||
``after_model``:
|
||||
|
||||
``after_model`` fires immediately after the model emits an
|
||||
``AIMessage`` that may carry ``tool_calls``. The tools node has not
|
||||
run yet, so no matching ``ToolMessage`` exists in the history. Any
|
||||
message we add here lands *between* the assistant's tool_calls and
|
||||
their responses. OpenAI/Moonshot reject the next request with
|
||||
``"tool_call_ids did not have response messages"`` because their
|
||||
validators require the assistant's tool_calls to be followed
|
||||
immediately by tool messages. Anthropic also disallows mid-stream
|
||||
``SystemMessage``. By deferring the warning to ``wrap_model_call``,
|
||||
every prior ToolMessage is already present in the request's message
|
||||
list and the warning is appended at the end — pairing intact, no
|
||||
``AIMessage`` semantics are mutated.
|
||||
|
||||
Queued warnings are intentionally transient. If a run ends before the
|
||||
next model request drains a queued warning, ``after_agent`` drops it
|
||||
instead of carrying it into a later invocation for the same thread. The
|
||||
hard-stop path still forces termination when the configured safety limit
|
||||
is reached.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -19,11 +45,14 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -38,6 +67,7 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
||||
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
||||
_MAX_PENDING_WARNINGS_PER_RUN = 4
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
@@ -195,6 +225,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread/run queue of warnings to inject at the next model call.
|
||||
# Populated by ``after_model`` (detection) and drained by
|
||||
# ``wrap_model_call`` (injection); see module docstring.
|
||||
self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list)
|
||||
self._pending_warning_touch_order: OrderedDict[tuple[str, str], None] = OrderedDict()
|
||||
self._max_pending_warning_keys = max(1, self.max_tracked_threads * 2)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
||||
@@ -213,9 +249,20 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return str(thread_id)
|
||||
return "default"
|
||||
|
||||
def _get_run_id(self, runtime: Runtime) -> str:
|
||||
"""Extract run_id from runtime context for per-run warning scoping."""
|
||||
run_id = runtime.context.get("run_id") if runtime.context else None
|
||||
if run_id:
|
||||
return str(run_id)
|
||||
return "default"
|
||||
|
||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
||||
"""Return the pending-warning key for the current thread/run."""
|
||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
@@ -226,8 +273,52 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._warned.pop(evicted_id, None)
|
||||
self._tool_freq.pop(evicted_id, None)
|
||||
self._tool_freq_warned.pop(evicted_id, None)
|
||||
for key in list(self._pending_warnings):
|
||||
if key[0] == evicted_id:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _drop_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
|
||||
"""Drop all pending-warning bookkeeping for one thread/run key.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
self._pending_warnings.pop(key, None)
|
||||
self._pending_warning_touch_order.pop(key, None)
|
||||
|
||||
def _touch_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
|
||||
"""Mark a pending-warning key as recently used.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
self._pending_warning_touch_order[key] = None
|
||||
self._pending_warning_touch_order.move_to_end(key)
|
||||
|
||||
def _prune_pending_warning_state_locked(self, protected_key: tuple[str, str]) -> None:
|
||||
"""Cap pending-warning state across abnormal or concurrent runs.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
overflow = len(self._pending_warning_touch_order) - self._max_pending_warning_keys
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
candidates = [key for key in self._pending_warning_touch_order if key != protected_key]
|
||||
for key in candidates[:overflow]:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
|
||||
def _queue_pending_warning(self, runtime: Runtime, warning: str) -> None:
|
||||
"""Queue one transient warning for the current thread/run with caps."""
|
||||
pending_key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
warnings = self._pending_warnings[pending_key]
|
||||
if warning not in warnings:
|
||||
warnings.append(warning)
|
||||
if len(warnings) > _MAX_PENDING_WARNINGS_PER_RUN:
|
||||
del warnings[: len(warnings) - _MAX_PENDING_WARNINGS_PER_RUN]
|
||||
self._touch_pending_warning_key_locked(pending_key)
|
||||
self._prune_pending_warning_state_locked(protected_key=pending_key)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
@@ -268,6 +359,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
if len(history) > self.window_size:
|
||||
history[:] = history[-self.window_size :]
|
||||
|
||||
warned_hashes = self._warned.get(thread_id)
|
||||
if warned_hashes is not None:
|
||||
warned_hashes.intersection_update(history)
|
||||
if not warned_hashes:
|
||||
self._warned.pop(thread_id, None)
|
||||
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
@@ -381,7 +478,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
if hard_stop:
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
# Strip tool_calls from the last AIMessage to force text output.
|
||||
# Once tool_calls are stripped, the AIMessage no longer requires
|
||||
# matching ToolMessage responses, so mutating it in place here
|
||||
# is safe for OpenAI/Moonshot pairing validators.
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
|
||||
@@ -389,33 +489,48 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# WORKAROUND for v2.0-m1 — see #2724.
|
||||
#
|
||||
# Append the warning to the AIMessage content instead of
|
||||
# injecting a separate HumanMessage. Inserting any non-tool
|
||||
# message between an AIMessage(tool_calls=...) and its
|
||||
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
|
||||
# validation ("tool_call_ids did not have response messages")
|
||||
# because the tools node has not run yet at after_model time.
|
||||
# tool_calls are preserved so the tools node still executes.
|
||||
#
|
||||
# This is a temporary mitigation: mutating an existing
|
||||
# AIMessage to carry framework-authored text leaks loop-warning
|
||||
# text into downstream consumers (MemoryMiddleware fact
|
||||
# extraction, TitleMiddleware, telemetry, model replay) as if
|
||||
# the model said it. The proper fix is to defer warning
|
||||
# injection from after_model to wrap_model_call so every prior
|
||||
# ToolMessage is already in the request — see RFC #2517 (which
|
||||
# lists "loop intervention does not leave invalid
|
||||
# tool-call/tool-message state" as acceptance criteria) and
|
||||
# the prototype on `fix/loop-detection-tool-call-pairing`.
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
|
||||
return {"messages": [patched_msg]}
|
||||
# Defer injection to the next model call. We must NOT alter the
|
||||
# AIMessage(tool_calls=...) here (would put framework words in
|
||||
# the model's mouth, polluting downstream consumers like
|
||||
# MemoryMiddleware), nor insert a separate non-tool message
|
||||
# (would break OpenAI/Moonshot tool-call pairing because the
|
||||
# tools node has not produced ToolMessage responses yet). The
|
||||
# warning is delivered via ``wrap_model_call`` below.
|
||||
self._queue_pending_warning(runtime, warning)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None:
|
||||
"""Drop stale pending warnings for previous runs in this thread."""
|
||||
thread_id, current_run_id = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
for key in list(self._pending_warnings):
|
||||
if key[0] == thread_id and key[1] != current_run_id:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
|
||||
def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None:
|
||||
"""Drop pending warnings owned by the current thread/run."""
|
||||
pending_key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._drop_pending_warning_key_locked(pending_key)
|
||||
|
||||
@staticmethod
|
||||
def _format_warning_message(warnings: list[str]) -> str:
|
||||
"""Merge pending warnings into one prompt message."""
|
||||
deduped = list(dict.fromkeys(warnings))
|
||||
return "\n\n".join(deduped)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_other_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_other_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
@@ -424,6 +539,59 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_current_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_current_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
def _drain_pending_warnings(self, runtime: Runtime) -> list[str]:
|
||||
"""Pop and return all queued warnings for *runtime*'s thread/run."""
|
||||
pending_key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
warnings = self._pending_warnings.pop(pending_key, [])
|
||||
self._pending_warning_touch_order.pop(pending_key, None)
|
||||
return warnings
|
||||
|
||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
||||
"""Append queued loop warnings (if any) to the outgoing message list.
|
||||
|
||||
The warning is placed *after* every existing message, including the
|
||||
ToolMessage responses to the previous AIMessage(tool_calls). This
|
||||
keeps ``assistant tool_calls -> tool_messages`` pairing intact for
|
||||
OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage
|
||||
restriction (we use HumanMessage), and never mutates an existing
|
||||
AIMessage.
|
||||
"""
|
||||
warnings = self._drain_pending_warnings(request.runtime)
|
||||
if not warnings:
|
||||
return request
|
||||
new_messages = [
|
||||
*request.messages,
|
||||
HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"),
|
||||
]
|
||||
return request.override(messages=new_messages)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._augment_request(request))
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
"""Clear tracking state. If thread_id given, clear only that thread."""
|
||||
with self._lock:
|
||||
@@ -432,8 +600,13 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._warned.pop(thread_id, None)
|
||||
self._tool_freq.pop(thread_id, None)
|
||||
self._tool_freq_warned.pop(thread_id, None)
|
||||
for key in list(self._pending_warnings):
|
||||
if key[0] == thread_id:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
self._tool_freq.clear()
|
||||
self._tool_freq_warned.clear()
|
||||
self._pending_warnings.clear()
|
||||
self._pending_warning_touch_order.clear()
|
||||
|
||||
Reference in New Issue
Block a user