mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
dcc6f1e678
* fix(loop-detection): defer warn injection to wrap_model_call The warn branch in LoopDetectionMiddleware injected a HumanMessage into state from after_model. The tools node had not yet produced ToolMessage responses to the previous AIMessage(tool_calls=...), so the new HumanMessage landed *between* the assistant's tool_calls and their responses. OpenAI/Moonshot reject the next request with "tool_call_ids did not have response messages" because their validators require tool_calls to be followed immediately by tool messages. Detection now runs in after_model as before, but only enqueues the warning into a per-thread list. Injection happens in wrap_model_call, where every prior ToolMessage is already present in request.messages. The warning is appended at the end as HumanMessage(name="loop_warning") — pairing intact, AIMessage semantics untouched, no SystemMessage issues for Anthropic. Closes #2029, addresses #2255 #2293 #2304 #2511. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(channels): remove loop warning display filter * feat(loop-detection): scope pending warnings by run * docs(loop-detection): update docs * test(loop-detection): assert deferred warnings are queued * fix(loop-detection): cap transient warning state * docs: update docs * add async awrap_model_call test coverage * docs(loop-detection): document transient warnings --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
613 lines
26 KiB
Python
613 lines
26 KiB
Python
"""Middleware to detect and break repetitive tool call loops.
|
|
|
|
P0 safety: prevents the agent from calling the same tool with the same
|
|
arguments indefinitely until the recursion limit kills the run.
|
|
|
|
Detection strategy:
|
|
1. After each model response, hash the tool calls (name + args).
|
|
2. Track recent hashes in a sliding window.
|
|
3. If the same hash appears >= warn_threshold times, queue a
|
|
"you are repeating yourself — wrap up" warning for the current
|
|
thread/run. The warning is **injected at the next model call** (in
|
|
``wrap_model_call``) as a ``HumanMessage`` appended to the message
|
|
list, *after* all ToolMessage responses to the previous
|
|
AIMessage(tool_calls).
|
|
4. If it appears >= hard_limit times, strip all tool_calls from the
|
|
response so the agent is forced to produce a final text answer.
|
|
|
|
Why the warning is injected at ``wrap_model_call`` instead of
|
|
``after_model``:
|
|
|
|
``after_model`` fires immediately after the model emits an
|
|
``AIMessage`` that may carry ``tool_calls``. The tools node has not
|
|
run yet, so no matching ``ToolMessage`` exists in the history. Any
|
|
message we add here lands *between* the assistant's tool_calls and
|
|
their responses. OpenAI/Moonshot reject the next request with
|
|
``"tool_call_ids did not have response messages"`` because their
|
|
validators require the assistant's tool_calls to be followed
|
|
immediately by tool messages. Anthropic also disallows mid-stream
|
|
``SystemMessage``. By deferring the warning to ``wrap_model_call``,
|
|
every prior ToolMessage is already present in the request's message
|
|
list and the warning is appended at the end — pairing intact, no
|
|
``AIMessage`` semantics are mutated.
|
|
|
|
Queued warnings are intentionally transient. If a run ends before the
|
|
next model request drains a queued warning, ``after_agent`` drops it
|
|
instead of carrying it into a later invocation for the same thread. The
|
|
hard-stop path still forces termination when the configured safety limit
|
|
is reached.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import threading
|
|
from collections import OrderedDict, defaultdict
|
|
from collections.abc import Awaitable, Callable
|
|
from copy import deepcopy
|
|
from typing import TYPE_CHECKING, override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.runtime import Runtime
|
|
|
|
if TYPE_CHECKING:
|
|
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Defaults — can be overridden via constructor
|
|
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
|
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
|
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
|
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
|
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
|
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
|
_MAX_PENDING_WARNINGS_PER_RUN = 4
|
|
|
|
|
|
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
|
"""Normalize tool call args to a dict plus an optional fallback key.
|
|
|
|
Some providers serialize ``args`` as a JSON string instead of a dict.
|
|
We defensively parse those cases so loop detection does not crash while
|
|
still preserving a stable fallback key for non-dict payloads.
|
|
"""
|
|
if isinstance(raw_args, dict):
|
|
return raw_args, None
|
|
|
|
if isinstance(raw_args, str):
|
|
try:
|
|
parsed = json.loads(raw_args)
|
|
except (TypeError, ValueError, json.JSONDecodeError):
|
|
return {}, raw_args
|
|
|
|
if isinstance(parsed, dict):
|
|
return parsed, None
|
|
return {}, json.dumps(parsed, sort_keys=True, default=str)
|
|
|
|
if raw_args is None:
|
|
return {}, None
|
|
|
|
return {}, json.dumps(raw_args, sort_keys=True, default=str)
|
|
|
|
|
|
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
|
|
"""Derive a stable key from salient args without overfitting to noise."""
|
|
if name == "read_file" and fallback_key is None:
|
|
path = args.get("path") or ""
|
|
start_line = args.get("start_line")
|
|
end_line = args.get("end_line")
|
|
|
|
bucket_size = 200
|
|
try:
|
|
start_line = int(start_line) if start_line is not None else 1
|
|
except (TypeError, ValueError):
|
|
start_line = 1
|
|
try:
|
|
end_line = int(end_line) if end_line is not None else start_line
|
|
except (TypeError, ValueError):
|
|
end_line = start_line
|
|
|
|
start_line, end_line = sorted((start_line, end_line))
|
|
bucket_start = max(start_line, 1)
|
|
bucket_end = max(end_line, 1)
|
|
bucket_start = (bucket_start - 1) // bucket_size
|
|
bucket_end = (bucket_end - 1) // bucket_size
|
|
return f"{path}:{bucket_start}-{bucket_end}"
|
|
|
|
# write_file / str_replace are content-sensitive: same path may be updated
|
|
# with different payloads during iteration. Using only salient fields (path)
|
|
# can collapse distinct calls, so we hash full args to reduce false positives.
|
|
if name in {"write_file", "str_replace"}:
|
|
if fallback_key is not None:
|
|
return fallback_key
|
|
return json.dumps(args, sort_keys=True, default=str)
|
|
|
|
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
|
|
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
|
|
if stable_args:
|
|
return json.dumps(stable_args, sort_keys=True, default=str)
|
|
|
|
if fallback_key is not None:
|
|
return fallback_key
|
|
|
|
return json.dumps(args, sort_keys=True, default=str)
|
|
|
|
|
|
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
|
"""Deterministic hash of a set of tool calls (name + stable key).
|
|
|
|
This is intended to be order-independent: the same multiset of tool calls
|
|
should always produce the same hash, regardless of their input order.
|
|
"""
|
|
# Normalize each tool call to a stable (name, key) structure.
|
|
normalized: list[str] = []
|
|
for tc in tool_calls:
|
|
name = tc.get("name", "")
|
|
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
|
|
key = _stable_tool_key(name, args, fallback_key)
|
|
|
|
normalized.append(f"{name}:{key}")
|
|
|
|
# Sort so permutations of the same multiset of calls yield the same ordering.
|
|
normalized.sort()
|
|
blob = json.dumps(normalized, sort_keys=True, default=str)
|
|
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
|
|
|
|
|
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
|
|
|
_TOOL_FREQ_WARNING_MSG = (
|
|
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
|
)
|
|
|
|
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
|
|
|
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
|
|
|
|
|
|
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|
"""Detects and breaks repetitive tool call loops.
|
|
|
|
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
|
|
construct via :meth:`from_config` to ensure values pass Pydantic validation.
|
|
|
|
Args:
|
|
warn_threshold: Number of identical tool call sets before injecting
|
|
a warning message. Default: 3.
|
|
hard_limit: Number of identical tool call sets before stripping
|
|
tool_calls entirely. Default: 5.
|
|
window_size: Size of the sliding window for tracking calls.
|
|
Default: 20.
|
|
max_tracked_threads: Maximum number of threads to track before
|
|
evicting the least recently used. Default: 100.
|
|
tool_freq_warn: Number of calls to the same tool *type* (regardless
|
|
of arguments) before injecting a frequency warning. Catches
|
|
cross-file read loops that hash-based detection misses.
|
|
Default: 30.
|
|
tool_freq_hard_limit: Number of calls to the same tool type before
|
|
forcing a stop. Default: 50.
|
|
tool_freq_overrides: Per-tool overrides for frequency thresholds,
|
|
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
|
|
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
|
|
that specific tool. Tools not listed here fall back to the global
|
|
thresholds. Useful for raising limits on intentionally
|
|
high-frequency tools (e.g. ``bash`` in batch pipelines) without
|
|
weakening protection on all other tools. Default: ``None``
|
|
(no overrides).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
|
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
|
window_size: int = _DEFAULT_WINDOW_SIZE,
|
|
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
|
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
|
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
|
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
|
|
):
|
|
super().__init__()
|
|
self.warn_threshold = warn_threshold
|
|
self.hard_limit = hard_limit
|
|
self.window_size = window_size
|
|
self.max_tracked_threads = max_tracked_threads
|
|
self.tool_freq_warn = tool_freq_warn
|
|
self.tool_freq_hard_limit = tool_freq_hard_limit
|
|
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
|
|
self._lock = threading.Lock()
|
|
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
|
self._warned: dict[str, set[str]] = defaultdict(set)
|
|
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
|
# Per-thread/run queue of warnings to inject at the next model call.
|
|
# Populated by ``after_model`` (detection) and drained by
|
|
# ``wrap_model_call`` (injection); see module docstring.
|
|
self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list)
|
|
self._pending_warning_touch_order: OrderedDict[tuple[str, str], None] = OrderedDict()
|
|
self._max_pending_warning_keys = max(1, self.max_tracked_threads * 2)
|
|
|
|
@classmethod
|
|
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
|
"""Construct from a Pydantic-validated config, trusting its validation."""
|
|
return cls(
|
|
warn_threshold=config.warn_threshold,
|
|
hard_limit=config.hard_limit,
|
|
window_size=config.window_size,
|
|
max_tracked_threads=config.max_tracked_threads,
|
|
tool_freq_warn=config.tool_freq_warn,
|
|
tool_freq_hard_limit=config.tool_freq_hard_limit,
|
|
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
|
|
)
|
|
|
|
def _get_thread_id(self, runtime: Runtime) -> str:
|
|
"""Extract thread_id from runtime context for per-thread tracking."""
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id:
|
|
return str(thread_id)
|
|
return "default"
|
|
|
|
def _get_run_id(self, runtime: Runtime) -> str:
|
|
"""Extract run_id from runtime context for per-run warning scoping."""
|
|
run_id = runtime.context.get("run_id") if runtime.context else None
|
|
if run_id:
|
|
return str(run_id)
|
|
return "default"
|
|
|
|
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
|
"""Return the pending-warning key for the current thread/run."""
|
|
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
|
|
|
def _evict_if_needed(self) -> None:
|
|
"""Evict least recently used threads if over the limit.
|
|
|
|
Must be called while holding self._lock.
|
|
"""
|
|
while len(self._history) > self.max_tracked_threads:
|
|
evicted_id, _ = self._history.popitem(last=False)
|
|
self._warned.pop(evicted_id, None)
|
|
self._tool_freq.pop(evicted_id, None)
|
|
self._tool_freq_warned.pop(evicted_id, None)
|
|
for key in list(self._pending_warnings):
|
|
if key[0] == evicted_id:
|
|
self._drop_pending_warning_key_locked(key)
|
|
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
|
|
|
def _drop_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
|
|
"""Drop all pending-warning bookkeeping for one thread/run key.
|
|
|
|
Must be called while holding self._lock.
|
|
"""
|
|
self._pending_warnings.pop(key, None)
|
|
self._pending_warning_touch_order.pop(key, None)
|
|
|
|
def _touch_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
|
|
"""Mark a pending-warning key as recently used.
|
|
|
|
Must be called while holding self._lock.
|
|
"""
|
|
self._pending_warning_touch_order[key] = None
|
|
self._pending_warning_touch_order.move_to_end(key)
|
|
|
|
def _prune_pending_warning_state_locked(self, protected_key: tuple[str, str]) -> None:
|
|
"""Cap pending-warning state across abnormal or concurrent runs.
|
|
|
|
Must be called while holding self._lock.
|
|
"""
|
|
overflow = len(self._pending_warning_touch_order) - self._max_pending_warning_keys
|
|
if overflow <= 0:
|
|
return
|
|
|
|
candidates = [key for key in self._pending_warning_touch_order if key != protected_key]
|
|
for key in candidates[:overflow]:
|
|
self._drop_pending_warning_key_locked(key)
|
|
|
|
def _queue_pending_warning(self, runtime: Runtime, warning: str) -> None:
|
|
"""Queue one transient warning for the current thread/run with caps."""
|
|
pending_key = self._pending_key(runtime)
|
|
with self._lock:
|
|
warnings = self._pending_warnings[pending_key]
|
|
if warning not in warnings:
|
|
warnings.append(warning)
|
|
if len(warnings) > _MAX_PENDING_WARNINGS_PER_RUN:
|
|
del warnings[: len(warnings) - _MAX_PENDING_WARNINGS_PER_RUN]
|
|
self._touch_pending_warning_key_locked(pending_key)
|
|
self._prune_pending_warning_state_locked(protected_key=pending_key)
|
|
|
|
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
|
"""Track tool calls and check for loops.
|
|
|
|
Two detection layers:
|
|
1. **Hash-based** (existing): catches identical tool call sets.
|
|
2. **Frequency-based** (new): catches the same *tool type* being
|
|
called many times with varying arguments (e.g. ``read_file``
|
|
on 40 different files).
|
|
|
|
Returns:
|
|
(warning_message_or_none, should_hard_stop)
|
|
"""
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return None, False
|
|
|
|
last_msg = messages[-1]
|
|
if getattr(last_msg, "type", None) != "ai":
|
|
return None, False
|
|
|
|
tool_calls = getattr(last_msg, "tool_calls", None)
|
|
if not tool_calls:
|
|
return None, False
|
|
|
|
thread_id = self._get_thread_id(runtime)
|
|
call_hash = _hash_tool_calls(tool_calls)
|
|
|
|
with self._lock:
|
|
# Touch / create entry (move to end for LRU)
|
|
if thread_id in self._history:
|
|
self._history.move_to_end(thread_id)
|
|
else:
|
|
self._history[thread_id] = []
|
|
self._evict_if_needed()
|
|
|
|
history = self._history[thread_id]
|
|
history.append(call_hash)
|
|
if len(history) > self.window_size:
|
|
history[:] = history[-self.window_size :]
|
|
|
|
warned_hashes = self._warned.get(thread_id)
|
|
if warned_hashes is not None:
|
|
warned_hashes.intersection_update(history)
|
|
if not warned_hashes:
|
|
self._warned.pop(thread_id, None)
|
|
|
|
count = history.count(call_hash)
|
|
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
|
|
|
# --- Layer 1: hash-based (identical call sets) ---
|
|
if count >= self.hard_limit:
|
|
logger.error(
|
|
"Loop hard limit reached — forcing stop",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"call_hash": call_hash,
|
|
"count": count,
|
|
"tools": tool_names,
|
|
},
|
|
)
|
|
return _HARD_STOP_MSG, True
|
|
|
|
if count >= self.warn_threshold:
|
|
warned = self._warned[thread_id]
|
|
if call_hash not in warned:
|
|
warned.add(call_hash)
|
|
logger.warning(
|
|
"Repetitive tool calls detected — injecting warning",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"call_hash": call_hash,
|
|
"count": count,
|
|
"tools": tool_names,
|
|
},
|
|
)
|
|
return _WARNING_MSG, False
|
|
|
|
# --- Layer 2: per-tool-type frequency ---
|
|
freq = self._tool_freq[thread_id]
|
|
for tc in tool_calls:
|
|
name = tc.get("name", "")
|
|
if not name:
|
|
continue
|
|
freq[name] += 1
|
|
tc_count = freq[name]
|
|
|
|
if name in self._tool_freq_overrides:
|
|
eff_warn, eff_hard = self._tool_freq_overrides[name]
|
|
else:
|
|
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
|
|
|
|
if tc_count >= eff_hard:
|
|
logger.error(
|
|
"Tool frequency hard limit reached — forcing stop",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"tool_name": name,
|
|
"count": tc_count,
|
|
},
|
|
)
|
|
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
|
|
|
if tc_count >= eff_warn:
|
|
warned = self._tool_freq_warned[thread_id]
|
|
if name not in warned:
|
|
warned.add(name)
|
|
logger.warning(
|
|
"Tool frequency warning — too many calls to same tool type",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"tool_name": name,
|
|
"count": tc_count,
|
|
},
|
|
)
|
|
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
|
|
|
|
return None, False
|
|
|
|
@staticmethod
|
|
def _append_text(content: str | list | None, text: str) -> str | list:
|
|
"""Append *text* to AIMessage content, handling str, list, and None.
|
|
|
|
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
|
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
|
a string to a list, which would raise ``TypeError``.
|
|
"""
|
|
if content is None:
|
|
return text
|
|
if isinstance(content, list):
|
|
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
|
if isinstance(content, str):
|
|
return content + f"\n\n{text}"
|
|
# Fallback: coerce unexpected types to str to avoid TypeError
|
|
return str(content) + f"\n\n{text}"
|
|
|
|
@staticmethod
|
|
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
|
|
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
|
|
update = {
|
|
"tool_calls": [],
|
|
"content": content,
|
|
}
|
|
|
|
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
|
|
for key in ("tool_calls", "function_call"):
|
|
additional_kwargs.pop(key, None)
|
|
update["additional_kwargs"] = additional_kwargs
|
|
|
|
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
|
|
if response_metadata.get("finish_reason") == "tool_calls":
|
|
response_metadata["finish_reason"] = "stop"
|
|
update["response_metadata"] = response_metadata
|
|
|
|
return update
|
|
|
|
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
warning, hard_stop = self._track_and_check(state, runtime)
|
|
|
|
if hard_stop:
|
|
# Strip tool_calls from the last AIMessage to force text output.
|
|
# Once tool_calls are stripped, the AIMessage no longer requires
|
|
# matching ToolMessage responses, so mutating it in place here
|
|
# is safe for OpenAI/Moonshot pairing validators.
|
|
messages = state.get("messages", [])
|
|
last_msg = messages[-1]
|
|
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
|
|
stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
|
|
return {"messages": [stripped_msg]}
|
|
|
|
if warning:
|
|
# Defer injection to the next model call. We must NOT alter the
|
|
# AIMessage(tool_calls=...) here (would put framework words in
|
|
# the model's mouth, polluting downstream consumers like
|
|
# MemoryMiddleware), nor insert a separate non-tool message
|
|
# (would break OpenAI/Moonshot tool-call pairing because the
|
|
# tools node has not produced ToolMessage responses yet). The
|
|
# warning is delivered via ``wrap_model_call`` below.
|
|
self._queue_pending_warning(runtime, warning)
|
|
return None
|
|
|
|
return None
|
|
|
|
def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None:
|
|
"""Drop stale pending warnings for previous runs in this thread."""
|
|
thread_id, current_run_id = self._pending_key(runtime)
|
|
with self._lock:
|
|
for key in list(self._pending_warnings):
|
|
if key[0] == thread_id and key[1] != current_run_id:
|
|
self._drop_pending_warning_key_locked(key)
|
|
|
|
def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None:
|
|
"""Drop pending warnings owned by the current thread/run."""
|
|
pending_key = self._pending_key(runtime)
|
|
with self._lock:
|
|
self._drop_pending_warning_key_locked(pending_key)
|
|
|
|
@staticmethod
|
|
def _format_warning_message(warnings: list[str]) -> str:
|
|
"""Merge pending warnings into one prompt message."""
|
|
deduped = list(dict.fromkeys(warnings))
|
|
return "\n\n".join(deduped)
|
|
|
|
@override
|
|
def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
self._clear_other_run_pending_warnings(runtime)
|
|
return None
|
|
|
|
@override
|
|
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
self._clear_other_run_pending_warnings(runtime)
|
|
return None
|
|
|
|
@override
|
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._apply(state, runtime)
|
|
|
|
@override
|
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._apply(state, runtime)
|
|
|
|
@override
|
|
def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
self._clear_current_run_pending_warnings(runtime)
|
|
return None
|
|
|
|
@override
|
|
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
self._clear_current_run_pending_warnings(runtime)
|
|
return None
|
|
|
|
def _drain_pending_warnings(self, runtime: Runtime) -> list[str]:
|
|
"""Pop and return all queued warnings for *runtime*'s thread/run."""
|
|
pending_key = self._pending_key(runtime)
|
|
with self._lock:
|
|
warnings = self._pending_warnings.pop(pending_key, [])
|
|
self._pending_warning_touch_order.pop(pending_key, None)
|
|
return warnings
|
|
|
|
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
|
"""Append queued loop warnings (if any) to the outgoing message list.
|
|
|
|
The warning is placed *after* every existing message, including the
|
|
ToolMessage responses to the previous AIMessage(tool_calls). This
|
|
keeps ``assistant tool_calls -> tool_messages`` pairing intact for
|
|
OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage
|
|
restriction (we use HumanMessage), and never mutates an existing
|
|
AIMessage.
|
|
"""
|
|
warnings = self._drain_pending_warnings(request.runtime)
|
|
if not warnings:
|
|
return request
|
|
new_messages = [
|
|
*request.messages,
|
|
HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"),
|
|
]
|
|
return request.override(messages=new_messages)
|
|
|
|
@override
|
|
def wrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
) -> ModelCallResult:
|
|
return handler(self._augment_request(request))
|
|
|
|
@override
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
) -> ModelCallResult:
|
|
return await handler(self._augment_request(request))
|
|
|
|
def reset(self, thread_id: str | None = None) -> None:
|
|
"""Clear tracking state. If thread_id given, clear only that thread."""
|
|
with self._lock:
|
|
if thread_id:
|
|
self._history.pop(thread_id, None)
|
|
self._warned.pop(thread_id, None)
|
|
self._tool_freq.pop(thread_id, None)
|
|
self._tool_freq_warned.pop(thread_id, None)
|
|
for key in list(self._pending_warnings):
|
|
if key[0] == thread_id:
|
|
self._drop_pending_warning_key_locked(key)
|
|
else:
|
|
self._history.clear()
|
|
self._warned.clear()
|
|
self._tool_freq.clear()
|
|
self._tool_freq_warned.clear()
|
|
self._pending_warnings.clear()
|
|
self._pending_warning_touch_order.clear()
|