mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711)
* Make loop detection configurable Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled. Refs bytedance/deer-flow#2517 * feat(loop-detection): add per-tool tool_freq_overrides to Phase 1 Adds ToolFreqOverride model and tool_freq_overrides field to LoopDetectionConfig, wires it through LoopDetectionMiddleware, and documents the option in config.example.yaml. Resolves the gap flagged in the #2586 review: without per-tool overrides, users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit) had no way to raise thresholds for one tool without loosening the global limit for every tool. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring Add the missing Args entry for tool_freq_overrides, explaining the (warn, hard_limit) tuple structure and how per-tool thresholds supersede the global tool_freq_warn / tool_freq_hard_limit for named tools. Also run ruff format on the three files flagged by the lint check. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly Raise clear ValueError at construction time instead of crashing at unpack-time inside _track_and_check when bad values are passed: - tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn - scalar thresholds: warn_threshold, hard_limit, tool_freq_warn, tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs - window_size, max_tracked_threads must be >= 1 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(test): isolate credential loader directory-path test from real ~/.claude The test didn't monkeypatch HOME, so on any machine with real Claude Code credentials at ~/.claude/.credentials.json the function fell through to those credentials and the assertion failed. Adding HOME redirect ensures the default credential path doesn't exist during the test. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style(test): add blank lines after import pytest in TestInitValidation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(loop-detection): collapse dual validation to LoopDetectionConfig Modifications - LoopDetectionMiddleware.__init__: stripped of all ValueError raises; becomes a plain field-assignment constructor. - LoopDetectionMiddleware.from_config: classmethod that builds the middleware from a Pydantic-validated LoopDetectionConfig and handles the ToolFreqOverride -> tuple[int, int] conversion. - agents/factory.py: SDK construction routed through LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the defaults path is Pydantic-validated too. - agents/lead_agent/agent.py: uses from_config instead of unpacking config fields by hand. - tests/test_loop_detection_middleware.py: deleted TestInitValidation (16 methods exercising the removed __init__ checks); added TestFromConfig (4 tests: scalar field mapping, override tuple conversion, empty overrides, behavioral smoke test). Result: one validation layer (Pydantic), zero duplication, no __new__ hacks. Both production construction sites flow through LoopDetectionConfig. Test results make test -> 2977 passed, 18 skipped, 0 failed (137s) make format -> All checks passed; 411 files left unchanged * feat(agents): make loop_detection configurable in create_deerflow_agent Adds a `loop_detection: bool | AgentMiddleware = True` field to RuntimeFeatures, mirroring the existing pattern used by `sandbox`, `memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware or replace it with a custom instance built from their own LoopDetectionConfig — e.g. `LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck with the hardcoded defaults previously installed by the SDK factory. The lead-agent path (which already reads AppConfig.loop_detection) is unchanged, and the default `True` preserves prior always-on behavior for all existing callers. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> --------- Co-authored-by: knight0940 <631532668@qq.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -173,7 +173,7 @@ def _assemble_from_features(
|
||||
9. MemoryMiddleware (memory feature)
|
||||
10. ViewImageMiddleware (vision feature)
|
||||
11. SubagentLimitMiddleware (subagent feature)
|
||||
12. LoopDetectionMiddleware (always)
|
||||
12. LoopDetectionMiddleware (loop_detection feature)
|
||||
13. ClarificationMiddleware (always last)
|
||||
|
||||
Two-phase ordering:
|
||||
@@ -272,10 +272,15 @@ def _assemble_from_features(
|
||||
|
||||
extra_tools.append(task_tool)
|
||||
|
||||
# --- [12] LoopDetection (always) ---
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
# --- [12] LoopDetection ---
|
||||
if feat.loop_detection is not False:
|
||||
if isinstance(feat.loop_detection, AgentMiddleware):
|
||||
chain.append(feat.loop_detection)
|
||||
else:
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
chain.append(LoopDetectionMiddleware())
|
||||
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
|
||||
|
||||
# --- [13] Clarification (always last among built-ins) ---
|
||||
chain.append(ClarificationMiddleware())
|
||||
|
||||
@@ -31,6 +31,7 @@ class RuntimeFeatures:
|
||||
vision: bool | AgentMiddleware = False
|
||||
auto_title: bool | AgentMiddleware = False
|
||||
guardrail: Literal[False] | AgentMiddleware = False
|
||||
loop_detection: bool | AgentMiddleware = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -299,7 +299,9 @@ def _build_middlewares(
|
||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
loop_detection_config = resolved_app_config.loop_detection
|
||||
if loop_detection_config.enabled:
|
||||
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
|
||||
|
||||
# Inject custom middlewares before ClarificationMiddleware
|
||||
if custom_middlewares:
|
||||
|
||||
@@ -12,18 +12,23 @@ Detection strategy:
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
@@ -139,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
|
||||
construct via :meth:`from_config` to ensure values pass Pydantic validation.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
@@ -154,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
Default: 30.
|
||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
||||
forcing a stop. Default: 50.
|
||||
tool_freq_overrides: Per-tool overrides for frequency thresholds,
|
||||
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
|
||||
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
|
||||
that specific tool. Tools not listed here fall back to the global
|
||||
thresholds. Useful for raising limits on intentionally
|
||||
high-frequency tools (e.g. ``bash`` in batch pipelines) without
|
||||
weakening protection on all other tools. Default: ``None``
|
||||
(no overrides).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -164,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
||||
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
@@ -172,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self.tool_freq_warn = tool_freq_warn
|
||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
||||
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread, per-tool-type cumulative call counts
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
||||
"""Construct from a Pydantic-validated config, trusting its validation."""
|
||||
return cls(
|
||||
warn_threshold=config.warn_threshold,
|
||||
hard_limit=config.hard_limit,
|
||||
window_size=config.window_size,
|
||||
max_tracked_threads=config.max_tracked_threads,
|
||||
tool_freq_warn=config.tool_freq_warn,
|
||||
tool_freq_hard_limit=config.tool_freq_hard_limit,
|
||||
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
|
||||
)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
@@ -279,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
freq[name] += 1
|
||||
tc_count = freq[name]
|
||||
|
||||
if tc_count >= self.tool_freq_hard_limit:
|
||||
if name in self._tool_freq_overrides:
|
||||
eff_warn, eff_hard = self._tool_freq_overrides[name]
|
||||
else:
|
||||
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
|
||||
|
||||
if tc_count >= eff_hard:
|
||||
logger.error(
|
||||
"Tool frequency hard limit reached — forcing stop",
|
||||
extra={
|
||||
@@ -290,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
)
|
||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
||||
|
||||
if tc_count >= self.tool_freq_warn:
|
||||
if tc_count >= eff_warn:
|
||||
warned = self._tool_freq_warned[thread_id]
|
||||
if name not in warned:
|
||||
warned.add(name)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .app_config import get_app_config
|
||||
from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .loop_detection_config import LoopDetectionConfig
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skill_evolution_config import SkillEvolutionConfig
|
||||
@@ -20,6 +21,7 @@ __all__ = [
|
||||
"SkillsConfig",
|
||||
"ExtensionsConfig",
|
||||
"get_extensions_config",
|
||||
"LoopDetectionConfig",
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
|
||||
@@ -15,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
@@ -100,6 +101,7 @@ class AppConfig(BaseModel):
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Configuration for loop detection middleware."""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class ToolFreqOverride(BaseModel):
|
||||
"""Per-tool frequency threshold override.
|
||||
|
||||
Can be higher or lower than the global defaults. Commonly used to raise
|
||||
thresholds for high-frequency tools like bash in batch workflows (e.g.
|
||||
RNA-seq pipelines) without weakening protection on every other tool.
|
||||
"""
|
||||
|
||||
warn: int = Field(ge=1)
|
||||
hard_limit: int = Field(ge=1)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self) -> "ToolFreqOverride":
|
||||
if self.hard_limit < self.warn:
|
||||
raise ValueError("hard_limit must be >= warn")
|
||||
return self
|
||||
|
||||
|
||||
class LoopDetectionConfig(BaseModel):
|
||||
"""Configuration for repetitive tool-call loop detection."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable repetitive tool-call loop detection",
|
||||
)
|
||||
warn_threshold: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
description="Number of identical tool-call sets before injecting a warning",
|
||||
)
|
||||
hard_limit: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Number of identical tool-call sets before forcing a stop",
|
||||
)
|
||||
window_size: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
description="Number of recent tool-call sets to track per thread",
|
||||
)
|
||||
max_tracked_threads: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
description="Maximum number of thread histories to keep in memory",
|
||||
)
|
||||
tool_freq_warn: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
description="Number of calls to the same tool type before injecting a frequency warning",
|
||||
)
|
||||
tool_freq_hard_limit: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
description="Number of calls to the same tool type before forcing a stop",
|
||||
)
|
||||
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
|
||||
default_factory=dict,
|
||||
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_thresholds(self) -> "LoopDetectionConfig":
|
||||
"""Ensure hard stop cannot happen before the warning threshold."""
|
||||
if self.hard_limit < self.warn_threshold:
|
||||
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
|
||||
if self.tool_freq_hard_limit < self.tool_freq_warn:
|
||||
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
|
||||
return self
|
||||
Reference in New Issue
Block a user