mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
5b633449f8
* fix(middleware): add per-tool-type frequency detection to LoopDetectionMiddleware The existing hash-based loop detection only catches identical tool call sets. When the agent calls the same tool type (e.g. read_file) on many different files, each call produces a unique hash and bypasses detection. This causes the agent to exhaust recursion_limit, consuming 150K-225K tokens per failed run. Add a second detection layer that tracks cumulative call counts per tool type per thread. Warns at 30 calls (configurable) and forces stop at 50. The hard stop message now uses the actual returned message instead of a hardcoded constant, so both hash-based and frequency-based stops produce accurate diagnostics. Also fix _apply() to use the warning message returned by _track_and_check() for hard stops, instead of always using _HARD_STOP_MSG. Closes #1987 * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(lint): remove unused imports and fix line length - Remove unused _TOOL_FREQ_HARD_STOP_MSG and _TOOL_FREQ_WARNING_MSG imports from test file (F401) - Break long _TOOL_FREQ_WARNING_MSG string to fit within 240 char limit (E501) * style: apply ruff format * test: add LRU eviction and per-thread reset coverage for frequency state Address review feedback from @WillemJiang: - Verify _tool_freq and _tool_freq_warned are cleaned on LRU eviction - Add test for reset(thread_id=...) clearing only the target thread's frequency state while leaving others intact * fix(makefile): route Windows shell-script targets through Git Bash (#2060) --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Asish Kumar <87874775+officialasishkumar@users.noreply.github.com>
373 lines
15 KiB
Python
373 lines
15 KiB
Python
"""Middleware to detect and break repetitive tool call loops.
|
|
|
|
P0 safety: prevents the agent from calling the same tool with the same
|
|
arguments indefinitely until the recursion limit kills the run.
|
|
|
|
Detection strategy:
|
|
1. After each model response, hash the tool calls (name + args).
|
|
2. Track recent hashes in a sliding window.
|
|
3. If the same hash appears >= warn_threshold times, inject a
|
|
"you are repeating yourself — wrap up" system message (once per hash).
|
|
4. If it appears >= hard_limit times, strip all tool_calls from the
|
|
response so the agent is forced to produce a final text answer.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import threading
|
|
from collections import OrderedDict, defaultdict
|
|
from typing import override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.runtime import Runtime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Defaults — can be overridden via constructor
|
|
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
|
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
|
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
|
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
|
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
|
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
|
|
|
|
|
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
|
"""Normalize tool call args to a dict plus an optional fallback key.
|
|
|
|
Some providers serialize ``args`` as a JSON string instead of a dict.
|
|
We defensively parse those cases so loop detection does not crash while
|
|
still preserving a stable fallback key for non-dict payloads.
|
|
"""
|
|
if isinstance(raw_args, dict):
|
|
return raw_args, None
|
|
|
|
if isinstance(raw_args, str):
|
|
try:
|
|
parsed = json.loads(raw_args)
|
|
except (TypeError, ValueError, json.JSONDecodeError):
|
|
return {}, raw_args
|
|
|
|
if isinstance(parsed, dict):
|
|
return parsed, None
|
|
return {}, json.dumps(parsed, sort_keys=True, default=str)
|
|
|
|
if raw_args is None:
|
|
return {}, None
|
|
|
|
return {}, json.dumps(raw_args, sort_keys=True, default=str)
|
|
|
|
|
|
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
|
|
"""Derive a stable key from salient args without overfitting to noise."""
|
|
if name == "read_file" and fallback_key is None:
|
|
path = args.get("path") or ""
|
|
start_line = args.get("start_line")
|
|
end_line = args.get("end_line")
|
|
|
|
bucket_size = 200
|
|
try:
|
|
start_line = int(start_line) if start_line is not None else 1
|
|
except (TypeError, ValueError):
|
|
start_line = 1
|
|
try:
|
|
end_line = int(end_line) if end_line is not None else start_line
|
|
except (TypeError, ValueError):
|
|
end_line = start_line
|
|
|
|
start_line, end_line = sorted((start_line, end_line))
|
|
bucket_start = max(start_line, 1)
|
|
bucket_end = max(end_line, 1)
|
|
bucket_start = (bucket_start - 1) // bucket_size
|
|
bucket_end = (bucket_end - 1) // bucket_size
|
|
return f"{path}:{bucket_start}-{bucket_end}"
|
|
|
|
# write_file / str_replace are content-sensitive: same path may be updated
|
|
# with different payloads during iteration. Using only salient fields (path)
|
|
# can collapse distinct calls, so we hash full args to reduce false positives.
|
|
if name in {"write_file", "str_replace"}:
|
|
if fallback_key is not None:
|
|
return fallback_key
|
|
return json.dumps(args, sort_keys=True, default=str)
|
|
|
|
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
|
|
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
|
|
if stable_args:
|
|
return json.dumps(stable_args, sort_keys=True, default=str)
|
|
|
|
if fallback_key is not None:
|
|
return fallback_key
|
|
|
|
return json.dumps(args, sort_keys=True, default=str)
|
|
|
|
|
|
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
|
"""Deterministic hash of a set of tool calls (name + stable key).
|
|
|
|
This is intended to be order-independent: the same multiset of tool calls
|
|
should always produce the same hash, regardless of their input order.
|
|
"""
|
|
# Normalize each tool call to a stable (name, key) structure.
|
|
normalized: list[str] = []
|
|
for tc in tool_calls:
|
|
name = tc.get("name", "")
|
|
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
|
|
key = _stable_tool_key(name, args, fallback_key)
|
|
|
|
normalized.append(f"{name}:{key}")
|
|
|
|
# Sort so permutations of the same multiset of calls yield the same ordering.
|
|
normalized.sort()
|
|
blob = json.dumps(normalized, sort_keys=True, default=str)
|
|
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
|
|
|
|
|
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
|
|
|
_TOOL_FREQ_WARNING_MSG = (
|
|
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
|
)
|
|
|
|
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
|
|
|
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
|
|
|
|
|
|
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|
"""Detects and breaks repetitive tool call loops.
|
|
|
|
Args:
|
|
warn_threshold: Number of identical tool call sets before injecting
|
|
a warning message. Default: 3.
|
|
hard_limit: Number of identical tool call sets before stripping
|
|
tool_calls entirely. Default: 5.
|
|
window_size: Size of the sliding window for tracking calls.
|
|
Default: 20.
|
|
max_tracked_threads: Maximum number of threads to track before
|
|
evicting the least recently used. Default: 100.
|
|
tool_freq_warn: Number of calls to the same tool *type* (regardless
|
|
of arguments) before injecting a frequency warning. Catches
|
|
cross-file read loops that hash-based detection misses.
|
|
Default: 30.
|
|
tool_freq_hard_limit: Number of calls to the same tool type before
|
|
forcing a stop. Default: 50.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
|
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
|
window_size: int = _DEFAULT_WINDOW_SIZE,
|
|
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
|
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
|
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
|
):
|
|
super().__init__()
|
|
self.warn_threshold = warn_threshold
|
|
self.hard_limit = hard_limit
|
|
self.window_size = window_size
|
|
self.max_tracked_threads = max_tracked_threads
|
|
self.tool_freq_warn = tool_freq_warn
|
|
self.tool_freq_hard_limit = tool_freq_hard_limit
|
|
self._lock = threading.Lock()
|
|
# Per-thread tracking using OrderedDict for LRU eviction
|
|
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
|
self._warned: dict[str, set[str]] = defaultdict(set)
|
|
# Per-thread, per-tool-type cumulative call counts
|
|
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
|
|
|
def _get_thread_id(self, runtime: Runtime) -> str:
|
|
"""Extract thread_id from runtime context for per-thread tracking."""
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id:
|
|
return thread_id
|
|
return "default"
|
|
|
|
def _evict_if_needed(self) -> None:
|
|
"""Evict least recently used threads if over the limit.
|
|
|
|
Must be called while holding self._lock.
|
|
"""
|
|
while len(self._history) > self.max_tracked_threads:
|
|
evicted_id, _ = self._history.popitem(last=False)
|
|
self._warned.pop(evicted_id, None)
|
|
self._tool_freq.pop(evicted_id, None)
|
|
self._tool_freq_warned.pop(evicted_id, None)
|
|
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
|
|
|
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
|
"""Track tool calls and check for loops.
|
|
|
|
Two detection layers:
|
|
1. **Hash-based** (existing): catches identical tool call sets.
|
|
2. **Frequency-based** (new): catches the same *tool type* being
|
|
called many times with varying arguments (e.g. ``read_file``
|
|
on 40 different files).
|
|
|
|
Returns:
|
|
(warning_message_or_none, should_hard_stop)
|
|
"""
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return None, False
|
|
|
|
last_msg = messages[-1]
|
|
if getattr(last_msg, "type", None) != "ai":
|
|
return None, False
|
|
|
|
tool_calls = getattr(last_msg, "tool_calls", None)
|
|
if not tool_calls:
|
|
return None, False
|
|
|
|
thread_id = self._get_thread_id(runtime)
|
|
call_hash = _hash_tool_calls(tool_calls)
|
|
|
|
with self._lock:
|
|
# Touch / create entry (move to end for LRU)
|
|
if thread_id in self._history:
|
|
self._history.move_to_end(thread_id)
|
|
else:
|
|
self._history[thread_id] = []
|
|
self._evict_if_needed()
|
|
|
|
history = self._history[thread_id]
|
|
history.append(call_hash)
|
|
if len(history) > self.window_size:
|
|
history[:] = history[-self.window_size :]
|
|
|
|
count = history.count(call_hash)
|
|
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
|
|
|
# --- Layer 1: hash-based (identical call sets) ---
|
|
if count >= self.hard_limit:
|
|
logger.error(
|
|
"Loop hard limit reached — forcing stop",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"call_hash": call_hash,
|
|
"count": count,
|
|
"tools": tool_names,
|
|
},
|
|
)
|
|
return _HARD_STOP_MSG, True
|
|
|
|
if count >= self.warn_threshold:
|
|
warned = self._warned[thread_id]
|
|
if call_hash not in warned:
|
|
warned.add(call_hash)
|
|
logger.warning(
|
|
"Repetitive tool calls detected — injecting warning",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"call_hash": call_hash,
|
|
"count": count,
|
|
"tools": tool_names,
|
|
},
|
|
)
|
|
return _WARNING_MSG, False
|
|
|
|
# --- Layer 2: per-tool-type frequency ---
|
|
freq = self._tool_freq[thread_id]
|
|
for tc in tool_calls:
|
|
name = tc.get("name", "")
|
|
if not name:
|
|
continue
|
|
freq[name] += 1
|
|
tc_count = freq[name]
|
|
|
|
if tc_count >= self.tool_freq_hard_limit:
|
|
logger.error(
|
|
"Tool frequency hard limit reached — forcing stop",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"tool_name": name,
|
|
"count": tc_count,
|
|
},
|
|
)
|
|
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
|
|
|
if tc_count >= self.tool_freq_warn:
|
|
warned = self._tool_freq_warned[thread_id]
|
|
if name not in warned:
|
|
warned.add(name)
|
|
logger.warning(
|
|
"Tool frequency warning — too many calls to same tool type",
|
|
extra={
|
|
"thread_id": thread_id,
|
|
"tool_name": name,
|
|
"count": tc_count,
|
|
},
|
|
)
|
|
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
|
|
|
|
return None, False
|
|
|
|
@staticmethod
|
|
def _append_text(content: str | list | None, text: str) -> str | list:
|
|
"""Append *text* to AIMessage content, handling str, list, and None.
|
|
|
|
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
|
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
|
a string to a list, which would raise ``TypeError``.
|
|
"""
|
|
if content is None:
|
|
return text
|
|
if isinstance(content, list):
|
|
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
|
if isinstance(content, str):
|
|
return content + f"\n\n{text}"
|
|
# Fallback: coerce unexpected types to str to avoid TypeError
|
|
return str(content) + f"\n\n{text}"
|
|
|
|
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
warning, hard_stop = self._track_and_check(state, runtime)
|
|
|
|
if hard_stop:
|
|
# Strip tool_calls from the last AIMessage to force text output
|
|
messages = state.get("messages", [])
|
|
last_msg = messages[-1]
|
|
stripped_msg = last_msg.model_copy(
|
|
update={
|
|
"tool_calls": [],
|
|
"content": self._append_text(last_msg.content, warning),
|
|
}
|
|
)
|
|
return {"messages": [stripped_msg]}
|
|
|
|
if warning:
|
|
# Inject as HumanMessage instead of SystemMessage to avoid
|
|
# Anthropic's "multiple non-consecutive system messages" error.
|
|
# Anthropic models require system messages only at the start of
|
|
# the conversation; injecting one mid-conversation crashes
|
|
# langchain_anthropic's _format_messages(). HumanMessage works
|
|
# with all providers. See #1299.
|
|
return {"messages": [HumanMessage(content=warning)]}
|
|
|
|
return None
|
|
|
|
@override
|
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._apply(state, runtime)
|
|
|
|
@override
|
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._apply(state, runtime)
|
|
|
|
def reset(self, thread_id: str | None = None) -> None:
|
|
"""Clear tracking state. If thread_id given, clear only that thread."""
|
|
with self._lock:
|
|
if thread_id:
|
|
self._history.pop(thread_id, None)
|
|
self._warned.pop(thread_id, None)
|
|
self._tool_freq.pop(thread_id, None)
|
|
self._tool_freq_warned.pop(thread_id, None)
|
|
else:
|
|
self._history.clear()
|
|
self._warned.clear()
|
|
self._tool_freq.clear()
|
|
self._tool_freq_warned.clear()
|