refactor: thread app_config through middleware factories (#2652)

* refactor: thread app_config through middleware factories

Continues the incremental config-refactor sequence (#2611 root, #2612 lead
path) one layer deeper into the middleware factories. Two ambient lookups
inside _build_runtime_middlewares are eliminated and the LLMErrorHandling
band-aid removed:

- _build_runtime_middlewares / build_lead_runtime_middlewares /
  build_subagent_runtime_middlewares now require app_config: AppConfig.
- get_guardrails_config() inside the factory is replaced with
  app_config.guardrails (semantically identical — same default-factory
  GuardrailsConfig — verified by direct equality check).
- LLMErrorHandlingMiddleware.__init__ now requires app_config and reads
  circuit_breaker fields directly. The class-level
  circuit_failure_threshold / circuit_recovery_timeout_sec defaults are
  removed along with the try/except (FileNotFoundError, RuntimeError):
  pass band-aid — the let-it-crash invariant the rest of the refactor
  enforces.

Caller chain (already-resolved app_config sources):
- _build_middlewares in lead_agent/agent.py: reorder so
  resolved_app_config = app_config or get_app_config() is computed BEFORE
  build_lead_runtime_middlewares is called, then passed as kwarg.
- SubagentExecutor: optional app_config parameter (mirrors the lead-agent
  pattern); _create_agent does the same `or get_app_config()` fallback at
  agent-build time, so task_tool callers don't need to plumb app_config
  through yet (typed-context plumbing for tool runtimes is a separate
  refactor).

Tests:
- test_llm_error_handling_middleware: _make_app_config helper using
  AppConfig(sandbox=SandboxConfig(use="test")) — same minimal-config
  pattern conftest already uses. Three direct LLMErrorHandlingMiddleware()
  calls each followed by post-construction circuit_breaker mutation fold
  cleanly into _build_middleware(circuit_failure_threshold=...,
  circuit_recovery_timeout_sec=...).

Verification:
- tests/test_llm_error_handling_middleware.py — 14 passed
- tests/test_subagent_executor.py — 28 passed
- tests/test_tool_error_handling_middleware.py — 6 passed
- tests/test_task_tool_core_logic.py — 18 passed (verifies task_tool
  unchanged behavior)
- Full suite: 2697 passed, 3 skipped. The single intermittent failure in
  tests/test_client_e2e.py::test_tool_call_produces_events is pre-existing
  LLM flakiness (the test asserts the model decided to call a tool;
  reproduces 1/3 on unchanged main as well).

* fix: address middleware app config review comments

* fix: satisfy app config annotation lint

* test: cover explicit app config middleware wiring

---------

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-04-30 12:41:09 +08:00
committed by GitHub
parent 74081a85a6
commit 38714b6ceb
8 changed files with 236 additions and 34 deletions
@@ -217,6 +217,40 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
captured: dict[str, object] = {}
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def _fake_build_lead_runtime_middlewares(*, app_config, lazy_init):
captured["app_config"] = app_config
captured["lazy_init"] = lazy_init
return ["base-middleware"]
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(
lead_agent_module,
"build_lead_runtime_middlewares",
_fake_build_lead_runtime_middlewares,
)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert captured == {
"app_config": app_config,
"lazy_init": True,
}
assert middlewares[0] == "base-middleware"
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
monkeypatch.setattr(
lead_agent_module,
@@ -11,6 +11,13 @@ from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_app_config() -> AppConfig:
"""Minimal AppConfig for middleware tests; circuit_breaker uses defaults."""
return AppConfig(sandbox=SandboxConfig(use="test"))
class FakeError(Exception):
@@ -31,7 +38,7 @@ class FakeError(Exception):
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware()
middleware = LLMErrorHandlingMiddleware(app_config=_make_app_config())
for key, value in attrs.items():
setattr(middleware, key, value)
return middleware
@@ -226,9 +233,7 @@ def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) ->
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
request: Any = {"messages": []}
@@ -284,8 +289,7 @@ def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pyte
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware = _build_middleware(circuit_failure_threshold=3)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
request: Any = {"messages": []}
@@ -386,9 +390,7 @@ async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.Monk
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
async def async_failing_handler(request: Any) -> Any:
+90
View File
@@ -17,6 +17,7 @@ import asyncio
import sys
import threading
from datetime import datetime
from types import ModuleType
from unittest.mock import MagicMock, patch
import pytest
@@ -153,6 +154,13 @@ def mock_agent():
return agent
def _module(name: str, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
# Helper to create real message objects
class _MsgHelper:
"""Helper to create real message objects from fixture classes."""
@@ -176,6 +184,88 @@ def msg(classes):
return _MsgHelper(classes)
# -----------------------------------------------------------------------------
# Agent Construction Tests
# -----------------------------------------------------------------------------
class TestAgentConstruction:
"""Test _create_agent() wiring before execution starts."""
def test_create_agent_threads_explicit_app_config_to_model_and_middlewares(
self,
classes,
base_config,
monkeypatch: pytest.MonkeyPatch,
):
"""Explicit app_config must flow into both model and middleware factories."""
import deerflow.config as config_module
from deerflow.subagents import executor as executor_module
SubagentExecutor = classes["SubagentExecutor"]
app_config = object()
model = object()
middlewares = [object()]
agent = object()
captured: dict[str, dict] = {}
def fake_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def fake_create_chat_model(**kwargs):
captured["model"] = kwargs
return model
def fake_build_subagent_runtime_middlewares(**kwargs):
captured["middlewares"] = kwargs
return middlewares
def fake_create_agent(**kwargs):
captured["agent"] = kwargs
return agent
monkeypatch.setattr(config_module, "get_app_config", fake_get_app_config)
monkeypatch.setattr(
executor_module,
"create_chat_model",
fake_create_chat_model,
)
monkeypatch.setattr(executor_module, "create_agent", fake_create_agent)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.tool_error_handling_middleware",
_module(
"deerflow.agents.middlewares.tool_error_handling_middleware",
build_subagent_runtime_middlewares=fake_build_subagent_runtime_middlewares,
),
)
executor = SubagentExecutor(
config=base_config,
tools=[],
app_config=app_config,
parent_model="parent-model",
)
result = executor._create_agent()
assert result is agent
assert captured["model"] == {
"name": "parent-model",
"thinking_enabled": False,
"app_config": app_config,
}
assert captured["middlewares"] == {
"app_config": app_config,
"lazy_init": True,
}
assert captured["agent"]["model"] is model
assert captured["agent"]["middleware"] is middlewares
assert captured["agent"]["tools"] == []
assert captured["agent"]["system_prompt"] == base_config.system_prompt
# -----------------------------------------------------------------------------
# Async Execution Path Tests
# -----------------------------------------------------------------------------
@@ -1,10 +1,32 @@
from types import SimpleNamespace
import sys
from types import ModuleType, SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphInterrupt
from deerflow.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import (
ToolErrorHandlingMiddleware,
build_subagent_runtime_middlewares,
)
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.sandbox_config import SandboxConfig
def _module(name: str, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _make_app_config() -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=False),
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
)
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
@@ -14,6 +36,56 @@ def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
return SimpleNamespace(tool_call=tool_call)
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, object] = {}
class FakeMiddleware:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FakeLLMErrorHandlingMiddleware:
def __init__(self, *, app_config):
captured["app_config"] = app_config
app_config = _make_app_config()
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.llm_error_handling_middleware",
_module(
"deerflow.agents.middlewares.llm_error_handling_middleware",
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.thread_data_middleware",
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.sandbox.middleware",
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.dangling_tool_call_middleware",
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.sandbox_audit_middleware",
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
assert len(middlewares) == 6
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
def test_wrap_tool_call_passthrough_on_success():
middleware = ToolErrorHandlingMiddleware()
req = _request()