mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
fix(subagents): use model override for tools and middleware (#2641)
* fix(subagents): use model override for tools and middleware * fix(config): resolve effective subagent model * fix(subagents): defer app config loading * fix(subagents): fully defer config.yaml load in executor __init__ The previous attempt only relocated the explicit get_app_config() call, but left resolve_subagent_model_name(...) running eagerly in __init__. That helper has its own internal get_app_config() fallback, which still fired when both app_config and parent_model were None and config.model == "inherit" — exactly the path unit tests hit, breaking 21 tests in CI with FileNotFoundError: config.yaml. Skip the eager resolve in __init__ when it would require loading the config file, and defer to _create_agent (which already has the app_config or get_app_config() fallback).
This commit is contained in:
+23
-2
@@ -136,11 +136,32 @@ def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = T
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_subagent_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
|
def build_subagent_runtime_middlewares(
|
||||||
|
*,
|
||||||
|
app_config: AppConfig | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
lazy_init: bool = True,
|
||||||
|
) -> list[AgentMiddleware]:
|
||||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||||
return _build_runtime_middlewares(
|
if app_config is None:
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
app_config = get_app_config()
|
||||||
|
|
||||||
|
middlewares = _build_runtime_middlewares(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
include_uploads=False,
|
include_uploads=False,
|
||||||
include_dangling_tool_call_patch=True,
|
include_dangling_tool_call_patch=True,
|
||||||
lazy_init=lazy_init,
|
lazy_init=lazy_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_name is None and app_config.models:
|
||||||
|
model_name = app_config.models[0].name
|
||||||
|
|
||||||
|
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||||
|
if model_config is not None and model_config.supports_vision:
|
||||||
|
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||||
|
|
||||||
|
middlewares.append(ViewImageMiddleware())
|
||||||
|
|
||||||
|
return middlewares
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
"""Subagent configuration definitions."""
|
"""Subagent configuration definitions."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deerflow.config.app_config import AppConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -29,3 +33,24 @@ class SubagentConfig:
|
|||||||
model: str = "inherit"
|
model: str = "inherit"
|
||||||
max_turns: int = 50
|
max_turns: int = 50
|
||||||
timeout_seconds: int = 900
|
timeout_seconds: int = 900
|
||||||
|
|
||||||
|
|
||||||
|
def _default_model_name(app_config: "AppConfig") -> str:
|
||||||
|
if not app_config.models:
|
||||||
|
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||||
|
return app_config.models[0].name
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_subagent_model_name(config: SubagentConfig, parent_model: str | None, *, app_config: "AppConfig | None" = None) -> str:
|
||||||
|
"""Resolve the effective model name a subagent should use."""
|
||||||
|
if config.model != "inherit":
|
||||||
|
return config.model
|
||||||
|
|
||||||
|
if parent_model is not None:
|
||||||
|
return parent_model
|
||||||
|
|
||||||
|
if app_config is None:
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
app_config = get_app_config()
|
||||||
|
return _default_model_name(app_config)
|
||||||
|
|||||||
@@ -20,9 +20,10 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
||||||
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.subagents.config import SubagentConfig
|
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -213,21 +214,6 @@ def _filter_tools(
|
|||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
def _get_model_name(config: SubagentConfig, parent_model: str | None) -> str | None:
|
|
||||||
"""Resolve the model name for a subagent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Subagent configuration.
|
|
||||||
parent_model: The parent agent's model name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model name to use, or None to use default.
|
|
||||||
"""
|
|
||||||
if config.model == "inherit":
|
|
||||||
return parent_model
|
|
||||||
return config.model
|
|
||||||
|
|
||||||
|
|
||||||
class SubagentExecutor:
|
class SubagentExecutor:
|
||||||
"""Executor for running subagents."""
|
"""Executor for running subagents."""
|
||||||
|
|
||||||
@@ -247,9 +233,9 @@ class SubagentExecutor:
|
|||||||
Args:
|
Args:
|
||||||
config: Subagent configuration.
|
config: Subagent configuration.
|
||||||
tools: List of all available tools (will be filtered).
|
tools: List of all available tools (will be filtered).
|
||||||
app_config: Resolved AppConfig; threaded into middleware factories
|
app_config: Resolved AppConfig. When None, ``_create_agent`` falls
|
||||||
at agent-build time. When None, ``_create_agent`` falls back to
|
back to ``get_app_config()`` (matches the lead-agent factory's
|
||||||
``get_app_config()`` (matches the lead-agent factory's pattern).
|
pattern).
|
||||||
parent_model: The parent agent's model name for inheritance.
|
parent_model: The parent agent's model name for inheritance.
|
||||||
sandbox_state: Sandbox state from parent agent.
|
sandbox_state: Sandbox state from parent agent.
|
||||||
thread_data: Thread data from parent agent.
|
thread_data: Thread data from parent agent.
|
||||||
@@ -259,6 +245,13 @@ class SubagentExecutor:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.app_config = app_config
|
self.app_config = app_config
|
||||||
self.parent_model = parent_model
|
self.parent_model = parent_model
|
||||||
|
# Resolve eagerly only when it does not require loading config.yaml; otherwise defer
|
||||||
|
# to _create_agent (which already loads app_config) so unit tests can construct
|
||||||
|
# executors without a config file present.
|
||||||
|
if config.model != "inherit" or parent_model is not None or app_config is not None:
|
||||||
|
self.model_name: str | None = resolve_subagent_model_name(config, parent_model, app_config=app_config)
|
||||||
|
else:
|
||||||
|
self.model_name = None
|
||||||
self.sandbox_state = sandbox_state
|
self.sandbox_state = sandbox_state
|
||||||
self.thread_data = thread_data
|
self.thread_data = thread_data
|
||||||
self.thread_id = thread_id
|
self.thread_id = thread_id
|
||||||
@@ -276,17 +269,15 @@ class SubagentExecutor:
|
|||||||
|
|
||||||
def _create_agent(self):
|
def _create_agent(self):
|
||||||
"""Create the agent instance."""
|
"""Create the agent instance."""
|
||||||
# Mirror lead-agent factory pattern: prefer explicit app_config,
|
app_config = self.app_config or get_app_config()
|
||||||
# fall back to ambient lookup at agent-build time.
|
if self.model_name is None:
|
||||||
from deerflow.config import get_app_config
|
self.model_name = resolve_subagent_model_name(self.config, self.parent_model, app_config=app_config)
|
||||||
|
model = create_chat_model(name=self.model_name, thinking_enabled=False, app_config=app_config)
|
||||||
resolved_app_config = self.app_config or get_app_config()
|
|
||||||
model_name = _get_model_name(self.config, self.parent_model)
|
|
||||||
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=resolved_app_config)
|
|
||||||
|
|
||||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||||
|
|
||||||
middlewares = build_subagent_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
# Reuse shared middleware composition with lead agent.
|
||||||
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True)
|
||||||
|
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@@ -11,9 +11,16 @@ from langgraph.config import get_stream_writer
|
|||||||
from langgraph.typing import ContextT
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
|
from deerflow.config import get_app_config
|
||||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
from deerflow.subagents.config import resolve_subagent_model_name
|
||||||
|
from deerflow.subagents.executor import (
|
||||||
|
SubagentStatus,
|
||||||
|
cleanup_background_task,
|
||||||
|
get_background_task_result,
|
||||||
|
request_cancel_background_task,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -129,14 +136,19 @@ async def task_tool(
|
|||||||
|
|
||||||
# Inherit parent agent's tool_groups so subagents respect the same restrictions
|
# Inherit parent agent's tool_groups so subagents respect the same restrictions
|
||||||
parent_tool_groups = metadata.get("tool_groups")
|
parent_tool_groups = metadata.get("tool_groups")
|
||||||
|
app_config = None
|
||||||
|
if config.model == "inherit" and parent_model is None:
|
||||||
|
app_config = get_app_config()
|
||||||
|
effective_model = resolve_subagent_model_name(config, parent_model, app_config=app_config)
|
||||||
|
|
||||||
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
||||||
tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False)
|
tools = get_available_tools(model_name=effective_model, groups=parent_tool_groups, subagent_enabled=False)
|
||||||
|
|
||||||
# Create executor
|
# Create executor
|
||||||
executor = SubagentExecutor(
|
executor = SubagentExecutor(
|
||||||
config=config,
|
config=config,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
app_config=app_config,
|
||||||
parent_model=parent_model,
|
parent_model=parent_model,
|
||||||
sandbox_state=sandbox_state,
|
sandbox_state=sandbox_state,
|
||||||
thread_data=thread_data,
|
thread_data=thread_data,
|
||||||
|
|||||||
@@ -258,6 +258,7 @@ class TestAgentConstruction:
|
|||||||
}
|
}
|
||||||
assert captured["middlewares"] == {
|
assert captured["middlewares"] == {
|
||||||
"app_config": app_config,
|
"app_config": app_config,
|
||||||
|
"model_name": "parent-model",
|
||||||
"lazy_init": True,
|
"lazy_init": True,
|
||||||
}
|
}
|
||||||
assert captured["agent"]["model"] is model
|
assert captured["agent"]["model"] is model
|
||||||
|
|||||||
@@ -223,6 +223,56 @@ def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch):
|
|||||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
|
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_tool_uses_subagent_model_override_for_tool_loading(monkeypatch):
|
||||||
|
"""Subagent model overrides should drive model-gated tool loading."""
|
||||||
|
config = SubagentConfig(
|
||||||
|
name="general-purpose",
|
||||||
|
description="General helper",
|
||||||
|
system_prompt="Base system prompt",
|
||||||
|
model="vision-subagent-model",
|
||||||
|
max_turns=50,
|
||||||
|
timeout_seconds=10,
|
||||||
|
)
|
||||||
|
runtime = _make_runtime()
|
||||||
|
runtime.config["metadata"]["model_name"] = "parent-text-model"
|
||||||
|
events = []
|
||||||
|
get_available_tools = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
class DummyExecutor:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def execute_async(self, prompt, task_id=None):
|
||||||
|
return task_id or "generated-task-id"
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"get_background_task_result",
|
||||||
|
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
|
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||||
|
|
||||||
|
output = _run_task_tool(
|
||||||
|
runtime=runtime,
|
||||||
|
description="inspect image",
|
||||||
|
prompt="inspect the uploaded image",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-issue-2543",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output == "Task Succeeded. Result: done"
|
||||||
|
get_available_tools.assert_called_once_with(
|
||||||
|
model_name="vision-subagent-model",
|
||||||
|
groups=None,
|
||||||
|
subagent_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
|
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
|
||||||
config = _make_subagent_config()
|
config = _make_subagent_config()
|
||||||
runtime = _make_runtime()
|
runtime = _make_runtime()
|
||||||
@@ -371,6 +421,7 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
|||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_app_config", lambda: SimpleNamespace(models=[SimpleNamespace(name="default-model")]))
|
||||||
|
|
||||||
output = _run_task_tool(
|
output = _run_task_tool(
|
||||||
runtime=None,
|
runtime=None,
|
||||||
@@ -381,8 +432,8 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert output == "Task Succeeded. Result: ok"
|
assert output == "Task Succeeded. Result: ok"
|
||||||
# runtime is None → metadata is empty dict → groups=None
|
# runtime is None -> metadata is empty dict -> groups=None, model falls back to app default.
|
||||||
get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False)
|
get_available_tools.assert_called_once_with(model_name="default-model", groups=None, subagent_enabled=False)
|
||||||
|
|
||||||
config = _make_subagent_config()
|
config = _make_subagent_config()
|
||||||
events = []
|
events = []
|
||||||
|
|||||||
@@ -9,11 +9,20 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import (
|
|||||||
ToolErrorHandlingMiddleware,
|
ToolErrorHandlingMiddleware,
|
||||||
build_subagent_runtime_middlewares,
|
build_subagent_runtime_middlewares,
|
||||||
)
|
)
|
||||||
|
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||||
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
||||||
from deerflow.config.guardrails_config import GuardrailsConfig
|
from deerflow.config.guardrails_config import GuardrailsConfig
|
||||||
|
from deerflow.config.model_config import ModelConfig
|
||||||
from deerflow.config.sandbox_config import SandboxConfig
|
from deerflow.config.sandbox_config import SandboxConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||||
|
tool_call = {"name": name}
|
||||||
|
if tool_call_id is not None:
|
||||||
|
tool_call["id"] = tool_call_id
|
||||||
|
return SimpleNamespace(tool_call=tool_call)
|
||||||
|
|
||||||
|
|
||||||
def _module(name: str, **attrs):
|
def _module(name: str, **attrs):
|
||||||
module = ModuleType(name)
|
module = ModuleType(name)
|
||||||
for key, value in attrs.items():
|
for key, value in attrs.items():
|
||||||
@@ -21,19 +30,62 @@ def _module(name: str, **attrs):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def _make_app_config() -> AppConfig:
|
def _make_app_config(*, supports_vision: bool = False) -> AppConfig:
|
||||||
return AppConfig(
|
return AppConfig(
|
||||||
|
models=[
|
||||||
|
ModelConfig(
|
||||||
|
name="test-model",
|
||||||
|
display_name="test-model",
|
||||||
|
description=None,
|
||||||
|
use="langchain_openai:ChatOpenAI",
|
||||||
|
model="test-model",
|
||||||
|
supports_vision=supports_vision,
|
||||||
|
)
|
||||||
|
],
|
||||||
sandbox=SandboxConfig(use="test"),
|
sandbox=SandboxConfig(use="test"),
|
||||||
guardrails=GuardrailsConfig(enabled=False),
|
guardrails=GuardrailsConfig(enabled=False),
|
||||||
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
tool_call = {"name": name}
|
class FakeMiddleware:
|
||||||
if tool_call_id is not None:
|
def __init__(self, *args, **kwargs):
|
||||||
tool_call["id"] = tool_call_id
|
self.args = args
|
||||||
return SimpleNamespace(tool_call=tool_call)
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
class FakeLLMErrorHandlingMiddleware:
|
||||||
|
def __init__(self, *, app_config):
|
||||||
|
self.app_config = app_config
|
||||||
|
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||||
|
_module(
|
||||||
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||||
|
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.thread_data_middleware",
|
||||||
|
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.sandbox.middleware",
|
||||||
|
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
||||||
|
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
||||||
|
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
||||||
@@ -166,3 +218,30 @@ async def test_awrap_tool_call_reraises_graph_interrupt():
|
|||||||
|
|
||||||
with pytest.raises(GraphInterrupt):
|
with pytest.raises(GraphInterrupt):
|
||||||
await middleware.awrap_tool_call(req, _interrupt)
|
await middleware.awrap_tool_call(req, _interrupt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch):
|
||||||
|
app_config = _make_app_config(supports_vision=True)
|
||||||
|
_stub_runtime_middleware_imports(monkeypatch)
|
||||||
|
|
||||||
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
||||||
|
|
||||||
|
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch):
|
||||||
|
app_config = _make_app_config(supports_vision=True)
|
||||||
|
_stub_runtime_middleware_imports(monkeypatch)
|
||||||
|
|
||||||
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None)
|
||||||
|
|
||||||
|
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch):
|
||||||
|
app_config = _make_app_config(supports_vision=False)
|
||||||
|
_stub_runtime_middleware_imports(monkeypatch)
|
||||||
|
|
||||||
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
||||||
|
|
||||||
|
assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||||
|
|||||||
Reference in New Issue
Block a user