refactor(config): eliminate global mutable state, wire DeerFlowContext into runtime

- Freeze all config models (AppConfig + 15 sub-configs) with frozen=True
- Purify from_file() — remove 9 load_*_from_dict() side-effect calls
- Replace mtime/reload/push/pop machinery with single ContextVar + init_app_config()
- Delete 10 sub-module globals and their getters/setters/loaders
- Migrate 50+ consumers from get_*_config() to get_app_config().xxx

- Expand DeerFlowContext: app_config + thread_id + agent_name (frozen dataclass)
- Wire into Gateway runtime (worker.py) and DeerFlowClient via context= parameter
- Remove sandbox_id from runtime.context — flows through ThreadState.sandbox only
- Middleware/tools access runtime.context directly via Runtime[DeerFlowContext] generic
- resolve_context() retained at server entry points for LangGraph Server fallback
This commit is contained in:
greatmengqi
2026-04-13 23:49:31 +08:00
parent c4d273a68a
commit edf345cd72
111 changed files with 4848 additions and 4079 deletions
+27 -37
View File
@@ -6,17 +6,20 @@ import pytest
import yaml
from pydantic import ValidationError
from deerflow.config.acp_config import ACPAgentConfig, get_acp_agents, load_acp_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def setup_function():
"""Reset ACP config before each test."""
load_acp_config_from_dict({})
def _make_config(acp_agents: dict | None = None) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
acp_agents={name: ACPAgentConfig(**cfg) for name, cfg in (acp_agents or {}).items()},
)
def test_load_acp_config_sets_agents():
load_acp_config_from_dict(
def test_acp_agents_via_app_config():
cfg = _make_config(
{
"claude_code": {
"command": "claude-code-acp",
@@ -26,39 +29,33 @@ def test_load_acp_config_sets_agents():
}
}
)
agents = get_acp_agents()
agents = cfg.acp_agents
assert "claude_code" in agents
assert agents["claude_code"].command == "claude-code-acp"
assert agents["claude_code"].description == "Claude Code for coding tasks"
assert agents["claude_code"].model is None
def test_load_acp_config_multiple_agents():
load_acp_config_from_dict(
def test_multiple_agents():
cfg = _make_config(
{
"claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"},
"codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"},
}
)
agents = get_acp_agents()
agents = cfg.acp_agents
assert len(agents) == 2
assert agents["codex"].args == ["--flag"]
def test_load_acp_config_empty_clears_agents():
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert len(get_acp_agents()) == 1
load_acp_config_from_dict({})
assert len(get_acp_agents()) == 0
def test_empty_acp_agents():
cfg = _make_config({})
assert cfg.acp_agents == {}
def test_load_acp_config_none_clears_agents():
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert len(get_acp_agents()) == 1
load_acp_config_from_dict(None)
assert get_acp_agents() == {}
def test_default_acp_agents_empty():
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
assert cfg.acp_agents == {}
def test_acp_agent_config_defaults():
@@ -79,8 +76,8 @@ def test_acp_agent_config_env_default_is_empty():
assert cfg.env == {}
def test_load_acp_config_preserves_env():
load_acp_config_from_dict(
def test_acp_agent_preserves_env():
cfg = _make_config(
{
"codex": {
"command": "codex-acp",
@@ -90,8 +87,7 @@ def test_load_acp_config_preserves_env():
}
}
)
cfg = get_acp_agents()["codex"]
assert cfg.env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
assert cfg.acp_agents["codex"].env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
def test_acp_agent_config_with_model():
@@ -115,13 +111,7 @@ def test_acp_agent_config_missing_description_raises():
ACPAgentConfig(command="my-agent")
def test_get_acp_agents_returns_empty_by_default():
"""After clearing, should return empty dict."""
load_acp_config_from_dict({})
assert get_acp_agents() == {}
def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, monkeypatch):
def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -157,9 +147,9 @@ def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, mo
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8")
AppConfig.from_file(str(config_path))
assert set(get_acp_agents()) == {"codex"}
app = AppConfig.from_file(str(config_path))
assert set(app.acp_agents) == {"codex"}
config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8")
AppConfig.from_file(str(config_path))
assert get_acp_agents() == {}
app = AppConfig.from_file(str(config_path))
assert app.acp_agents == {}
+43 -33
View File
@@ -1,12 +1,11 @@
from __future__ import annotations
import json
import os
from pathlib import Path
import yaml
from deerflow.config.app_config import get_app_config, reset_app_config
from deerflow.config.app_config import AppConfig
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@@ -32,50 +31,61 @@ def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch):
def test_init_then_get(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_path, model_name="first-model", supports_thinking=False)
_write_config(config_path, model_name="test-model", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
initial = get_app_config()
assert initial.models[0].supports_thinking is False
config = AppConfig.from_file(str(config_path))
AppConfig.init(config)
_write_config(config_path, model_name="first-model", supports_thinking=True)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded.models[0].supports_thinking is True
assert reloaded is not initial
finally:
reset_app_config()
result = AppConfig.current()
assert result is config
assert result.models[0].name == "test-model"
def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
config_a = tmp_path / "config-a.yaml"
config_b = tmp_path / "config-b.yaml"
def test_init_replaces_previous(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_a, model_name="model-a", supports_thinking=False)
_write_config(config_b, model_name="model-b", supports_thinking=True)
_write_config(config_path, model_name="model-a", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_a))
reset_app_config()
try:
first = get_app_config()
assert first.models[0].name == "model-a"
config_a = AppConfig.from_file(str(config_path))
AppConfig.init(config_a)
assert AppConfig.current().models[0].name == "model-a"
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_b))
second = get_app_config()
assert second.models[0].name == "model-b"
assert second is not first
finally:
reset_app_config()
_write_config(config_path, model_name="model-b", supports_thinking=True)
config_b = AppConfig.from_file(str(config_path))
AppConfig.init(config_b)
assert AppConfig.current().models[0].name == "model-b"
assert AppConfig.current() is config_b
def test_config_version_check(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
config_path.write_text(
yaml.safe_dump(
{
"config_version": 1,
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [],
}
),
encoding="utf-8",
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config = AppConfig.from_file(str(config_path))
assert config is not None
+68 -121
View File
@@ -5,25 +5,21 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import deerflow.config.app_config as app_config_module
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.config.checkpointer_config import (
CheckpointerConfig,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_config(checkpointer: CheckpointerConfig | None = None) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), checkpointer=checkpointer)
@pytest.fixture(autouse=True)
def reset_state():
"""Reset singleton state before each test."""
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
@@ -33,24 +29,18 @@ def reset_state():
class TestCheckpointerConfig:
def test_load_memory_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
def test_memory_config(self):
config = CheckpointerConfig(type="memory")
assert config.type == "memory"
assert config.connection_string is None
def test_load_sqlite_config(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
def test_sqlite_config(self):
config = CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")
assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db"
def test_load_postgres_config(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
def test_postgres_config(self):
config = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db"
@@ -58,14 +48,9 @@ class TestCheckpointerConfig:
config = CheckpointerConfig(type="memory")
assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
CheckpointerConfig(type="unknown")
# ---------------------------------------------------------------------------
@@ -78,58 +63,78 @@ class TestGetCheckpointer:
"""get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
with patch.object(AppConfig, "current", return_value=_make_config()):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_returns_in_memory_saver_when_config_not_found(self):
from langgraph.checkpoint.memory import InMemorySaver
with patch.object(AppConfig, "current", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver
cp = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp = get_checkpointer()
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cp2 = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp1 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is cp2
def test_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}),
):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
def test_postgres_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}),
):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
def test_postgres_raises_when_connection_string_missing(self):
load_checkpointer_config_from_dict({"type": "postgres"})
cfg = _make_config(CheckpointerConfig(type="postgres"))
mock_saver = MagicMock()
mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}),
):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
@@ -142,7 +147,10 @@ class TestGetCheckpointer:
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
):
reset_checkpointer()
cp = get_checkpointer()
@@ -152,7 +160,7 @@ class TestGetCheckpointer:
def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
@@ -165,7 +173,10 @@ class TestGetCheckpointer:
mock_pg_module = MagicMock()
mock_pg_module.PostgresSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}),
):
reset_checkpointer()
cp = get_checkpointer()
@@ -195,7 +206,7 @@ class TestAsyncCheckpointer:
mock_module.AsyncSqliteSaver = mock_saver_cls
with (
patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.object(AppConfig, "current", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch(
@@ -221,12 +232,10 @@ class TestAsyncCheckpointer:
class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self):
"""load_checkpointer_config_from_dict populates the global config."""
set_checkpointer_config(None)
load_checkpointer_config_from_dict({"type": "memory"})
cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
"""AppConfig with checkpointer section has the correct config."""
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None
assert cfg.checkpointer.type == "memory"
# ---------------------------------------------------------------------------
@@ -237,68 +246,6 @@ class TestAppConfigLoadsCheckpointer:
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp
# This is a structural test — verifying the fallback path exists.
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None
+6 -4
View File
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.config.app_config import AppConfig
class TestCheckpointerNoneFix:
"""Tests that checkpointer context managers return InMemorySaver instead of None."""
@@ -14,11 +16,11 @@ class TestCheckpointerNoneFix:
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
# Mock get_app_config to return a config with checkpointer=None
# Mock AppConfig.get to return a config with checkpointer=None
mock_config = MagicMock()
mock_config.checkpointer = None
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
async with make_checkpointer() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
@@ -37,11 +39,11 @@ class TestCheckpointerNoneFix:
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
from deerflow.agents.checkpointer.provider import checkpointer_context
# Mock get_app_config to return a config with checkpointer=None
# Mock AppConfig.get to return a config with checkpointer=None
mock_config = MagicMock()
mock_config.checkpointer = None
with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
with checkpointer_context() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
+69 -52
View File
@@ -18,6 +18,7 @@ from app.gateway.routers.models import ModelResponse, ModelsListResponse
from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
from app.gateway.routers.uploads import UploadResponse
from deerflow.client import DeerFlowClient
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import Paths
from deerflow.uploads.manager import PathTraversalError
@@ -44,7 +45,7 @@ def mock_app_config():
@pytest.fixture
def client(mock_app_config):
"""Create a DeerFlowClient with mocked config loading."""
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
return DeerFlowClient()
@@ -66,7 +67,7 @@ class TestClientInit:
def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
@@ -77,7 +78,7 @@ class TestClientInit:
assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config):
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="invalid name with spaces!")
with pytest.raises(ValueError, match="Invalid agent name"):
@@ -85,15 +86,17 @@ class TestClientInit:
def test_custom_config_path(self, mock_app_config):
with (
patch("deerflow.client.reload_app_config") as mock_reload,
patch("deerflow.client.get_app_config", return_value=mock_app_config),
patch.object(AppConfig, "from_file", return_value=mock_app_config) as mock_from_file,
patch.object(AppConfig, "init") as mock_init,
patch.object(AppConfig, "current", return_value=mock_app_config),
):
DeerFlowClient(config_path="/tmp/custom.yaml")
mock_reload.assert_called_once_with("/tmp/custom.yaml")
mock_from_file.assert_called_once_with("/tmp/custom.yaml")
mock_init.assert_called_once_with(mock_app_config)
def test_checkpointer_stored(self, mock_app_config):
cp = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
c = DeerFlowClient(checkpointer=cp)
assert c._checkpointer is cp
@@ -249,8 +252,8 @@ class TestStream:
# Verify context passed to agent.stream
agent.stream.assert_called_once()
call_kwargs = agent.stream.call_args.kwargs
assert call_kwargs["context"]["thread_id"] == "t1"
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
ctx = call_kwargs["context"]
assert ctx.app_config is client._app_config
def test_custom_mode_is_normalized_to_string(self, client):
"""stream() forwards custom events even when the mode is not a plain string."""
@@ -1089,7 +1092,7 @@ class TestMcpConfig:
ext_config = MagicMock()
ext_config.mcp_servers = {"github": server}
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
with patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)):
result = client.get_mcp_config()
assert "mcp_servers" in result
@@ -1114,10 +1117,12 @@ class TestMcpConfig:
# Pre-set agent to verify it gets invalidated
client._agent = MagicMock()
# Set initial AppConfig with current extensions
AppConfig.init(MagicMock(extensions=current_config))
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
):
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
@@ -1179,8 +1184,8 @@ class TestSkillsManagement:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
result = client.update_skill("test-skill", enabled=False)
assert result["enabled"] is False
@@ -1311,35 +1316,40 @@ class TestMemoryManagement:
assert result == data
def test_get_memory_config(self, client):
config = MagicMock()
config.enabled = True
config.storage_path = ".deer-flow/memory.json"
config.debounce_seconds = 30
config.max_facts = 100
config.fact_confidence_threshold = 0.7
config.injection_enabled = True
config.max_injection_tokens = 2000
mem_config = MagicMock()
mem_config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30
mem_config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True
mem_config.max_injection_tokens = 2000
with patch("deerflow.config.memory_config.get_memory_config", return_value=config):
app_cfg = MagicMock()
app_cfg.memory = mem_config
with patch.object(AppConfig, "current", return_value=app_cfg):
result = client.get_memory_config()
assert result["enabled"] is True
assert result["max_facts"] == 100
def test_get_memory_status(self, client):
config = MagicMock()
config.enabled = True
config.storage_path = ".deer-flow/memory.json"
config.debounce_seconds = 30
config.max_facts = 100
config.fact_confidence_threshold = 0.7
config.injection_enabled = True
config.max_injection_tokens = 2000
mem_config = MagicMock()
mem_config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30
mem_config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True
mem_config.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_config
data = {"version": "1.0", "facts": []}
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch.object(AppConfig, "current", return_value=app_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=data),
):
result = client.get_memory_status()
@@ -1783,10 +1793,10 @@ class TestScenarioConfigManagement:
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
client._agent = MagicMock() # Simulate existing agent
AppConfig.init(MagicMock(extensions=current_config))
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
):
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
assert "my-mcp" in mcp_result["mcp_servers"]
@@ -1815,8 +1825,8 @@ class TestScenarioConfigManagement:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
skill_result = client.update_skill("code-gen", enabled=False)
assert skill_result["enabled"] is False
@@ -2001,8 +2011,10 @@ class TestScenarioMemoryWorkflow:
refreshed = client.reload_memory()
assert len(refreshed["facts"]) == 2
app_cfg = MagicMock()
app_cfg.memory = config
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch.object(AppConfig, "current", return_value=app_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data),
):
status = client.get_memory_status()
@@ -2065,8 +2077,8 @@ class TestScenarioSkillInstallAndUse:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
toggled = client.update_skill("my-analyzer", enabled=False)
assert toggled["enabled"] is False
@@ -2198,7 +2210,7 @@ class TestGatewayConformance:
model.supports_thinking = False
mock_app_config.models = [model]
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
client = DeerFlowClient()
result = client.list_models()
@@ -2217,7 +2229,7 @@ class TestGatewayConformance:
mock_app_config.models = [model]
mock_app_config.get_model_config.return_value = model
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
client = DeerFlowClient()
result = client.get_model("test-model")
@@ -2287,7 +2299,7 @@ class TestGatewayConformance:
ext_config = MagicMock()
ext_config.mcp_servers = {"test": server}
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
with patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)):
result = client.get_mcp_config()
parsed = McpConfigResponse(**result)
@@ -2313,9 +2325,9 @@ class TestGatewayConformance:
config_file.write_text("{}")
with (
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.reload_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=ext_config)),
):
result = client.update_mcp_config({"srv": server.model_dump.return_value})
@@ -2346,7 +2358,10 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
with patch.object(AppConfig, "current", return_value=app_cfg):
result = client.get_memory_config()
parsed = MemoryConfigResponse(**result)
@@ -2363,6 +2378,8 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
memory_data = {
"version": "1.0",
"lastUpdated": "",
@@ -2380,7 +2397,7 @@ class TestGatewayConformance:
}
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg),
patch.object(AppConfig, "current", return_value=app_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data),
):
result = client.get_memory_status()
@@ -2671,8 +2688,8 @@ class TestConfigUpdateErrors:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
with pytest.raises(RuntimeError, match="disappeared"):
client.update_skill("ghost-skill", enabled=False)
@@ -3042,10 +3059,10 @@ class TestBugAgentInvalidationInconsistency:
config_file = Path(tmp) / "ext.json"
config_file.write_text("{}")
AppConfig.init(MagicMock(extensions=current_config))
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),
):
client.update_mcp_config({})
@@ -3077,8 +3094,8 @@ class TestBugAgentInvalidationInconsistency:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
client.update_skill("s1", enabled=False)
+73
View File
@@ -0,0 +1,73 @@
"""Verify that all sub-config Pydantic models are frozen (immutable).
Frozen models reject attribute assignment after construction, raising
pydantic.ValidationError. This test collects every BaseModel subclass
defined in the deerflow.config package and asserts that mutation is
blocked.
"""
import inspect
import pkgutil
import pytest
from pydantic import BaseModel, ValidationError
import deerflow.config as config_pkg
def _collect_config_models() -> list[type[BaseModel]]:
"""Walk deerflow.config.* and return all concrete BaseModel subclasses."""
import importlib
models: list[type[BaseModel]] = []
package_path = config_pkg.__path__
package_prefix = config_pkg.__name__ + "."
for _importer, modname, _ispkg in pkgutil.walk_packages(package_path, prefix=package_prefix):
try:
mod = importlib.import_module(modname)
except Exception:
continue
for _name, obj in inspect.getmembers(mod, inspect.isclass):
if (
issubclass(obj, BaseModel)
and obj is not BaseModel
and obj.__module__ == mod.__name__
):
models.append(obj)
return models
_EXCLUDED: set[str] = set()
_ALL_MODELS = [m for m in _collect_config_models() if m.__name__ not in _EXCLUDED]
# Sanity: make sure we actually collected a meaningful set.
assert len(_ALL_MODELS) >= 15, f"Expected at least 15 config models, found {len(_ALL_MODELS)}: {[m.__name__ for m in _ALL_MODELS]}"
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_is_frozen(model_cls: type[BaseModel]):
"""Every sub-config model must have frozen=True in its model_config."""
cfg = model_cls.model_config
assert cfg.get("frozen") is True, (
f"{model_cls.__name__} is not frozen. "
f"Add `model_config = ConfigDict(frozen=True)` or add `frozen=True` to the existing ConfigDict."
)
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
"""Constructing then mutating any field must raise ValidationError."""
# Build a minimal instance -- use model_construct to skip validation for
# required fields, then pick the first field to try mutating.
fields = list(model_cls.model_fields.keys())
if not fields:
pytest.skip(f"{model_cls.__name__} has no fields")
instance = model_cls.model_construct()
first_field = fields[0]
with pytest.raises(ValidationError):
setattr(instance, first_field, "MUTATED")
+6 -4
View File
@@ -3,12 +3,14 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
import yaml
from fastapi.testclient import TestClient
from deerflow.config.app_config import AppConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -331,7 +333,7 @@ class TestMemoryFilePath:
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
path = storage._get_memory_file_path(None)
@@ -344,7 +346,7 @@ class TestMemoryFilePath:
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
path = storage._get_memory_file_path("code-reviewer")
@@ -356,7 +358,7 @@ class TestMemoryFilePath:
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
path_global = storage._get_memory_file_path(None)
+86
View File
@@ -0,0 +1,86 @@
"""Tests for DeerFlowContext and resolve_context()."""
from dataclasses import FrozenInstanceError
from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.sandbox_config import SandboxConfig
def _make_config(**overrides) -> AppConfig:
defaults = {"sandbox": SandboxConfig(use="test")}
defaults.update(overrides)
return AppConfig(**defaults)
class TestDeerFlowContext:
def test_frozen(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
with pytest.raises(FrozenInstanceError):
ctx.app_config = _make_config()
def test_fields(self):
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1", agent_name="test-agent")
assert ctx.thread_id == "t1"
assert ctx.agent_name == "test-agent"
assert ctx.app_config is config
def test_agent_name_default(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
assert ctx.agent_name is None
def test_thread_id_required(self):
with pytest.raises(TypeError):
DeerFlowContext(app_config=_make_config()) # type: ignore[call-arg]
class TestResolveContext:
def test_returns_typed_context_directly(self):
"""Gateway/Client path: runtime.context is DeerFlowContext → return as-is."""
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1")
runtime = MagicMock()
runtime.context = ctx
assert resolve_context(runtime) is ctx
def test_fallback_from_configurable(self):
"""LangGraph Server path: runtime.context is None → construct from ContextVar + configurable."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t2", "agent_name": "ag"}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == "t2"
assert ctx.agent_name == "ag"
assert ctx.app_config is config
def test_fallback_empty_configurable(self):
"""LangGraph Server path with no thread_id in configurable → empty string."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == ""
assert ctx.agent_name is None
def test_fallback_from_dict_context(self):
"""Legacy path: runtime.context is a dict → extract from dict directly."""
runtime = MagicMock()
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
config = _make_config()
with patch.object(AppConfig, "current", return_value=config):
ctx = resolve_context(runtime)
assert ctx.thread_id == "old-dict"
assert ctx.agent_name == "from-dict"
assert ctx.app_config is config
+6 -4
View File
@@ -5,11 +5,13 @@ from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.app_config import AppConfig
@pytest.fixture
def mock_app_config():
"""Mock the app config to return tool configurations."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 5,
@@ -67,7 +69,7 @@ class TestWebSearchTool:
def test_search_with_custom_config(self, mock_exa_client):
"""Test search respects custom configuration values."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 10,
@@ -195,7 +197,7 @@ class TestWebFetchTool:
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"}
mock_config.return_value.get_tool_config.return_value = tool_config
@@ -215,7 +217,7 @@ class TestWebFetchTool:
def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock()
+4 -2
View File
@@ -3,10 +3,12 @@
import json
from unittest.mock import MagicMock, patch
from deerflow.config.app_config import AppConfig
class TestWebSearchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch("deerflow.community.firecrawl.tools.get_app_config")
@patch.object(AppConfig, "current")
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
search_config = MagicMock()
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
@@ -36,7 +38,7 @@ class TestWebSearchTool:
class TestWebFetchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch("deerflow.community.firecrawl.tools.get_app_config")
@patch.object(AppConfig, "current")
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}
+12 -7
View File
@@ -333,12 +333,17 @@ class TestGuardrailsConfig:
assert config.provider.use == "deerflow.guardrails.builtin:AllowlistProvider"
assert config.provider.config == {"denied_tools": ["bash"]}
def test_singleton_load_and_get(self):
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict, reset_guardrails_config
def test_guardrails_config_via_app_config(self):
from unittest.mock import patch
try:
load_guardrails_config_from_dict({"enabled": True, "provider": {"use": "test:Foo"}})
config = get_guardrails_config()
from deerflow.config.app_config import AppConfig
from deerflow.config.guardrails_config import GuardrailProviderConfig, GuardrailsConfig
from deerflow.config.sandbox_config import SandboxConfig
cfg = AppConfig(
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=True, provider=GuardrailProviderConfig(use="test:Foo")),
)
with patch.object(AppConfig, "current", return_value=cfg):
config = AppConfig.current().guardrails
assert config.enabled is True
finally:
reset_guardrails_config()
+4 -3
View File
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
from deerflow.community.infoquest import tools
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
from deerflow.config.app_config import AppConfig
class TestInfoQuestClient:
@@ -149,8 +150,8 @@ class TestInfoQuestClient:
mock_get_client.assert_called_once()
mock_client.fetch.assert_called_once_with("https://example.com")
@patch("deerflow.community.infoquest.tools.get_app_config")
def test_get_infoquest_client(self, mock_get_app_config):
@patch.object(AppConfig, "current")
def test_get_infoquest_client(self, mock_get):
"""Test _get_infoquest_client function with config."""
mock_config = MagicMock()
# Add image_search config to the side_effect
@@ -159,7 +160,7 @@ class TestInfoQuestClient:
MagicMock(model_extra={"fetch_time": 10, "timeout": 30, "navigation_timeout": 60}), # web_fetch config
MagicMock(model_extra={"image_search_time_range": 7, "image_size": "l"}), # image_search config
]
mock_get_app_config.return_value = mock_config
mock_get.return_value = mock_config
client = tools._get_infoquest_client()
+10 -20
View File
@@ -6,7 +6,8 @@ from types import SimpleNamespace
import pytest
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.tools.builtins.invoke_acp_agent_tool import (
_build_acp_mcp_servers,
_build_mcp_servers,
@@ -18,7 +19,6 @@ from deerflow.tools.tools import get_available_tools
def test_build_mcp_servers_filters_disabled_and_maps_transports():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"]),
@@ -40,11 +40,9 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
}
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_acp_mcp_servers_formats_list_payload():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
@@ -77,7 +75,6 @@ def test_build_acp_mcp_servers_formats_list_payload():
]
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_permission_response_prefers_allow_once():
@@ -665,25 +662,20 @@ async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch,
def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(monkeypatch):
from deerflow.config.acp_config import load_acp_config_from_dict
load_acp_config_from_dict(
{
"codex": {
"command": "codex-acp",
"args": [],
"description": "Codex CLI",
}
}
)
fake_config = SimpleNamespace(
tools=[],
models=[],
tool_search=SimpleNamespace(enabled=False),
acp_agents={
"codex": ACPAgentConfig(
command="codex-acp",
args=[],
description="Codex CLI",
)
},
get_model_config=lambda name: None,
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: fake_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: fake_config))
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
@@ -691,5 +683,3 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
assert "invoke_acp_agent" in [tool.name for tool in tools]
load_acp_config_from_dict({})
+3 -2
View File
@@ -9,6 +9,7 @@ import pytest
import deerflow.community.jina_ai.jina_client as jina_client_module
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.community.jina_ai.tools import web_fetch_tool
from deerflow.config.app_config import AppConfig
@pytest.fixture
@@ -154,7 +155,7 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
assert result.startswith("Error:")
@@ -170,7 +171,7 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
assert "Hello world" in result
@@ -40,7 +40,7 @@ def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
with caplog.at_level("WARNING"):
resolved = lead_agent_module._resolve_model_name("missing-model")
@@ -57,7 +57,7 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
resolved = lead_agent_module._resolve_model_name(None)
@@ -67,7 +67,7 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
app_config = _make_app_config([])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
with pytest.raises(
ValueError,
@@ -81,7 +81,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
import deerflow.tools as tools_module
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
@@ -128,7 +128,8 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
AppConfig.init(app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
@@ -140,11 +141,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
AppConfig.init(patched)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: patched))
captured: dict[str, object] = {}
fake_model = object()
+5 -4
View File
@@ -4,12 +4,13 @@ from types import SimpleNamespace
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.skills.types import Skill
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
assert prompt_module._build_custom_mounts_section() == ""
@@ -20,7 +21,7 @@ def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
SimpleNamespace(container_path="/mnt/reference", read_only=True),
]
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
section = prompt_module._build_custom_mounts_section()
@@ -37,7 +38,7 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
sandbox=SimpleNamespace(mounts=mounts),
skills=SimpleNamespace(container_path="/mnt/skills"),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
@@ -55,7 +56,7 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
sandbox=SimpleNamespace(mounts=[]),
skills=SimpleNamespace(container_path="/mnt/skills"),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
+10 -9
View File
@@ -3,6 +3,7 @@ from types import SimpleNamespace
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
from deerflow.config.agents_config import AgentConfig
from deerflow.config.app_config import AppConfig
from deerflow.skills.types import Skill
@@ -58,11 +59,11 @@ def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
skills = [_make_skill("skill1")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr(
"deerflow.config.get_app_config",
lambda: SimpleNamespace(
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
),
)),
)
result = get_skills_prompt_section(available_skills=None)
@@ -72,11 +73,11 @@ def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
monkeypatch.setattr(
"deerflow.config.get_app_config",
lambda: SimpleNamespace(
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
),
)),
)
result = get_skills_prompt_section(available_skills=None)
@@ -90,7 +91,7 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
enabled_result = get_skills_prompt_section(available_skills=None)
assert "Skill Self-Evolution" in enabled_result
@@ -106,7 +107,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from deerflow.agents.lead_agent import agent as lead_agent_module
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: MagicMock()))
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
@@ -118,7 +119,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = MockModelConfig()
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_app_config))
captured_skills = []
@@ -1,5 +1,6 @@
from types import SimpleNamespace
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.tools import get_available_tools
@@ -22,7 +23,7 @@ def _make_config(*, allow_host_bash: bool, sandbox_use: str = "deerflow.sandbox.
def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=False))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _make_config(allow_host_bash=False)))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
@@ -35,7 +36,7 @@ def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=True))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _make_config(allow_host_bash=True)))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
@@ -52,7 +53,7 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
allow_host_bash=False,
extra_tools=[SimpleNamespace(name="shell", group="bash", use="deerflow.sandbox.tools:bash_tool")],
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
@@ -70,7 +71,7 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
allow_host_bash=False,
sandbox_use="deerflow.community.aio_sandbox:AioSandboxProvider",
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
@@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
@@ -312,7 +313,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
@@ -334,7 +335,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@@ -358,7 +359,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@@ -474,7 +475,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
@@ -10,12 +10,22 @@ from deerflow.agents.middlewares.loop_detection_middleware import (
LoopDetectionMiddleware,
_hash_tool_calls,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime(thread_id="test-thread"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
runtime.context = _make_context(thread_id)
return runtime
@@ -293,10 +303,10 @@ class TestLoopDetection:
assert isinstance(mw._lock, type(mw._lock))
def test_fallback_thread_id_when_missing(self):
"""When runtime context has no thread_id, should use 'default'."""
"""When runtime context has empty thread_id, should use 'default'."""
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime = MagicMock()
runtime.context = {}
runtime.context = _make_context("")
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
+6 -7
View File
@@ -1,21 +1,20 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def _make_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
@@ -55,7 +54,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
+17 -11
View File
@@ -11,7 +11,13 @@ from deerflow.agents.memory.storage import (
create_empty_memory,
get_memory_storage,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _app_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
class TestCreateEmptyMemory:
@@ -53,7 +59,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
@@ -87,7 +93,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
memory = storage.load()
assert isinstance(memory, dict)
@@ -103,7 +109,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
result = storage.save(test_memory)
@@ -122,7 +128,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
# First load
memory1 = storage.load()
@@ -150,19 +156,19 @@ class TestGetMemoryStorage:
def test_returns_file_memory_storage_by_default(self):
"""Should return FileMemoryStorage by default."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
def test_falls_back_to_file_memory_storage_on_error(self):
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="non.existent.StorageClass")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="non.existent.StorageClass")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
def test_returns_singleton_instance(self):
"""Should return the same instance on subsequent calls."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage1 = get_memory_storage()
storage2 = get_memory_storage()
assert storage1 is storage2
@@ -173,11 +179,11 @@ class TestGetMemoryStorage:
def get_storage():
# get_memory_storage is called concurrently from multiple threads while
# get_memory_config is patched once around thread creation. This verifies
# AppConfig.get is patched once around thread creation. This verifies
# that the singleton initialization remains thread-safe.
results.append(get_memory_storage())
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
threads = [threading.Thread(target=get_storage) for _ in range(10)]
for t in threads:
t.start()
@@ -191,13 +197,13 @@ class TestGetMemoryStorage:
def test_get_memory_storage_invalid_class_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
# Using a built-in function instead of a class
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="os.path.join")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="os.path.join")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
def test_get_memory_storage_non_subclass_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
# Using 'dict' as a class that is not a MemoryStorage subclass
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="builtins.dict")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="builtins.dict")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
+18 -26
View File
@@ -10,7 +10,9 @@ from deerflow.agents.memory.updater import (
import_memory_data,
update_memory_fact,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
@@ -31,11 +33,8 @@ def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, obje
}
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def _memory_config(**overrides: object) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig().model_copy(update=overrides))
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
@@ -67,8 +66,7 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
@@ -88,8 +86,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
@@ -132,8 +129,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
@@ -160,8 +156,7 @@ def test_apply_updates_preserves_source_error() -> None:
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
@@ -184,8 +179,7 @@ def test_apply_updates_ignores_empty_source_error() -> None:
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
@@ -532,7 +526,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -555,7 +549,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -577,7 +571,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -602,7 +596,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -646,8 +640,7 @@ class TestFactDeduplicationCaseInsensitive:
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
@@ -677,8 +670,7 @@ class TestFactDeduplicationCaseInsensitive:
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
@@ -704,7 +696,7 @@ class TestReinforcementHint:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -729,7 +721,7 @@ class TestReinforcementHint:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -754,7 +746,7 @@ class TestReinforcementHint:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
+5 -5
View File
@@ -72,8 +72,8 @@ class FakeChatModel(BaseChatModel):
def _patch_factory(monkeypatch, app_config: AppConfig, model_class=FakeChatModel):
"""Patch get_app_config, resolve_class, and tracing for isolated unit tests."""
monkeypatch.setattr(factory_module, "get_app_config", lambda: app_config)
"""Patch AppConfig.get, resolve_class, and tracing for isolated unit tests."""
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: model_class)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
@@ -96,7 +96,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
def test_raises_when_model_not_found(monkeypatch):
cfg = _make_app_config([_make_model("only-model")])
monkeypatch.setattr(factory_module, "get_app_config", lambda: cfg)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: cfg))
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
with pytest.raises(ValueError, match="ghost-model"):
@@ -744,7 +744,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
supports_thinking=True,
when_thinking_enabled=wte,
)
model.extra_body = {"top_k": 20}
model = model.model_copy(update={"extra_body": {"top_k": 20}})
cfg = _make_app_config([model])
_patch_factory(monkeypatch, cfg)
@@ -771,7 +771,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
supports_thinking=True,
when_thinking_enabled=wte,
)
model.extra_body = {"top_k": 20}
model = model.model_copy(update={"extra_body": {"top_k": 20}})
cfg = _make_app_config([model])
_patch_factory(monkeypatch, cfg)
@@ -3,13 +3,24 @@
import importlib
from types import SimpleNamespace
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
present_file_tool_module = importlib.import_module("deerflow.tools.builtins.present_file_tool")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime(outputs_path: str) -> SimpleNamespace:
return SimpleNamespace(
state={"thread_data": {"outputs_path": outputs_path}},
context={"thread_id": "thread-1"},
context=_make_context("thread-1"),
)
+4 -3
View File
@@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import patch
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
from deerflow.sandbox.tools import glob_tool, grep_tool
@@ -104,7 +105,7 @@ def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: SimpleNamespace(get_tool_config=lambda name: None)))
result = grep_tool.func(
runtime=runtime,
@@ -325,8 +326,8 @@ def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
monkeypatch.setattr(
"deerflow.sandbox.tools.get_app_config",
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50}))),
)
result = glob_tool.func(
+24 -25
View File
@@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
@@ -617,18 +618,25 @@ def test_apply_cwd_prefix_quotes_path_with_spaces() -> None:
def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None:
"""Bash commands referencing MCP filesystem server paths should be allowed."""
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.config.sandbox_config import SandboxConfig
mock_config = ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=True,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
)
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=mock_config):
def _make_app_config(enabled: bool) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
extensions=ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=enabled,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
),
)
with patch.object(AppConfig, "current", return_value=_make_app_config(True)):
# Should not raise - MCP filesystem paths are allowed
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
@@ -637,19 +645,10 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
# Disabled servers should not expose paths
disabled_config = ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=False,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
)
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# Disabled servers should not expose paths
with patch.object(AppConfig, "current", return_value=_make_app_config(False)):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# ---------- Custom mount path tests ----------
@@ -757,7 +756,7 @@ def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
@@ -786,7 +785,7 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
+2 -1
View File
@@ -2,13 +2,14 @@ from types import SimpleNamespace
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.skills.security_scanner import scan_skill_content
@pytest.mark.anyio
async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None))
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
+21 -18
View File
@@ -4,9 +4,20 @@ from types import SimpleNamespace
import anyio
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _skill_content(name: str, description: str = "Demo skill") -> str:
return f"---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"
@@ -23,9 +34,7 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
refresh_calls = []
async def _refresh():
@@ -34,7 +43,7 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
result = anyio.run(
skill_manage_module.skill_manage_tool.coroutine,
@@ -67,9 +76,7 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
async def _refresh():
return None
@@ -77,7 +84,7 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
@@ -107,10 +114,9 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
runtime = SimpleNamespace(context={}, config={"configurable": {}})
runtime = SimpleNamespace(context=_make_context(""), config={"configurable": {}})
with pytest.raises(ValueError, match="built-in skill"):
anyio.run(
@@ -131,8 +137,7 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
refresh_calls = []
async def _refresh():
@@ -141,7 +146,7 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-sync"}, config={"configurable": {"thread_id": "thread-sync"}})
runtime = SimpleNamespace(context=_make_context("thread-sync"), config={"configurable": {"thread_id": "thread-sync"}})
result = skill_manage_module.skill_manage_tool.func(
runtime=runtime,
action="create",
@@ -159,9 +164,7 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
async def _refresh():
return None
@@ -169,7 +172,7 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):
+11 -8
View File
@@ -6,6 +6,9 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.routers import skills as skills_router
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.skills.manager import get_skill_history_file
from deerflow.skills.types import Skill
@@ -43,8 +46,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
@@ -93,8 +95,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
get_skill_history_file("demo-skill").write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8",
@@ -135,8 +136,7 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
@@ -179,9 +179,12 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
refresh_calls.append("refresh")
enabled_state["value"] = False
_app_cfg = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ExtensionsConfig(mcp_servers={}, skills={}))
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
monkeypatch.setattr("app.gateway.routers.skills.get_extensions_config", lambda: SimpleNamespace(mcp_servers={}, skills={}))
monkeypatch.setattr("app.gateway.routers.skills.reload_extensions_config", lambda: None)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(AppConfig, "init", staticmethod(lambda _cfg: None))
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
+117 -349
View File
@@ -3,40 +3,38 @@
Covers:
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
- get_timeout_for() / get_max_turns_for() resolution logic
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
- AppConfig.subagents field access via AppConfig.current()
- registry.get_subagent_config() applies config overrides
- registry.list_subagents() applies overrides for all agents
- Polling timeout calculation in task_tool is consistent with config
"""
from unittest.mock import patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.subagents_config import (
SubagentOverrideConfig,
SubagentsAppConfig,
get_subagents_app_config,
load_subagents_config_from_dict,
)
from deerflow.subagents.config import SubagentConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _reset_subagents_config(
def _make_config(
timeout_seconds: int = 900,
*,
max_turns: int | None = None,
agents: dict | None = None,
) -> None:
"""Reset global subagents config to a known state."""
load_subagents_config_from_dict(
{
"timeout_seconds": timeout_seconds,
"max_turns": max_turns,
"agents": agents or {},
}
) -> AppConfig:
"""Build an AppConfig with the given subagents settings."""
return AppConfig(
sandbox=SandboxConfig(use="test"),
subagents=SubagentsAppConfig(
timeout_seconds=timeout_seconds,
max_turns=max_turns,
agents={k: SubagentOverrideConfig(**v) for k, v in (agents or {}).items()},
),
)
@@ -51,364 +49,134 @@ class TestSubagentOverrideConfig:
assert override.timeout_seconds is None
assert override.max_turns is None
def test_explicit_value(self):
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42)
assert override.timeout_seconds == 300
assert override.max_turns == 42
def test_rejects_zero(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=0)
def test_rejects_negative(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=-1)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=-1)
def test_minimum_valid_value(self):
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
assert override.timeout_seconds == 1
assert override.max_turns == 1
# ---------------------------------------------------------------------------
# SubagentsAppConfig defaults and validation
# ---------------------------------------------------------------------------
class TestSubagentsAppConfigDefaults:
def test_default_timeout(self):
config = SubagentsAppConfig()
assert config.timeout_seconds == 900
def test_default_max_turns_override_is_none(self):
config = SubagentsAppConfig()
assert config.max_turns is None
def test_default_agents_empty(self):
config = SubagentsAppConfig()
assert config.agents == {}
def test_custom_global_runtime_overrides(self):
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
assert config.timeout_seconds == 1800
assert config.max_turns == 120
def test_rejects_zero_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=0)
def test_explicit_values(self):
override = SubagentOverrideConfig(timeout_seconds=120, max_turns=50)
assert override.timeout_seconds == 120
assert override.max_turns == 50
def test_rejects_negative_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=-60)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=-60)
with pytest.raises(Exception):
SubagentOverrideConfig(timeout_seconds=-1)
def test_rejects_zero_timeout(self):
with pytest.raises(Exception):
SubagentOverrideConfig(timeout_seconds=0)
# ---------------------------------------------------------------------------
# SubagentsAppConfig resolution helpers
# SubagentsAppConfig model
# ---------------------------------------------------------------------------
class TestRuntimeResolution:
def test_returns_global_default_when_no_override(self):
class TestSubagentsAppConfig:
def test_default_timeout_is_900(self):
config = SubagentsAppConfig()
assert config.timeout_seconds == 900
assert config.max_turns is None
assert config.agents == {}
def test_custom_defaults(self):
config = SubagentsAppConfig(timeout_seconds=300, max_turns=50)
assert config.timeout_seconds == 300
assert config.max_turns == 50
# ---------------------------------------------------------------------------
# get_timeout_for / get_max_turns_for
# ---------------------------------------------------------------------------
class TestTimeoutResolution:
def test_global_timeout_for_unknown_agent(self):
config = SubagentsAppConfig(timeout_seconds=600)
assert config.get_timeout_for("unknown") == 600
def test_per_agent_timeout_overrides_global(self):
config = SubagentsAppConfig(
timeout_seconds=600,
agents={"bash": SubagentOverrideConfig(timeout_seconds=120)},
)
assert config.get_timeout_for("bash") == 120
assert config.get_timeout_for("general-purpose") == 600
def test_per_agent_override_none_falls_back_to_global(self):
config = SubagentsAppConfig(
timeout_seconds=600,
agents={"bash": SubagentOverrideConfig(timeout_seconds=None)},
)
assert config.get_timeout_for("bash") == 600
assert config.get_timeout_for("unknown-agent") == 600
assert config.get_max_turns_for("general-purpose", 100) == 100
class TestMaxTurnsResolution:
def test_builtin_default_when_no_override(self):
config = SubagentsAppConfig()
assert config.get_max_turns_for("bash", 60) == 60
def test_returns_per_agent_override_when_set(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=120,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
)
assert config.get_timeout_for("bash") == 300
assert config.get_max_turns_for("bash", 60) == 80
def test_global_max_turns_overrides_builtin(self):
config = SubagentsAppConfig(max_turns=100)
assert config.get_max_turns_for("bash", 60) == 100
def test_other_agents_still_use_global_default(self):
def test_per_agent_max_turns_overrides_global(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=140,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
max_turns=100,
agents={"bash": SubagentOverrideConfig(max_turns=30)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 140
assert config.get_max_turns_for("bash", 60) == 30
assert config.get_max_turns_for("general-purpose", 60) == 100
def test_agent_with_none_override_falls_back_to_global(self):
def test_per_agent_override_none_falls_back(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=150,
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
max_turns=100,
agents={"bash": SubagentOverrideConfig(max_turns=None)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 150
assert config.get_max_turns_for("bash", 60) == 100
def test_multiple_per_agent_overrides(self):
config = SubagentsAppConfig(
# ---------------------------------------------------------------------------
# AppConfig.subagents via AppConfig.current()
# ---------------------------------------------------------------------------
class TestAppConfigSubagents:
def test_load_global_timeout(self):
cfg = _make_config(timeout_seconds=300, max_turns=120)
with patch.object(AppConfig, "current", return_value=cfg):
sub = AppConfig.current().subagents
assert sub.timeout_seconds == 300
assert sub.max_turns == 120
def test_load_with_per_agent_overrides(self):
cfg = _make_config(
timeout_seconds=900,
max_turns=120,
agents={
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
)
assert config.get_timeout_for("general-purpose") == 1800
assert config.get_timeout_for("bash") == 120
assert config.get_max_turns_for("general-purpose", 100) == 200
assert config.get_max_turns_for("bash", 60) == 80
# ---------------------------------------------------------------------------
# load_subagents_config_from_dict / get_subagents_app_config singleton
# ---------------------------------------------------------------------------
class TestLoadSubagentsConfig:
def teardown_method(self):
"""Restore defaults after each test."""
_reset_subagents_config()
def test_load_global_timeout(self):
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
assert get_subagents_app_config().timeout_seconds == 300
assert get_subagents_app_config().max_turns == 120
def test_load_with_per_agent_overrides(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 1800
assert cfg.get_timeout_for("bash") == 60
assert cfg.get_max_turns_for("general-purpose", 100) == 200
assert cfg.get_max_turns_for("bash", 60) == 80
with patch.object(AppConfig, "current", return_value=cfg):
sub = AppConfig.current().subagents
assert sub.get_timeout_for("general-purpose") == 1800
assert sub.get_timeout_for("bash") == 60
assert sub.get_max_turns_for("general-purpose", 100) == 200
assert sub.get_max_turns_for("bash", 60) == 80
def test_load_partial_override(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 600,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
}
cfg = _make_config(
timeout_seconds=600,
agents={"bash": {"timeout_seconds": 120, "max_turns": 70}},
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 600
assert cfg.get_timeout_for("bash") == 120
assert cfg.get_max_turns_for("general-purpose", 100) == 100
assert cfg.get_max_turns_for("bash", 60) == 70
with patch.object(AppConfig, "current", return_value=cfg):
sub = AppConfig.current().subagents
assert sub.get_timeout_for("general-purpose") == 600
assert sub.get_timeout_for("bash") == 120
assert sub.get_max_turns_for("general-purpose", 100) == 100
assert sub.get_max_turns_for("bash", 60) == 70
def test_load_empty_dict_uses_defaults(self):
load_subagents_config_from_dict({})
cfg = get_subagents_app_config()
assert cfg.timeout_seconds == 900
assert cfg.max_turns is None
assert cfg.agents == {}
def test_load_replaces_previous_config(self):
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
assert get_subagents_app_config().timeout_seconds == 100
assert get_subagents_app_config().max_turns == 90
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
assert get_subagents_app_config().timeout_seconds == 200
assert get_subagents_app_config().max_turns == 110
def test_singleton_returns_same_instance_between_calls(self):
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
assert get_subagents_app_config() is get_subagents_app_config()
# ---------------------------------------------------------------------------
# registry.get_subagent_config runtime overrides applied
# ---------------------------------------------------------------------------
class TestRegistryGetSubagentConfig:
def teardown_method(self):
_reset_subagents_config()
def test_returns_none_for_unknown_agent(self):
from deerflow.subagents.registry import get_subagent_config
assert get_subagent_config("nonexistent") is None
def test_returns_config_for_builtin_agents(self):
from deerflow.subagents.registry import get_subagent_config
assert get_subagent_config("general-purpose") is not None
assert get_subagent_config("bash") is not None
def test_default_timeout_preserved_when_no_config(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=900)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 900
assert config.max_turns == 100
def test_global_timeout_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 1800
assert config.max_turns == 140
def test_per_agent_runtime_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.timeout_seconds == 120
assert bash_config.max_turns == 80
def test_per_agent_override_does_not_affect_other_agents(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
gp_config = get_subagent_config("general-purpose")
assert gp_config.timeout_seconds == 900
assert gp_config.max_turns == 120
def test_builtin_config_object_is_not_mutated(self):
"""Registry must return a new object, leaving the builtin default intact."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
returned = get_subagent_config("bash")
assert returned.timeout_seconds == 42
assert returned.max_turns == 88
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
def test_config_preserves_other_fields(self):
"""Applying runtime overrides must not change other SubagentConfig fields."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=300, max_turns=140)
original = BUILTIN_SUBAGENTS["general-purpose"]
overridden = get_subagent_config("general-purpose")
assert overridden.name == original.name
assert overridden.description == original.description
assert overridden.max_turns == 140
assert overridden.model == original.model
assert overridden.tools == original.tools
assert overridden.disallowed_tools == original.disallowed_tools
# ---------------------------------------------------------------------------
# registry.list_subagents all agents get overrides
# ---------------------------------------------------------------------------
class TestRegistryListSubagents:
def teardown_method(self):
_reset_subagents_config()
def test_lists_both_builtin_agents(self):
from deerflow.subagents.registry import list_subagents
names = {cfg.name for cfg in list_subagents()}
assert "general-purpose" in names
assert "bash" in names
def test_all_returned_configs_get_global_override(self):
from deerflow.subagents.registry import list_subagents
_reset_subagents_config(timeout_seconds=123, max_turns=77)
for cfg in list_subagents():
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
def test_per_agent_overrides_reflected_in_list(self):
from deerflow.subagents.registry import list_subagents
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
by_name = {cfg.name: cfg for cfg in list_subagents()}
assert by_name["general-purpose"].timeout_seconds == 1800
assert by_name["bash"].timeout_seconds == 60
assert by_name["general-purpose"].max_turns == 200
assert by_name["bash"].max_turns == 80
# ---------------------------------------------------------------------------
# Polling timeout calculation (logic extracted from task_tool)
# ---------------------------------------------------------------------------
class TestPollingTimeoutCalculation:
"""Verify the formula (timeout_seconds + 60) // 5 is correct for various inputs."""
@pytest.mark.parametrize(
"timeout_seconds, expected_max_polls",
[
(900, 192), # default 15 min → (900+60)//5 = 192
(300, 72), # 5 min → (300+60)//5 = 72
(1800, 372), # 30 min → (1800+60)//5 = 372
(60, 24), # 1 min → (60+60)//5 = 24
(1, 12), # minimum → (1+60)//5 = 12
],
)
def test_polling_timeout_formula(self, timeout_seconds: int, expected_max_polls: int):
dummy_config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
timeout_seconds=timeout_seconds,
)
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
assert max_poll_count == expected_max_polls
def test_polling_timeout_exceeds_execution_timeout(self):
"""Safety-net polling window must always be longer than the execution timeout."""
for timeout_seconds in [60, 300, 900, 1800]:
dummy_config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
timeout_seconds=timeout_seconds,
)
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
polling_window_seconds = max_poll_count * 5
assert polling_window_seconds > timeout_seconds
def test_load_empty_uses_defaults(self):
cfg = _make_config()
with patch.object(AppConfig, "current", return_value=cfg):
sub = AppConfig.current().subagents
assert sub.timeout_seconds == 900
assert sub.max_turns is None
assert sub.agents == {}
+11 -1
View File
@@ -8,6 +8,9 @@ from unittest.mock import MagicMock
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.subagents.config import SubagentConfig
# Use module import so tests can patch the exact symbols referenced inside task_tool().
@@ -24,6 +27,13 @@ class FakeSubagentStatus(Enum):
TIMED_OUT = "timed_out"
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime() -> SimpleNamespace:
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
return SimpleNamespace(
@@ -35,7 +45,7 @@ def _make_runtime() -> SimpleNamespace:
"outputs_path": "/tmp/outputs",
},
},
context={"thread_id": "thread-1"},
context=_make_context("thread-1"),
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
)
+23 -26
View File
@@ -1,58 +1,55 @@
import pytest
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
def _as_posix(path: str) -> str:
return path.replace("\\", "/")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
class TestThreadDataMiddleware:
def test_before_agent_returns_paths_when_thread_id_present_in_context(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
from langgraph.runtime import Runtime
result = middleware.before_agent(state={}, runtime=Runtime(context={"thread_id": "thread-123"}))
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-123")))
assert result is not None
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-123/user-data/workspace")
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-123/user-data/uploads")
assert _as_posix(result["thread_data"]["outputs_path"]).endswith("threads/thread-123/user-data/outputs")
def test_before_agent_uses_thread_id_from_configurable_when_context_is_none(self, tmp_path, monkeypatch):
def test_before_agent_uses_thread_id_from_context(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
runtime = Runtime(context=None)
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
lambda: {"configurable": {"thread_id": "thread-from-config"}},
)
from langgraph.runtime import Runtime
result = middleware.before_agent(state={}, runtime=runtime)
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-config")))
assert result is not None
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-from-config/user-data/workspace")
assert runtime.context is None
def test_before_agent_uses_thread_id_from_configurable_when_context_missing_thread_id(self, tmp_path, monkeypatch):
def test_before_agent_uses_thread_id_from_typed_context(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
runtime = Runtime(context={})
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
lambda: {"configurable": {"thread_id": "thread-from-config"}},
)
from langgraph.runtime import Runtime
result = middleware.before_agent(state={}, runtime=runtime)
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-dict")))
assert result is not None
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-config/user-data/uploads")
assert runtime.context == {}
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-dict/user-data/uploads")
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
def test_before_agent_raises_clear_error_when_thread_id_missing(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
lambda: {"configurable": {}},
)
from langgraph.runtime import Runtime
with pytest.raises(ValueError, match="Thread ID is required in runtime context or config.configurable"):
middleware.before_agent(state={}, runtime=Runtime(context=None))
with pytest.raises(ValueError, match="Thread ID is required"):
middleware.before_agent(state={}, runtime=Runtime(context=_make_context("")))
+1 -36
View File
@@ -3,7 +3,7 @@
import pytest
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
from deerflow.config.title_config import TitleConfig
class TestTitleConfig:
@@ -44,21 +44,6 @@ class TestTitleConfig:
with pytest.raises(ValueError):
TitleConfig(max_chars=201)
def test_get_set_config(self):
"""Test global config getter and setter."""
original_config = get_title_config()
# Set new config
new_config = TitleConfig(enabled=False, max_words=10)
set_title_config(new_config)
# Verify it was set
assert get_title_config().enabled is False
assert get_title_config().max_words == 10
# Restore original config
set_title_config(original_config)
class TestTitleMiddleware:
"""Tests for TitleMiddleware."""
@@ -68,23 +53,3 @@ class TestTitleMiddleware:
middleware = TitleMiddleware()
assert middleware is not None
assert middleware.state_schema is not None
# TODO: Add integration tests with mock Runtime
# def test_should_generate_title(self):
# """Test title generation trigger logic."""
# pass
# def test_generate_title(self):
# """Test title generation."""
# pass
# def test_after_agent_hook(self):
# """Test after_agent hook."""
# pass
# TODO: Add integration tests
# - Test with real LangGraph runtime
# - Test title persistence with checkpointer
# - Test fallback behavior when LLM fails
# - Test concurrent title generation
+107 -117
View File
@@ -1,137 +1,127 @@
"""Core behavior tests for TitleMiddleware."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.title_config import TitleConfig
def _clone_title_config(config: TitleConfig) -> TitleConfig:
# Avoid mutating shared global config objects across tests.
return TitleConfig(**config.model_dump())
def _make_config(**title_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
def _set_test_title_config(**overrides) -> TitleConfig:
config = _clone_title_config(get_title_config())
for key, value in overrides.items():
setattr(config, key, value)
set_title_config(config)
return config
def _patch_app_config(**title_overrides):
return patch.object(AppConfig, "current", return_value=_make_config(**title_overrides))
class TestTitleMiddlewareCoreLogic:
def setup_method(self):
# Title config is a global singleton; snapshot and restore for test isolation.
self._original = _clone_title_config(get_title_config())
def teardown_method(self):
set_title_config(self._original)
def test_should_generate_title_for_first_complete_exchange(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="帮我总结这段代码"),
AIMessage(content="好的,我先看结构"),
]
}
with _patch_app_config(enabled=True):
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="帮我总结这段代码"),
AIMessage(content="好的,我先看结构"),
]
}
assert middleware._should_generate_title(state) is True
assert middleware._should_generate_title(state) is True
def test_should_not_generate_title_when_disabled_or_already_set(self):
middleware = TitleMiddleware()
_set_test_title_config(enabled=False)
disabled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": None,
}
assert middleware._should_generate_title(disabled_state) is False
with _patch_app_config(enabled=False):
disabled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": None,
}
assert middleware._should_generate_title(disabled_state) is False
_set_test_title_config(enabled=True)
titled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": "Existing Title",
}
assert middleware._should_generate_title(titled_state) is False
with _patch_app_config(enabled=True):
titled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": "Existing Title",
}
assert middleware._should_generate_title(titled_state) is False
def test_should_not_generate_title_after_second_user_turn(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="第一问"),
AIMessage(content="第一答"),
HumanMessage(content="第二问"),
AIMessage(content="第二答"),
]
}
with _patch_app_config(enabled=True):
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="第一问"),
AIMessage(content="第一答"),
HumanMessage(content="第二问"),
AIMessage(content="第二答"),
]
}
assert middleware._should_generate_title(state) is False
assert middleware._should_generate_title(state) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
with _patch_app_config(max_chars=12):
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
state = {
"messages": [
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
with _patch_app_config(max_chars=20):
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
]
}
state = {
"messages": [
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert title == "请帮我总结这段代码"
assert title == "请帮我总结这段代码"
def test_generate_title_fallback_for_long_message(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
with _patch_app_config(max_chars=20):
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
AIMessage(content="收到"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
AIMessage(content="收到"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述")
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述")
def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
middleware = TitleMiddleware()
@@ -155,29 +145,29 @@ class TestTitleMiddlewareCoreLogic:
def test_sync_generate_title_uses_fallback_without_model(self):
"""Sync path avoids LLM calls and derives a local fallback title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
with _patch_app_config(max_chars=20):
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "请帮我写测试"}
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "请帮我写测试"}
def test_sync_generate_title_respects_fallback_truncation(self):
"""Sync fallback path still respects max_chars truncation rules."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
with _patch_app_config(max_chars=50):
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")
+2 -1
View File
@@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.client import DeerFlowClient
from deerflow.config.app_config import AppConfig
# ---------------------------------------------------------------------------
# _serialize_message — usage_metadata passthrough
@@ -154,7 +155,7 @@ class TestStreamUsageIntegration:
"""Test that stream() emits usage_metadata in messages-tuple and end events."""
def _make_client(self):
with patch("deerflow.client.get_app_config", return_value=_mock_app_config()):
with patch.object(AppConfig, "current", return_value=_mock_app_config()):
return DeerFlowClient()
def test_stream_emits_usage_in_messages_tuple(self):
+10 -15
View File
@@ -6,7 +6,8 @@ import sys
import pytest
from langchain_core.tools import tool as langchain_tool
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
from deerflow.tools.builtins.tool_search import (
DeferredToolRegistry,
get_deferred_registry,
@@ -62,12 +63,12 @@ class TestToolSearchConfig:
config = ToolSearchConfig(enabled=True)
assert config.enabled is True
def test_load_from_dict(self):
config = load_tool_search_config_from_dict({"enabled": True})
def test_validate_from_dict(self):
config = ToolSearchConfig.model_validate({"enabled": True})
assert config.enabled is True
def test_load_from_empty_dict(self):
config = load_tool_search_config_from_dict({})
def test_validate_from_empty_dict(self):
config = ToolSearchConfig.model_validate({})
assert config.enabled is False
@@ -263,7 +264,7 @@ class TestDeferredToolsPromptSection:
mock_config = MagicMock()
mock_config.tool_search = ToolSearchConfig() # disabled by default
monkeypatch.setattr("deerflow.config.get_app_config", lambda: mock_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
def test_empty_when_disabled(self):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
@@ -274,26 +275,20 @@ class TestDeferredToolsPromptSection:
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
from deerflow.config import get_app_config
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
section = get_deferred_tools_prompt_section()
assert section == ""
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
from deerflow.config import get_app_config
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
set_deferred_registry(DeferredToolRegistry())
section = get_deferred_tools_prompt_section()
assert section == ""
def test_lists_tool_names(self, registry, monkeypatch):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
from deerflow.config import get_app_config
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
set_deferred_registry(registry)
section = get_deferred_tools_prompt_section()
assert "<available-deferred-tools>" in section
@@ -13,7 +13,10 @@ from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths
from deerflow.config.sandbox_config import SandboxConfig
THREAD_ID = "thread-abc123"
@@ -23,13 +26,20 @@ THREAD_ID = "thread-abc123"
# ---------------------------------------------------------------------------
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _middleware(tmp_path: Path) -> UploadsMiddleware:
return UploadsMiddleware(base_dir=str(tmp_path))
def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock:
rt = MagicMock()
rt.context = {"thread_id": thread_id}
rt.context = _make_context(thread_id or "")
return rt