"""Middleware to detect and break repetitive tool call loops. P0 safety: prevents the agent from calling the same tool with the same arguments indefinitely until the recursion limit kills the run. Detection strategy: 1. After each model response, hash the tool calls (name + args). 2. Track recent hashes in a sliding window. 3. If the same hash appears >= warn_threshold times, queue a "you are repeating yourself — wrap up" warning for the current thread/run. The warning is **injected at the next model call** (in ``wrap_model_call``) as a ``HumanMessage`` appended to the message list, *after* all ToolMessage responses to the previous AIMessage(tool_calls). 4. If it appears >= hard_limit times, strip all tool_calls from the response so the agent is forced to produce a final text answer. Why the warning is injected at ``wrap_model_call`` instead of ``after_model``: ``after_model`` fires immediately after the model emits an ``AIMessage`` that may carry ``tool_calls``. The tools node has not run yet, so no matching ``ToolMessage`` exists in the history. Any message we add here lands *between* the assistant's tool_calls and their responses. OpenAI/Moonshot reject the next request with ``"tool_call_ids did not have response messages"`` because their validators require the assistant's tool_calls to be followed immediately by tool messages. Anthropic also disallows mid-stream ``SystemMessage``. By deferring the warning to ``wrap_model_call``, every prior ToolMessage is already present in the request's message list and the warning is appended at the end — pairing intact, no ``AIMessage`` semantics are mutated. Queued warnings are intentionally transient. If a run ends before the next model request drains a queued warning, ``after_agent`` drops it instead of carrying it into a later invocation for the same thread. The hard-stop path still forces termination when the configured safety limit is reached. """ from __future__ import annotations import hashlib import json import logging import threading from collections import OrderedDict, defaultdict from collections.abc import Awaitable, Callable from copy import deepcopy from typing import TYPE_CHECKING, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime if TYPE_CHECKING: from deerflow.config.loop_detection_config import LoopDetectionConfig logger = logging.getLogger(__name__) # Defaults — can be overridden via constructor _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls _DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls _DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit _DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type _DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type _MAX_PENDING_WARNINGS_PER_RUN = 4 def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]: """Normalize tool call args to a dict plus an optional fallback key. Some providers serialize ``args`` as a JSON string instead of a dict. We defensively parse those cases so loop detection does not crash while still preserving a stable fallback key for non-dict payloads. """ if isinstance(raw_args, dict): return raw_args, None if isinstance(raw_args, str): try: parsed = json.loads(raw_args) except (TypeError, ValueError, json.JSONDecodeError): return {}, raw_args if isinstance(parsed, dict): return parsed, None return {}, json.dumps(parsed, sort_keys=True, default=str) if raw_args is None: return {}, None return {}, json.dumps(raw_args, sort_keys=True, default=str) def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str: """Derive a stable key from salient args without overfitting to noise.""" if name == "read_file" and fallback_key is None: path = args.get("path") or "" start_line = args.get("start_line") end_line = args.get("end_line") bucket_size = 200 try: start_line = int(start_line) if start_line is not None else 1 except (TypeError, ValueError): start_line = 1 try: end_line = int(end_line) if end_line is not None else start_line except (TypeError, ValueError): end_line = start_line start_line, end_line = sorted((start_line, end_line)) bucket_start = max(start_line, 1) bucket_end = max(end_line, 1) bucket_start = (bucket_start - 1) // bucket_size bucket_end = (bucket_end - 1) // bucket_size return f"{path}:{bucket_start}-{bucket_end}" # write_file / str_replace are content-sensitive: same path may be updated # with different payloads during iteration. Using only salient fields (path) # can collapse distinct calls, so we hash full args to reduce false positives. if name in {"write_file", "str_replace"}: if fallback_key is not None: return fallback_key return json.dumps(args, sort_keys=True, default=str) salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd") stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None} if stable_args: return json.dumps(stable_args, sort_keys=True, default=str) if fallback_key is not None: return fallback_key return json.dumps(args, sort_keys=True, default=str) def _hash_tool_calls(tool_calls: list[dict]) -> str: """Deterministic hash of a set of tool calls (name + stable key). This is intended to be order-independent: the same multiset of tool calls should always produce the same hash, regardless of their input order. """ # Normalize each tool call to a stable (name, key) structure. normalized: list[str] = [] for tc in tool_calls: name = tc.get("name", "") args, fallback_key = _normalize_tool_call_args(tc.get("args", {})) key = _stable_tool_key(name, args, fallback_key) normalized.append(f"{name}:{key}") # Sort so permutations of the same multiset of calls yield the same ordering. normalized.sort() blob = json.dumps(normalized, sort_keys=True, default=str) return hashlib.md5(blob.encode()).hexdigest()[:12] _WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." _TOOL_FREQ_WARNING_MSG = ( "[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." ) _HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far." _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far." class LoopDetectionMiddleware(AgentMiddleware[AgentState]): """Detects and breaks repetitive tool call loops. Threshold parameters are validated upstream by :class:`LoopDetectionConfig`; construct via :meth:`from_config` to ensure values pass Pydantic validation. Args: warn_threshold: Number of identical tool call sets before injecting a warning message. Default: 3. hard_limit: Number of identical tool call sets before stripping tool_calls entirely. Default: 5. window_size: Size of the sliding window for tracking calls. Default: 20. max_tracked_threads: Maximum number of threads to track before evicting the least recently used. Default: 100. tool_freq_warn: Number of calls to the same tool *type* (regardless of arguments) before injecting a frequency warning. Catches cross-file read loops that hash-based detection misses. Default: 30. tool_freq_hard_limit: Number of calls to the same tool type before forcing a stop. Default: 50. tool_freq_overrides: Per-tool overrides for frequency thresholds, keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for that specific tool. Tools not listed here fall back to the global thresholds. Useful for raising limits on intentionally high-frequency tools (e.g. ``bash`` in batch pipelines) without weakening protection on all other tools. Default: ``None`` (no overrides). """ def __init__( self, warn_threshold: int = _DEFAULT_WARN_THRESHOLD, hard_limit: int = _DEFAULT_HARD_LIMIT, window_size: int = _DEFAULT_WINDOW_SIZE, max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN, tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT, tool_freq_overrides: dict[str, tuple[int, int]] | None = None, ): super().__init__() self.warn_threshold = warn_threshold self.hard_limit = hard_limit self.window_size = window_size self.max_tracked_threads = max_tracked_threads self.tool_freq_warn = tool_freq_warn self.tool_freq_hard_limit = tool_freq_hard_limit self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {} self._lock = threading.Lock() self._history: OrderedDict[str, list[str]] = OrderedDict() self._warned: dict[str, set[str]] = defaultdict(set) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) # Per-thread/run queue of warnings to inject at the next model call. # Populated by ``after_model`` (detection) and drained by # ``wrap_model_call`` (injection); see module docstring. self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list) self._pending_warning_touch_order: OrderedDict[tuple[str, str], None] = OrderedDict() self._max_pending_warning_keys = max(1, self.max_tracked_threads * 2) @classmethod def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware: """Construct from a Pydantic-validated config, trusting its validation.""" return cls( warn_threshold=config.warn_threshold, hard_limit=config.hard_limit, window_size=config.window_size, max_tracked_threads=config.max_tracked_threads, tool_freq_warn=config.tool_freq_warn, tool_freq_hard_limit=config.tool_freq_hard_limit, tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()}, ) def _get_thread_id(self, runtime: Runtime) -> str: """Extract thread_id from runtime context for per-thread tracking.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id: return str(thread_id) return "default" def _get_run_id(self, runtime: Runtime) -> str: """Extract run_id from runtime context for per-run warning scoping.""" run_id = runtime.context.get("run_id") if runtime.context else None if run_id: return str(run_id) return "default" def _pending_key(self, runtime: Runtime) -> tuple[str, str]: """Return the pending-warning key for the current thread/run.""" return self._get_thread_id(runtime), self._get_run_id(runtime) def _evict_if_needed(self) -> None: """Evict least recently used threads if over the limit. Must be called while holding self._lock. """ while len(self._history) > self.max_tracked_threads: evicted_id, _ = self._history.popitem(last=False) self._warned.pop(evicted_id, None) self._tool_freq.pop(evicted_id, None) self._tool_freq_warned.pop(evicted_id, None) for key in list(self._pending_warnings): if key[0] == evicted_id: self._drop_pending_warning_key_locked(key) logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) def _drop_pending_warning_key_locked(self, key: tuple[str, str]) -> None: """Drop all pending-warning bookkeeping for one thread/run key. Must be called while holding self._lock. """ self._pending_warnings.pop(key, None) self._pending_warning_touch_order.pop(key, None) def _touch_pending_warning_key_locked(self, key: tuple[str, str]) -> None: """Mark a pending-warning key as recently used. Must be called while holding self._lock. """ self._pending_warning_touch_order[key] = None self._pending_warning_touch_order.move_to_end(key) def _prune_pending_warning_state_locked(self, protected_key: tuple[str, str]) -> None: """Cap pending-warning state across abnormal or concurrent runs. Must be called while holding self._lock. """ overflow = len(self._pending_warning_touch_order) - self._max_pending_warning_keys if overflow <= 0: return candidates = [key for key in self._pending_warning_touch_order if key != protected_key] for key in candidates[:overflow]: self._drop_pending_warning_key_locked(key) def _queue_pending_warning(self, runtime: Runtime, warning: str) -> None: """Queue one transient warning for the current thread/run with caps.""" pending_key = self._pending_key(runtime) with self._lock: warnings = self._pending_warnings[pending_key] if warning not in warnings: warnings.append(warning) if len(warnings) > _MAX_PENDING_WARNINGS_PER_RUN: del warnings[: len(warnings) - _MAX_PENDING_WARNINGS_PER_RUN] self._touch_pending_warning_key_locked(pending_key) self._prune_pending_warning_state_locked(protected_key=pending_key) def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: """Track tool calls and check for loops. Two detection layers: 1. **Hash-based** (existing): catches identical tool call sets. 2. **Frequency-based** (new): catches the same *tool type* being called many times with varying arguments (e.g. ``read_file`` on 40 different files). Returns: (warning_message_or_none, should_hard_stop) """ messages = state.get("messages", []) if not messages: return None, False last_msg = messages[-1] if getattr(last_msg, "type", None) != "ai": return None, False tool_calls = getattr(last_msg, "tool_calls", None) if not tool_calls: return None, False thread_id = self._get_thread_id(runtime) call_hash = _hash_tool_calls(tool_calls) with self._lock: # Touch / create entry (move to end for LRU) if thread_id in self._history: self._history.move_to_end(thread_id) else: self._history[thread_id] = [] self._evict_if_needed() history = self._history[thread_id] history.append(call_hash) if len(history) > self.window_size: history[:] = history[-self.window_size :] warned_hashes = self._warned.get(thread_id) if warned_hashes is not None: warned_hashes.intersection_update(history) if not warned_hashes: self._warned.pop(thread_id, None) count = history.count(call_hash) tool_names = [tc.get("name", "?") for tc in tool_calls] # --- Layer 1: hash-based (identical call sets) --- if count >= self.hard_limit: logger.error( "Loop hard limit reached — forcing stop", extra={ "thread_id": thread_id, "call_hash": call_hash, "count": count, "tools": tool_names, }, ) return _HARD_STOP_MSG, True if count >= self.warn_threshold: warned = self._warned[thread_id] if call_hash not in warned: warned.add(call_hash) logger.warning( "Repetitive tool calls detected — injecting warning", extra={ "thread_id": thread_id, "call_hash": call_hash, "count": count, "tools": tool_names, }, ) return _WARNING_MSG, False # --- Layer 2: per-tool-type frequency --- freq = self._tool_freq[thread_id] for tc in tool_calls: name = tc.get("name", "") if not name: continue freq[name] += 1 tc_count = freq[name] if name in self._tool_freq_overrides: eff_warn, eff_hard = self._tool_freq_overrides[name] else: eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit if tc_count >= eff_hard: logger.error( "Tool frequency hard limit reached — forcing stop", extra={ "thread_id": thread_id, "tool_name": name, "count": tc_count, }, ) return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True if tc_count >= eff_warn: warned = self._tool_freq_warned[thread_id] if name not in warned: warned.add(name) logger.warning( "Tool frequency warning — too many calls to same tool type", extra={ "thread_id": thread_id, "tool_name": name, "count": tc_count, }, ) return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False return None, False @staticmethod def _append_text(content: str | list | None, text: str) -> str | list: """Append *text* to AIMessage content, handling str, list, and None. When content is a list of content blocks (e.g. Anthropic thinking mode), we append a new ``{"type": "text", ...}`` block instead of concatenating a string to a list, which would raise ``TypeError``. """ if content is None: return text if isinstance(content, list): return [*content, {"type": "text", "text": f"\n\n{text}"}] if isinstance(content, str): return content + f"\n\n{text}" # Fallback: coerce unexpected types to str to avoid TypeError return str(content) + f"\n\n{text}" @staticmethod def _build_hard_stop_update(last_msg, content: str | list) -> dict: """Clear tool-call metadata so forced-stop messages serialize as plain assistant text.""" update = { "tool_calls": [], "content": content, } additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {}) for key in ("tool_calls", "function_call"): additional_kwargs.pop(key, None) update["additional_kwargs"] = additional_kwargs response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {}) if response_metadata.get("finish_reason") == "tool_calls": response_metadata["finish_reason"] = "stop" update["response_metadata"] = response_metadata return update def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: warning, hard_stop = self._track_and_check(state, runtime) if hard_stop: # Strip tool_calls from the last AIMessage to force text output. # Once tool_calls are stripped, the AIMessage no longer requires # matching ToolMessage responses, so mutating it in place here # is safe for OpenAI/Moonshot pairing validators. messages = state.get("messages", []) last_msg = messages[-1] content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG) stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content)) return {"messages": [stripped_msg]} if warning: # Defer injection to the next model call. We must NOT alter the # AIMessage(tool_calls=...) here (would put framework words in # the model's mouth, polluting downstream consumers like # MemoryMiddleware), nor insert a separate non-tool message # (would break OpenAI/Moonshot tool-call pairing because the # tools node has not produced ToolMessage responses yet). The # warning is delivered via ``wrap_model_call`` below. self._queue_pending_warning(runtime, warning) return None return None def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None: """Drop stale pending warnings for previous runs in this thread.""" thread_id, current_run_id = self._pending_key(runtime) with self._lock: for key in list(self._pending_warnings): if key[0] == thread_id and key[1] != current_run_id: self._drop_pending_warning_key_locked(key) def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None: """Drop pending warnings owned by the current thread/run.""" pending_key = self._pending_key(runtime) with self._lock: self._drop_pending_warning_key_locked(pending_key) @staticmethod def _format_warning_message(warnings: list[str]) -> str: """Merge pending warnings into one prompt message.""" deduped = list(dict.fromkeys(warnings)) return "\n\n".join(deduped) @override def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None: self._clear_other_run_pending_warnings(runtime) return None @override async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None: self._clear_other_run_pending_warnings(runtime) return None @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) @override async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) @override def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None: self._clear_current_run_pending_warnings(runtime) return None @override async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None: self._clear_current_run_pending_warnings(runtime) return None def _drain_pending_warnings(self, runtime: Runtime) -> list[str]: """Pop and return all queued warnings for *runtime*'s thread/run.""" pending_key = self._pending_key(runtime) with self._lock: warnings = self._pending_warnings.pop(pending_key, []) self._pending_warning_touch_order.pop(pending_key, None) return warnings def _augment_request(self, request: ModelRequest) -> ModelRequest: """Append queued loop warnings (if any) to the outgoing message list. The warning is placed *after* every existing message, including the ToolMessage responses to the previous AIMessage(tool_calls). This keeps ``assistant tool_calls -> tool_messages`` pairing intact for OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage restriction (we use HumanMessage), and never mutates an existing AIMessage. """ warnings = self._drain_pending_warnings(request.runtime) if not warnings: return request new_messages = [ *request.messages, HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"), ] return request.override(messages=new_messages) @override def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: return handler(self._augment_request(request)) @override async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: return await handler(self._augment_request(request)) def reset(self, thread_id: str | None = None) -> None: """Clear tracking state. If thread_id given, clear only that thread.""" with self._lock: if thread_id: self._history.pop(thread_id, None) self._warned.pop(thread_id, None) self._tool_freq.pop(thread_id, None) self._tool_freq_warned.pop(thread_id, None) for key in list(self._pending_warnings): if key[0] == thread_id: self._drop_pending_warning_key_locked(key) else: self._history.clear() self._warned.clear() self._tool_freq.clear() self._tool_freq_warned.clear() self._pending_warnings.clear() self._pending_warning_touch_order.clear()