"""Middleware to detect and break repetitive tool call loops. P0 safety: prevents the agent from calling the same tool with the same arguments indefinitely until the recursion limit kills the run. Detection strategy: 1. After each model response, hash the tool calls (name + args). 2. Track recent hashes in a sliding window. 3. If the same hash appears >= warn_threshold times, inject a "you are repeating yourself — wrap up" system message (once per hash). 4. If it appears >= hard_limit times, strip all tool_calls from the response so the agent is forced to produce a final text answer. """ from __future__ import annotations import hashlib import json import logging import threading from collections import OrderedDict, defaultdict from copy import deepcopy from typing import TYPE_CHECKING, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime if TYPE_CHECKING: from deerflow.config.loop_detection_config import LoopDetectionConfig logger = logging.getLogger(__name__) # Defaults — can be overridden via constructor _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls _DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls _DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit _DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type _DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]: """Normalize tool call args to a dict plus an optional fallback key. Some providers serialize ``args`` as a JSON string instead of a dict. We defensively parse those cases so loop detection does not crash while still preserving a stable fallback key for non-dict payloads. """ if isinstance(raw_args, dict): return raw_args, None if isinstance(raw_args, str): try: parsed = json.loads(raw_args) except (TypeError, ValueError, json.JSONDecodeError): return {}, raw_args if isinstance(parsed, dict): return parsed, None return {}, json.dumps(parsed, sort_keys=True, default=str) if raw_args is None: return {}, None return {}, json.dumps(raw_args, sort_keys=True, default=str) def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str: """Derive a stable key from salient args without overfitting to noise.""" if name == "read_file" and fallback_key is None: path = args.get("path") or "" start_line = args.get("start_line") end_line = args.get("end_line") bucket_size = 200 try: start_line = int(start_line) if start_line is not None else 1 except (TypeError, ValueError): start_line = 1 try: end_line = int(end_line) if end_line is not None else start_line except (TypeError, ValueError): end_line = start_line start_line, end_line = sorted((start_line, end_line)) bucket_start = max(start_line, 1) bucket_end = max(end_line, 1) bucket_start = (bucket_start - 1) // bucket_size bucket_end = (bucket_end - 1) // bucket_size return f"{path}:{bucket_start}-{bucket_end}" # write_file / str_replace are content-sensitive: same path may be updated # with different payloads during iteration. Using only salient fields (path) # can collapse distinct calls, so we hash full args to reduce false positives. if name in {"write_file", "str_replace"}: if fallback_key is not None: return fallback_key return json.dumps(args, sort_keys=True, default=str) salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd") stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None} if stable_args: return json.dumps(stable_args, sort_keys=True, default=str) if fallback_key is not None: return fallback_key return json.dumps(args, sort_keys=True, default=str) def _hash_tool_calls(tool_calls: list[dict]) -> str: """Deterministic hash of a set of tool calls (name + stable key). This is intended to be order-independent: the same multiset of tool calls should always produce the same hash, regardless of their input order. """ # Normalize each tool call to a stable (name, key) structure. normalized: list[str] = [] for tc in tool_calls: name = tc.get("name", "") args, fallback_key = _normalize_tool_call_args(tc.get("args", {})) key = _stable_tool_key(name, args, fallback_key) normalized.append(f"{name}:{key}") # Sort so permutations of the same multiset of calls yield the same ordering. normalized.sort() blob = json.dumps(normalized, sort_keys=True, default=str) return hashlib.md5(blob.encode()).hexdigest()[:12] _WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." _TOOL_FREQ_WARNING_MSG = ( "[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." ) _HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far." _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far." class LoopDetectionMiddleware(AgentMiddleware[AgentState]): """Detects and breaks repetitive tool call loops. Threshold parameters are validated upstream by :class:`LoopDetectionConfig`; construct via :meth:`from_config` to ensure values pass Pydantic validation. Args: warn_threshold: Number of identical tool call sets before injecting a warning message. Default: 3. hard_limit: Number of identical tool call sets before stripping tool_calls entirely. Default: 5. window_size: Size of the sliding window for tracking calls. Default: 20. max_tracked_threads: Maximum number of threads to track before evicting the least recently used. Default: 100. tool_freq_warn: Number of calls to the same tool *type* (regardless of arguments) before injecting a frequency warning. Catches cross-file read loops that hash-based detection misses. Default: 30. tool_freq_hard_limit: Number of calls to the same tool type before forcing a stop. Default: 50. tool_freq_overrides: Per-tool overrides for frequency thresholds, keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for that specific tool. Tools not listed here fall back to the global thresholds. Useful for raising limits on intentionally high-frequency tools (e.g. ``bash`` in batch pipelines) without weakening protection on all other tools. Default: ``None`` (no overrides). """ def __init__( self, warn_threshold: int = _DEFAULT_WARN_THRESHOLD, hard_limit: int = _DEFAULT_HARD_LIMIT, window_size: int = _DEFAULT_WINDOW_SIZE, max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN, tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT, tool_freq_overrides: dict[str, tuple[int, int]] | None = None, ): super().__init__() self.warn_threshold = warn_threshold self.hard_limit = hard_limit self.window_size = window_size self.max_tracked_threads = max_tracked_threads self.tool_freq_warn = tool_freq_warn self.tool_freq_hard_limit = tool_freq_hard_limit self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {} self._lock = threading.Lock() self._history: OrderedDict[str, list[str]] = OrderedDict() self._warned: dict[str, set[str]] = defaultdict(set) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) @classmethod def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware: """Construct from a Pydantic-validated config, trusting its validation.""" return cls( warn_threshold=config.warn_threshold, hard_limit=config.hard_limit, window_size=config.window_size, max_tracked_threads=config.max_tracked_threads, tool_freq_warn=config.tool_freq_warn, tool_freq_hard_limit=config.tool_freq_hard_limit, tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()}, ) def _get_thread_id(self, runtime: Runtime) -> str: """Extract thread_id from runtime context for per-thread tracking.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id: return thread_id return "default" def _evict_if_needed(self) -> None: """Evict least recently used threads if over the limit. Must be called while holding self._lock. """ while len(self._history) > self.max_tracked_threads: evicted_id, _ = self._history.popitem(last=False) self._warned.pop(evicted_id, None) self._tool_freq.pop(evicted_id, None) self._tool_freq_warned.pop(evicted_id, None) logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: """Track tool calls and check for loops. Two detection layers: 1. **Hash-based** (existing): catches identical tool call sets. 2. **Frequency-based** (new): catches the same *tool type* being called many times with varying arguments (e.g. ``read_file`` on 40 different files). Returns: (warning_message_or_none, should_hard_stop) """ messages = state.get("messages", []) if not messages: return None, False last_msg = messages[-1] if getattr(last_msg, "type", None) != "ai": return None, False tool_calls = getattr(last_msg, "tool_calls", None) if not tool_calls: return None, False thread_id = self._get_thread_id(runtime) call_hash = _hash_tool_calls(tool_calls) with self._lock: # Touch / create entry (move to end for LRU) if thread_id in self._history: self._history.move_to_end(thread_id) else: self._history[thread_id] = [] self._evict_if_needed() history = self._history[thread_id] history.append(call_hash) if len(history) > self.window_size: history[:] = history[-self.window_size :] count = history.count(call_hash) tool_names = [tc.get("name", "?") for tc in tool_calls] # --- Layer 1: hash-based (identical call sets) --- if count >= self.hard_limit: logger.error( "Loop hard limit reached — forcing stop", extra={ "thread_id": thread_id, "call_hash": call_hash, "count": count, "tools": tool_names, }, ) return _HARD_STOP_MSG, True if count >= self.warn_threshold: warned = self._warned[thread_id] if call_hash not in warned: warned.add(call_hash) logger.warning( "Repetitive tool calls detected — injecting warning", extra={ "thread_id": thread_id, "call_hash": call_hash, "count": count, "tools": tool_names, }, ) return _WARNING_MSG, False # --- Layer 2: per-tool-type frequency --- freq = self._tool_freq[thread_id] for tc in tool_calls: name = tc.get("name", "") if not name: continue freq[name] += 1 tc_count = freq[name] if name in self._tool_freq_overrides: eff_warn, eff_hard = self._tool_freq_overrides[name] else: eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit if tc_count >= eff_hard: logger.error( "Tool frequency hard limit reached — forcing stop", extra={ "thread_id": thread_id, "tool_name": name, "count": tc_count, }, ) return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True if tc_count >= eff_warn: warned = self._tool_freq_warned[thread_id] if name not in warned: warned.add(name) logger.warning( "Tool frequency warning — too many calls to same tool type", extra={ "thread_id": thread_id, "tool_name": name, "count": tc_count, }, ) return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False return None, False @staticmethod def _append_text(content: str | list | None, text: str) -> str | list: """Append *text* to AIMessage content, handling str, list, and None. When content is a list of content blocks (e.g. Anthropic thinking mode), we append a new ``{"type": "text", ...}`` block instead of concatenating a string to a list, which would raise ``TypeError``. """ if content is None: return text if isinstance(content, list): return [*content, {"type": "text", "text": f"\n\n{text}"}] if isinstance(content, str): return content + f"\n\n{text}" # Fallback: coerce unexpected types to str to avoid TypeError return str(content) + f"\n\n{text}" @staticmethod def _build_hard_stop_update(last_msg, content: str | list) -> dict: """Clear tool-call metadata so forced-stop messages serialize as plain assistant text.""" update = { "tool_calls": [], "content": content, } additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {}) for key in ("tool_calls", "function_call"): additional_kwargs.pop(key, None) update["additional_kwargs"] = additional_kwargs response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {}) if response_metadata.get("finish_reason") == "tool_calls": response_metadata["finish_reason"] = "stop" update["response_metadata"] = response_metadata return update def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: warning, hard_stop = self._track_and_check(state, runtime) if hard_stop: # Strip tool_calls from the last AIMessage to force text output messages = state.get("messages", []) last_msg = messages[-1] content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG) stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content)) return {"messages": [stripped_msg]} if warning: # WORKAROUND for v2.0-m1 — see #2724. # # Append the warning to the AIMessage content instead of # injecting a separate HumanMessage. Inserting any non-tool # message between an AIMessage(tool_calls=...) and its # ToolMessage responses breaks OpenAI/Moonshot strict pairing # validation ("tool_call_ids did not have response messages") # because the tools node has not run yet at after_model time. # tool_calls are preserved so the tools node still executes. # # This is a temporary mitigation: mutating an existing # AIMessage to carry framework-authored text leaks loop-warning # text into downstream consumers (MemoryMiddleware fact # extraction, TitleMiddleware, telemetry, model replay) as if # the model said it. The proper fix is to defer warning # injection from after_model to wrap_model_call so every prior # ToolMessage is already in the request — see RFC #2517 (which # lists "loop intervention does not leave invalid # tool-call/tool-message state" as acceptance criteria) and # the prototype on `fix/loop-detection-tool-call-pairing`. messages = state.get("messages", []) last_msg = messages[-1] patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)}) return {"messages": [patched_msg]} return None @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) @override async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) def reset(self, thread_id: str | None = None) -> None: """Clear tracking state. If thread_id given, clear only that thread.""" with self._lock: if thread_id: self._history.pop(thread_id, None) self._warned.pop(thread_id, None) self._tool_freq.pop(thread_id, None) self._tool_freq_warned.pop(thread_id, None) else: self._history.clear() self._warned.clear() self._tool_freq.clear() self._tool_freq_warned.clear()