mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(subagents): use model override for tools and middleware (#2641)
* fix(subagents): use model override for tools and middleware * fix(config): resolve effective subagent model * fix(subagents): defer app config loading * fix(subagents): fully defer config.yaml load in executor __init__ The previous attempt only relocated the explicit get_app_config() call, but left resolve_subagent_model_name(...) running eagerly in __init__. That helper has its own internal get_app_config() fallback, which still fired when both app_config and parent_model were None and config.model == "inherit" — exactly the path unit tests hit, breaking 21 tests in CI with FileNotFoundError: config.yaml. Skip the eager resolve in __init__ when it would require loading the config file, and defer to _create_agent (which already has the app_config or get_app_config() fallback).
This commit is contained in:
@@ -258,6 +258,7 @@ class TestAgentConstruction:
|
||||
}
|
||||
assert captured["middlewares"] == {
|
||||
"app_config": app_config,
|
||||
"model_name": "parent-model",
|
||||
"lazy_init": True,
|
||||
}
|
||||
assert captured["agent"]["model"] is model
|
||||
|
||||
@@ -223,6 +223,56 @@ def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch):
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
|
||||
|
||||
|
||||
def test_task_tool_uses_subagent_model_override_for_tool_loading(monkeypatch):
|
||||
"""Subagent model overrides should drive model-gated tool loading."""
|
||||
config = SubagentConfig(
|
||||
name="general-purpose",
|
||||
description="General helper",
|
||||
system_prompt="Base system prompt",
|
||||
model="vision-subagent-model",
|
||||
max_turns=50,
|
||||
timeout_seconds=10,
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
runtime.config["metadata"]["model_name"] = "parent-text-model"
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="inspect image",
|
||||
prompt="inspect the uploaded image",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-issue-2543",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
get_available_tools.assert_called_once_with(
|
||||
model_name="vision-subagent-model",
|
||||
groups=None,
|
||||
subagent_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
@@ -371,6 +421,7 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
monkeypatch.setattr(task_tool_module, "get_app_config", lambda: SimpleNamespace(models=[SimpleNamespace(name="default-model")]))
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=None,
|
||||
@@ -381,8 +432,8 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: ok"
|
||||
# runtime is None → metadata is empty dict → groups=None
|
||||
get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False)
|
||||
# runtime is None -> metadata is empty dict -> groups=None, model falls back to app default.
|
||||
get_available_tools.assert_called_once_with(model_name="default-model", groups=None, subagent_enabled=False)
|
||||
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
|
||||
@@ -9,11 +9,20 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import (
|
||||
ToolErrorHandlingMiddleware,
|
||||
build_subagent_runtime_middlewares,
|
||||
)
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||
tool_call = {"name": name}
|
||||
if tool_call_id is not None:
|
||||
tool_call["id"] = tool_call_id
|
||||
return SimpleNamespace(tool_call=tool_call)
|
||||
|
||||
|
||||
def _module(name: str, **attrs):
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
@@ -21,19 +30,62 @@ def _module(name: str, **attrs):
|
||||
return module
|
||||
|
||||
|
||||
def _make_app_config() -> AppConfig:
|
||||
def _make_app_config(*, supports_vision: bool = False) -> AppConfig:
|
||||
return AppConfig(
|
||||
models=[
|
||||
ModelConfig(
|
||||
name="test-model",
|
||||
display_name="test-model",
|
||||
description=None,
|
||||
use="langchain_openai:ChatOpenAI",
|
||||
model="test-model",
|
||||
supports_vision=supports_vision,
|
||||
)
|
||||
],
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
guardrails=GuardrailsConfig(enabled=False),
|
||||
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
||||
)
|
||||
|
||||
|
||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||
tool_call = {"name": name}
|
||||
if tool_call_id is not None:
|
||||
tool_call["id"] = tool_call_id
|
||||
return SimpleNamespace(tool_call=tool_call)
|
||||
def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakeMiddleware:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class FakeLLMErrorHandlingMiddleware:
|
||||
def __init__(self, *, app_config):
|
||||
self.app_config = app_config
|
||||
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||
_module(
|
||||
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
||||
),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.thread_data_middleware",
|
||||
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.sandbox.middleware",
|
||||
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
||||
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
||||
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
||||
)
|
||||
|
||||
|
||||
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -166,3 +218,30 @@ async def test_awrap_tool_call_reraises_graph_interrupt():
|
||||
|
||||
with pytest.raises(GraphInterrupt):
|
||||
await middleware.awrap_tool_call(req, _interrupt)
|
||||
|
||||
|
||||
def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch):
|
||||
app_config = _make_app_config(supports_vision=True)
|
||||
_stub_runtime_middleware_imports(monkeypatch)
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
||||
|
||||
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||
|
||||
|
||||
def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch):
|
||||
app_config = _make_app_config(supports_vision=True)
|
||||
_stub_runtime_middleware_imports(monkeypatch)
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None)
|
||||
|
||||
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||
|
||||
|
||||
def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch):
|
||||
app_config = _make_app_config(supports_vision=False)
|
||||
_stub_runtime_middleware_imports(monkeypatch)
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
||||
|
||||
assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||
|
||||
Reference in New Issue
Block a user