mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711)
* Make loop detection configurable Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled. Refs bytedance/deer-flow#2517 * feat(loop-detection): add per-tool tool_freq_overrides to Phase 1 Adds ToolFreqOverride model and tool_freq_overrides field to LoopDetectionConfig, wires it through LoopDetectionMiddleware, and documents the option in config.example.yaml. Resolves the gap flagged in the #2586 review: without per-tool overrides, users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit) had no way to raise thresholds for one tool without loosening the global limit for every tool. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring Add the missing Args entry for tool_freq_overrides, explaining the (warn, hard_limit) tuple structure and how per-tool thresholds supersede the global tool_freq_warn / tool_freq_hard_limit for named tools. Also run ruff format on the three files flagged by the lint check. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly Raise clear ValueError at construction time instead of crashing at unpack-time inside _track_and_check when bad values are passed: - tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn - scalar thresholds: warn_threshold, hard_limit, tool_freq_warn, tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs - window_size, max_tracked_threads must be >= 1 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(test): isolate credential loader directory-path test from real ~/.claude The test didn't monkeypatch HOME, so on any machine with real Claude Code credentials at ~/.claude/.credentials.json the function fell through to those credentials and the assertion failed. Adding HOME redirect ensures the default credential path doesn't exist during the test. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style(test): add blank lines after import pytest in TestInitValidation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(loop-detection): collapse dual validation to LoopDetectionConfig Modifications - LoopDetectionMiddleware.__init__: stripped of all ValueError raises; becomes a plain field-assignment constructor. - LoopDetectionMiddleware.from_config: classmethod that builds the middleware from a Pydantic-validated LoopDetectionConfig and handles the ToolFreqOverride -> tuple[int, int] conversion. - agents/factory.py: SDK construction routed through LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the defaults path is Pydantic-validated too. - agents/lead_agent/agent.py: uses from_config instead of unpacking config fields by hand. - tests/test_loop_detection_middleware.py: deleted TestInitValidation (16 methods exercising the removed __init__ checks); added TestFromConfig (4 tests: scalar field mapping, override tuple conversion, empty overrides, behavioral smoke test). Result: one validation layer (Pydantic), zero duplication, no __new__ hacks. Both production construction sites flow through LoopDetectionConfig. Test results make test -> 2977 passed, 18 skipped, 0 failed (137s) make format -> All checks passed; 411 files left unchanged * feat(agents): make loop_detection configurable in create_deerflow_agent Adds a `loop_detection: bool | AgentMiddleware = True` field to RuntimeFeatures, mirroring the existing pattern used by `sandbox`, `memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware or replace it with a custom instance built from their own LoopDetectionConfig — e.g. `LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck with the hardcoded defaults previously installed by the SDK factory. The lead-agent path (which already reads AppConfig.loop_detection) is unchanged, and the default `True` preserves prior always-on behavior for all existing callers. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> --------- Co-authored-by: knight0940 <631532668@qq.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -192,6 +192,7 @@ def test_agent_features_defaults():
|
||||
assert f.vision is False
|
||||
assert f.auto_title is False
|
||||
assert f.guardrail is False
|
||||
assert f.loop_detection is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
|
||||
assert loop_idx == clar_idx - 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 30b. loop_detection=False skips LoopDetectionMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
@patch("deerflow.agents.factory.create_agent")
|
||||
def test_loop_detection_disabled(mock_create_agent):
|
||||
mock_create_agent.return_value = MagicMock()
|
||||
create_deerflow_agent(
|
||||
_make_mock_model(),
|
||||
features=RuntimeFeatures(sandbox=False, loop_detection=False),
|
||||
)
|
||||
|
||||
call_kwargs = mock_create_agent.call_args[1]
|
||||
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
|
||||
assert "LoopDetectionMiddleware" not in mw_types
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
|
||||
# ---------------------------------------------------------------------------
|
||||
@patch("deerflow.agents.factory.create_agent")
|
||||
def test_loop_detection_custom_middleware(mock_create_agent):
|
||||
from langchain.agents.middleware import AgentMiddleware as AM
|
||||
|
||||
mock_create_agent.return_value = MagicMock()
|
||||
|
||||
class MyLoopDetection(AM):
|
||||
pass
|
||||
|
||||
custom = MyLoopDetection()
|
||||
create_deerflow_agent(
|
||||
_make_mock_model(),
|
||||
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
|
||||
)
|
||||
|
||||
call_kwargs = mock_create_agent.call_args[1]
|
||||
middleware = call_kwargs["middleware"]
|
||||
assert custom in middleware
|
||||
mw_types = [type(m).__name__ for m in middleware]
|
||||
# Default LoopDetectionMiddleware must not also appear.
|
||||
assert "LoopDetectionMiddleware" not in mw_types
|
||||
# Custom replacement still sits immediately before ClarificationMiddleware.
|
||||
assert mw_types[-1] == "ClarificationMiddleware"
|
||||
assert mw_types[-2] == "MyLoopDetection"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 31. plan_mode=True adds TodoMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
|
||||
|
||||
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
|
||||
_clear_claude_code_env(monkeypatch)
|
||||
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
cred_dir = tmp_path / "claude-creds-dir"
|
||||
cred_dir.mkdir()
|
||||
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
|
||||
|
||||
@@ -8,17 +8,20 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
|
||||
|
||||
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
|
||||
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
|
||||
return AppConfig(
|
||||
models=models,
|
||||
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
|
||||
loop_detection=loop_detection or LoopDetectionConfig(),
|
||||
)
|
||||
|
||||
|
||||
@@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
|
||||
assert middlewares[0] == "base-middleware"
|
||||
|
||||
|
||||
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[_make_model("safe-model", supports_thinking=False)],
|
||||
loop_detection=LoopDetectionConfig(
|
||||
warn_threshold=7,
|
||||
hard_limit=9,
|
||||
window_size=30,
|
||||
max_tracked_threads=40,
|
||||
tool_freq_warn=50,
|
||||
tool_freq_hard_limit=60,
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="safe-model",
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
|
||||
assert loop_detection.warn_threshold == 7
|
||||
assert loop_detection.hard_limit == 9
|
||||
assert loop_detection.window_size == 30
|
||||
assert loop_detection.max_tracked_threads == 40
|
||||
assert loop_detection.tool_freq_warn == 50
|
||||
assert loop_detection.tool_freq_hard_limit == 60
|
||||
|
||||
|
||||
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[_make_model("safe-model", supports_thinking=False)],
|
||||
loop_detection=LoopDetectionConfig(enabled=False),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="safe-model",
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
|
||||
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Tests for loop detection configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
|
||||
class TestLoopDetectionConfig:
|
||||
def test_defaults_match_middleware_defaults(self):
|
||||
config = LoopDetectionConfig()
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.warn_threshold == 3
|
||||
assert config.hard_limit == 5
|
||||
assert config.window_size == 20
|
||||
assert config.max_tracked_threads == 100
|
||||
assert config.tool_freq_warn == 30
|
||||
assert config.tool_freq_hard_limit == 50
|
||||
|
||||
def test_accepts_custom_values(self):
|
||||
config = LoopDetectionConfig(
|
||||
enabled=False,
|
||||
warn_threshold=10,
|
||||
hard_limit=20,
|
||||
window_size=50,
|
||||
max_tracked_threads=200,
|
||||
tool_freq_warn=60,
|
||||
tool_freq_hard_limit=80,
|
||||
)
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.warn_threshold == 10
|
||||
assert config.hard_limit == 20
|
||||
assert config.window_size == 50
|
||||
assert config.max_tracked_threads == 200
|
||||
assert config.tool_freq_warn == 60
|
||||
assert config.tool_freq_hard_limit == 80
|
||||
|
||||
def test_rejects_zero_thresholds(self):
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(warn_threshold=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(hard_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_warn=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_hard_limit=0)
|
||||
|
||||
def test_rejects_hard_limit_below_warn_threshold(self):
|
||||
with pytest.raises(ValueError, match="hard_limit"):
|
||||
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
|
||||
|
||||
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
|
||||
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
|
||||
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
|
||||
|
||||
def test_tool_freq_override_valid(self):
|
||||
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
|
||||
override = config.tool_freq_overrides["bash"]
|
||||
assert override.warn == 150
|
||||
assert override.hard_limit == 300
|
||||
|
||||
def test_tool_freq_override_rejects_zero_warn(self):
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
|
||||
|
||||
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
|
||||
with pytest.raises(ValueError, match="hard_limit"):
|
||||
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
|
||||
@@ -648,6 +648,37 @@ class TestToolFrequencyDetection:
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_override_tool_uses_override_thresholds(self):
|
||||
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
tool_freq_warn=5,
|
||||
tool_freq_hard_limit=10,
|
||||
tool_freq_overrides={"bash": (50, 100)},
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
|
||||
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
|
||||
for i in range(10):
|
||||
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
||||
assert result is None, f"unexpected trigger on call {i + 1}"
|
||||
|
||||
def test_non_override_tool_falls_back_to_global(self):
|
||||
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
tool_freq_warn=3,
|
||||
tool_freq_hard_limit=6,
|
||||
tool_freq_overrides={"bash": (50, 100)},
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 3rd read_file call hits global warn=3 (read_file has no override)
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_hash_detection_takes_priority(self):
|
||||
"""Hash-based hard stop fires before frequency check for identical calls."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
@@ -668,3 +699,48 @@ class TestToolFrequencyDetection:
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
|
||||
|
||||
@staticmethod
|
||||
def _config(**kwargs):
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
return LoopDetectionConfig(**kwargs)
|
||||
|
||||
def test_scalar_fields_mapped(self):
|
||||
config = self._config(
|
||||
warn_threshold=4,
|
||||
hard_limit=8,
|
||||
window_size=15,
|
||||
max_tracked_threads=50,
|
||||
tool_freq_warn=20,
|
||||
tool_freq_hard_limit=40,
|
||||
)
|
||||
mw = LoopDetectionMiddleware.from_config(config)
|
||||
assert mw.warn_threshold == 4
|
||||
assert mw.hard_limit == 8
|
||||
assert mw.window_size == 15
|
||||
assert mw.max_tracked_threads == 50
|
||||
assert mw.tool_freq_warn == 20
|
||||
assert mw.tool_freq_hard_limit == 40
|
||||
|
||||
def test_overrides_converted_to_tuples(self):
|
||||
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
|
||||
mw = LoopDetectionMiddleware.from_config(config)
|
||||
assert mw._tool_freq_overrides == {"bash": (50, 100)}
|
||||
|
||||
def test_empty_overrides(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config())
|
||||
assert mw._tool_freq_overrides == {}
|
||||
|
||||
def test_constructed_middleware_detects_loops(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
Reference in New Issue
Block a user