feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711)

* Make loop detection configurable

Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled.

Refs bytedance/deer-flow#2517

* feat(loop-detection): add per-tool tool_freq_overrides to Phase 1

Adds ToolFreqOverride model and tool_freq_overrides field to
LoopDetectionConfig, wires it through LoopDetectionMiddleware, and
documents the option in config.example.yaml.

Resolves the gap flagged in the #2586 review: without per-tool overrides,
users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit)
had no way to raise thresholds for one tool without loosening the global
limit for every tool.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring

Add the missing Args entry for tool_freq_overrides, explaining the
(warn, hard_limit) tuple structure and how per-tool thresholds supersede
the global tool_freq_warn / tool_freq_hard_limit for named tools.
Also run ruff format on the three files flagged by the lint check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly

Raise clear ValueError at construction time instead of crashing at
unpack-time inside _track_and_check when bad values are passed:
- tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn
- scalar thresholds: warn_threshold, hard_limit, tool_freq_warn,
  tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs
- window_size, max_tracked_threads must be >= 1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(test): isolate credential loader directory-path test from real ~/.claude

The test didn't monkeypatch HOME, so on any machine with real Claude Code
credentials at ~/.claude/.credentials.json the function fell through to
those credentials and the assertion failed. Adding HOME redirect ensures
the default credential path doesn't exist during the test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style(test): add blank lines after import pytest in TestInitValidation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(loop-detection): collapse dual validation to LoopDetectionConfig

Modifications
  - LoopDetectionMiddleware.__init__: stripped of all ValueError raises;
    becomes a plain field-assignment constructor.
  - LoopDetectionMiddleware.from_config: classmethod that builds the
    middleware from a Pydantic-validated LoopDetectionConfig and handles
    the ToolFreqOverride -> tuple[int, int] conversion.
  - agents/factory.py: SDK construction routed through
    LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the
    defaults path is Pydantic-validated too.
  - agents/lead_agent/agent.py: uses from_config instead of unpacking
    config fields by hand.
  - tests/test_loop_detection_middleware.py: deleted TestInitValidation
    (16 methods exercising the removed __init__ checks); added
    TestFromConfig (4 tests: scalar field mapping, override tuple
    conversion, empty overrides, behavioral smoke test).

Result: one validation layer (Pydantic), zero duplication, no __new__
hacks. Both production construction sites flow through LoopDetectionConfig.

Test results
  make test   -> 2977 passed, 18 skipped, 0 failed (137s)
  make format -> All checks passed; 411 files left unchanged

* feat(agents): make loop_detection configurable in create_deerflow_agent

Adds a `loop_detection: bool | AgentMiddleware = True` field to
RuntimeFeatures, mirroring the existing pattern used by `sandbox`,
`memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware
or replace it with a custom instance built from their own
LoopDetectionConfig — e.g.
`LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck
with the hardcoded defaults previously installed by the SDK factory.

The lead-agent path (which already reads AppConfig.loop_detection) is
unchanged, and the default `True` preserves prior always-on behavior for
all existing callers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

---------

Co-authored-by: knight0940 <631532668@qq.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Tao Liu
2026-05-07 16:15:15 +08:00
committed by GitHub
parent 27559f3675
commit daa3ffc29b
13 changed files with 406 additions and 12 deletions
@@ -173,7 +173,7 @@ def _assemble_from_features(
9. MemoryMiddleware (memory feature) 9. MemoryMiddleware (memory feature)
10. ViewImageMiddleware (vision feature) 10. ViewImageMiddleware (vision feature)
11. SubagentLimitMiddleware (subagent feature) 11. SubagentLimitMiddleware (subagent feature)
12. LoopDetectionMiddleware (always) 12. LoopDetectionMiddleware (loop_detection feature)
13. ClarificationMiddleware (always last) 13. ClarificationMiddleware (always last)
Two-phase ordering: Two-phase ordering:
@@ -272,10 +272,15 @@ def _assemble_from_features(
extra_tools.append(task_tool) extra_tools.append(task_tool)
# --- [12] LoopDetection (always) --- # --- [12] LoopDetection ---
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware if feat.loop_detection is not False:
if isinstance(feat.loop_detection, AgentMiddleware):
chain.append(feat.loop_detection)
else:
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.loop_detection_config import LoopDetectionConfig
chain.append(LoopDetectionMiddleware()) chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
# --- [13] Clarification (always last among built-ins) --- # --- [13] Clarification (always last among built-ins) ---
chain.append(ClarificationMiddleware()) chain.append(ClarificationMiddleware())
@@ -31,6 +31,7 @@ class RuntimeFeatures:
vision: bool | AgentMiddleware = False vision: bool | AgentMiddleware = False
auto_title: bool | AgentMiddleware = False auto_title: bool | AgentMiddleware = False
guardrail: Literal[False] | AgentMiddleware = False guardrail: Literal[False] | AgentMiddleware = False
loop_detection: bool | AgentMiddleware = True
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -299,7 +299,9 @@ def _build_middlewares(
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents)) middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops # LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware()) loop_detection_config = resolved_app_config.loop_detection
if loop_detection_config.enabled:
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
# Inject custom middlewares before ClarificationMiddleware # Inject custom middlewares before ClarificationMiddleware
if custom_middlewares: if custom_middlewares:
@@ -12,18 +12,23 @@ Detection strategy:
response so the agent is forced to produce a final text answer. response so the agent is forced to produce a final text answer.
""" """
from __future__ import annotations
import hashlib import hashlib
import json import json
import logging import logging
import threading import threading
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy from copy import deepcopy
from typing import override from typing import TYPE_CHECKING, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.loop_detection_config import LoopDetectionConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor # Defaults — can be overridden via constructor
@@ -139,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
class LoopDetectionMiddleware(AgentMiddleware[AgentState]): class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops. """Detects and breaks repetitive tool call loops.
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
construct via :meth:`from_config` to ensure values pass Pydantic validation.
Args: Args:
warn_threshold: Number of identical tool call sets before injecting warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3. a warning message. Default: 3.
@@ -154,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 30. Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50. forcing a stop. Default: 50.
tool_freq_overrides: Per-tool overrides for frequency thresholds,
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
that specific tool. Tools not listed here fall back to the global
thresholds. Useful for raising limits on intentionally
high-frequency tools (e.g. ``bash`` in batch pipelines) without
weakening protection on all other tools. Default: ``None``
(no overrides).
""" """
def __init__( def __init__(
@@ -164,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN, tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT, tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
): ):
super().__init__() super().__init__()
self.warn_threshold = warn_threshold self.warn_threshold = warn_threshold
@@ -172,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self.max_tracked_threads = max_tracked_threads self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit self.tool_freq_hard_limit = tool_freq_hard_limit
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
self._lock = threading.Lock() self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict() self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set) self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
@classmethod
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
"""Construct from a Pydantic-validated config, trusting its validation."""
return cls(
warn_threshold=config.warn_threshold,
hard_limit=config.hard_limit,
window_size=config.window_size,
max_tracked_threads=config.max_tracked_threads,
tool_freq_warn=config.tool_freq_warn,
tool_freq_hard_limit=config.tool_freq_hard_limit,
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
)
def _get_thread_id(self, runtime: Runtime) -> str: def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking.""" """Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = runtime.context.get("thread_id") if runtime.context else None
@@ -279,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
freq[name] += 1 freq[name] += 1
tc_count = freq[name] tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit: if name in self._tool_freq_overrides:
eff_warn, eff_hard = self._tool_freq_overrides[name]
else:
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
if tc_count >= eff_hard:
logger.error( logger.error(
"Tool frequency hard limit reached — forcing stop", "Tool frequency hard limit reached — forcing stop",
extra={ extra={
@@ -290,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
) )
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn: if tc_count >= eff_warn:
warned = self._tool_freq_warned[thread_id] warned = self._tool_freq_warned[thread_id]
if name not in warned: if name not in warned:
warned.add(name) warned.add(name)
@@ -1,5 +1,6 @@
from .app_config import get_app_config from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config from .extensions_config import ExtensionsConfig, get_extensions_config
from .loop_detection_config import LoopDetectionConfig
from .memory_config import MemoryConfig, get_memory_config from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig from .skill_evolution_config import SkillEvolutionConfig
@@ -20,6 +21,7 @@ __all__ = [
"SkillsConfig", "SkillsConfig",
"ExtensionsConfig", "ExtensionsConfig",
"get_extensions_config", "get_extensions_config",
"LoopDetectionConfig",
"MemoryConfig", "MemoryConfig",
"get_memory_config", "get_memory_config",
"get_tracing_config", "get_tracing_config",
@@ -15,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
from deerflow.config.database_config import DatabaseConfig from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.run_events_config import RunEventsConfig
@@ -100,6 +101,7 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -0,0 +1,73 @@
"""Configuration for loop detection middleware."""
from pydantic import BaseModel, Field, model_validator
class ToolFreqOverride(BaseModel):
"""Per-tool frequency threshold override.
Can be higher or lower than the global defaults. Commonly used to raise
thresholds for high-frequency tools like bash in batch workflows (e.g.
RNA-seq pipelines) without weakening protection on every other tool.
"""
warn: int = Field(ge=1)
hard_limit: int = Field(ge=1)
@model_validator(mode="after")
def _validate(self) -> "ToolFreqOverride":
if self.hard_limit < self.warn:
raise ValueError("hard_limit must be >= warn")
return self
class LoopDetectionConfig(BaseModel):
"""Configuration for repetitive tool-call loop detection."""
enabled: bool = Field(
default=True,
description="Whether to enable repetitive tool-call loop detection",
)
warn_threshold: int = Field(
default=3,
ge=1,
description="Number of identical tool-call sets before injecting a warning",
)
hard_limit: int = Field(
default=5,
ge=1,
description="Number of identical tool-call sets before forcing a stop",
)
window_size: int = Field(
default=20,
ge=1,
description="Number of recent tool-call sets to track per thread",
)
max_tracked_threads: int = Field(
default=100,
ge=1,
description="Maximum number of thread histories to keep in memory",
)
tool_freq_warn: int = Field(
default=30,
ge=1,
description="Number of calls to the same tool type before injecting a frequency warning",
)
tool_freq_hard_limit: int = Field(
default=50,
ge=1,
description="Number of calls to the same tool type before forcing a stop",
)
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
default_factory=dict,
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
)
@model_validator(mode="after")
def validate_thresholds(self) -> "LoopDetectionConfig":
"""Ensure hard stop cannot happen before the warning threshold."""
if self.hard_limit < self.warn_threshold:
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
if self.tool_freq_hard_limit < self.tool_freq_warn:
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
return self
@@ -192,6 +192,7 @@ def test_agent_features_defaults():
assert f.vision is False assert f.vision is False
assert f.auto_title is False assert f.auto_title is False
assert f.guardrail is False assert f.guardrail is False
assert f.loop_detection is True
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
assert loop_idx == clar_idx - 1 assert loop_idx == clar_idx - 1
# ---------------------------------------------------------------------------
# 30b. loop_detection=False skips LoopDetectionMiddleware
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_disabled(mock_create_agent):
mock_create_agent.return_value = MagicMock()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=False),
)
call_kwargs = mock_create_agent.call_args[1]
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
assert "LoopDetectionMiddleware" not in mw_types
# ---------------------------------------------------------------------------
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_custom_middleware(mock_create_agent):
from langchain.agents.middleware import AgentMiddleware as AM
mock_create_agent.return_value = MagicMock()
class MyLoopDetection(AM):
pass
custom = MyLoopDetection()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
)
call_kwargs = mock_create_agent.call_args[1]
middleware = call_kwargs["middleware"]
assert custom in middleware
mw_types = [type(m).__name__ for m in middleware]
# Default LoopDetectionMiddleware must not also appear.
assert "LoopDetectionMiddleware" not in mw_types
# Custom replacement still sits immediately before ClarificationMiddleware.
assert mw_types[-1] == "ClarificationMiddleware"
assert mw_types[-2] == "MyLoopDetection"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 31. plan_mode=True adds TodoMiddleware # 31. plan_mode=True adds TodoMiddleware
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+2
View File
@@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch): def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
_clear_claude_code_env(monkeypatch) _clear_claude_code_env(monkeypatch)
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
monkeypatch.setenv("HOME", str(tmp_path))
cred_dir = tmp_path / "claude-creds-dir" cred_dir = tmp_path / "claude-creds-dir"
cred_dir.mkdir() cred_dir.mkdir()
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir)) monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
@@ -8,17 +8,20 @@ from unittest.mock import MagicMock
import pytest import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import AppConfig
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig from deerflow.config.summarization_config import SummarizationConfig
def _make_app_config(models: list[ModelConfig]) -> AppConfig: def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
return AppConfig( return AppConfig(
models=models, models=models,
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"), sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
loop_detection=loop_detection or LoopDetectionConfig(),
) )
@@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
assert middlewares[0] == "base-middleware" assert middlewares[0] == "base-middleware"
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(
warn_threshold=7,
hard_limit=9,
window_size=30,
max_tracked_threads=40,
tool_freq_warn=50,
tool_freq_hard_limit=60,
),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
assert loop_detection.warn_threshold == 7
assert loop_detection.hard_limit == 9
assert loop_detection.window_size == 30
assert loop_detection.max_tracked_threads == 40
assert loop_detection.tool_freq_warn == 50
assert loop_detection.tool_freq_hard_limit == 60
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(enabled=False),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch): def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)]) app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork") app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
@@ -0,0 +1,72 @@
"""Tests for loop detection configuration."""
import pytest
from deerflow.config.loop_detection_config import LoopDetectionConfig
class TestLoopDetectionConfig:
def test_defaults_match_middleware_defaults(self):
config = LoopDetectionConfig()
assert config.enabled is True
assert config.warn_threshold == 3
assert config.hard_limit == 5
assert config.window_size == 20
assert config.max_tracked_threads == 100
assert config.tool_freq_warn == 30
assert config.tool_freq_hard_limit == 50
def test_accepts_custom_values(self):
config = LoopDetectionConfig(
enabled=False,
warn_threshold=10,
hard_limit=20,
window_size=50,
max_tracked_threads=200,
tool_freq_warn=60,
tool_freq_hard_limit=80,
)
assert config.enabled is False
assert config.warn_threshold == 10
assert config.hard_limit == 20
assert config.window_size == 50
assert config.max_tracked_threads == 200
assert config.tool_freq_warn == 60
assert config.tool_freq_hard_limit == 80
def test_rejects_zero_thresholds(self):
with pytest.raises(ValueError):
LoopDetectionConfig(warn_threshold=0)
with pytest.raises(ValueError):
LoopDetectionConfig(hard_limit=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_warn=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_hard_limit=0)
def test_rejects_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
def test_tool_freq_override_valid(self):
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
override = config.tool_freq_overrides["bash"]
assert override.warn == 150
assert override.hard_limit == 300
def test_tool_freq_override_rejects_zero_warn(self):
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
@@ -648,6 +648,37 @@ class TestToolFrequencyDetection:
assert result is not None assert result is not None
assert "read_file" in result["messages"][0].content assert "read_file" in result["messages"][0].content
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
mw = LoopDetectionMiddleware(
tool_freq_warn=5,
tool_freq_hard_limit=10,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
for i in range(10):
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None, f"unexpected trigger on call {i + 1}"
def test_non_override_tool_falls_back_to_global(self):
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
mw = LoopDetectionMiddleware(
tool_freq_warn=3,
tool_freq_hard_limit=6,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
def test_hash_detection_takes_priority(self): def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls.""" """Hash-based hard stop fires before frequency check for identical calls."""
mw = LoopDetectionMiddleware( mw = LoopDetectionMiddleware(
@@ -668,3 +699,48 @@ class TestToolFrequencyDetection:
msg = result["messages"][0] msg = result["messages"][0]
assert isinstance(msg, AIMessage) assert isinstance(msg, AIMessage)
assert _HARD_STOP_MSG in msg.content assert _HARD_STOP_MSG in msg.content
class TestFromConfig:
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
@staticmethod
def _config(**kwargs):
from deerflow.config.loop_detection_config import LoopDetectionConfig
return LoopDetectionConfig(**kwargs)
def test_scalar_fields_mapped(self):
config = self._config(
warn_threshold=4,
hard_limit=8,
window_size=15,
max_tracked_threads=50,
tool_freq_warn=20,
tool_freq_hard_limit=40,
)
mw = LoopDetectionMiddleware.from_config(config)
assert mw.warn_threshold == 4
assert mw.hard_limit == 8
assert mw.window_size == 15
assert mw.max_tracked_threads == 50
assert mw.tool_freq_warn == 20
assert mw.tool_freq_hard_limit == 40
def test_overrides_converted_to_tuples(self):
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
mw = LoopDetectionMiddleware.from_config(config)
assert mw._tool_freq_overrides == {"bash": (50, 100)}
def test_empty_overrides(self):
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
+24 -1
View File
@@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
# Bump this number when the config schema changes. # Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml. # Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 8 config_version: 9
# ============================================================================ # ============================================================================
# Logging # Logging
@@ -506,6 +506,29 @@ tools:
tool_search: tool_search:
enabled: false enabled: false
# ============================================================================
# Loop Detection Configuration
# ============================================================================
# Detect and interrupt repeated identical tool-call loops.
# Frequency thresholds are safety limits for repeated use of the same tool type.
loop_detection:
enabled: true
warn_threshold: 3
hard_limit: 5
window_size: 20
max_tracked_threads: 100
tool_freq_warn: 30
tool_freq_hard_limit: 50
# Per-tool overrides for tool_freq_warn / tool_freq_hard_limit. Values can be
# higher or lower than the global defaults. Commonly used to raise thresholds
# for high-frequency tools like bash in batch workflows (e.g. RNA-seq pipelines)
# without weakening protection on every other tool.
# tool_freq_overrides:
# bash:
# warn: 150
# hard_limit: 300
# ============================================================================ # ============================================================================
# Sandbox Configuration # Sandbox Configuration
# ============================================================================ # ============================================================================