Merge refactor/config-deerflow-context into release/2.0-rc

Cherry-pick PR #2271's config refactor onto release/2.0-rc.
Used 'git merge -X theirs' to auto-resolve content conflicts in favor of
the PR's design (frozen AppConfig + explicit-parameter passing).

Limitations:
- Release-only changes that overlapped with PR's refactor in 119 files
  are NOT preserved — those files reflect PR's version. Follow-up commits
  on this branch will need to re-apply release-only modifications where
  meaningful.
- See PR #2271 for design rationale.
This commit is contained in:
greatmengqi
2026-04-27 18:16:42 +08:00
227 changed files with 6965 additions and 5578 deletions
+6 -1
View File
@@ -29,6 +29,7 @@ apps with the real middleware — those should not use this module.
from __future__ import annotations
from collections.abc import Callable
from typing import ParamSpec, TypeVar
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
@@ -112,7 +113,11 @@ def make_authed_test_app(
return app
def call_unwrapped[*P, R](decorated: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R:
_P = ParamSpec("_P")
_R = TypeVar("_R")
def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""Invoke the underlying function of a ``@require_permission``-decorated route.
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all
+27
View File
@@ -68,6 +68,33 @@ def provisioner_module():
# context should mark themselves ``@pytest.mark.no_auto_user``.
@pytest.fixture(autouse=True)
def _auto_app_config_from_file(monkeypatch, request):
"""Replace ``AppConfig.from_file`` with a minimal factory so tests that
(directly or indirectly, e.g. via the LangGraph Server bootstrap path in
``make_lead_agent``) load AppConfig from disk do not need a real
``config.yaml`` on the filesystem.
Tests that want to verify the real ``from_file`` behaviour should mark
themselves with ``@pytest.mark.real_from_file``.
"""
if request.node.get_closest_marker("real_from_file"):
yield
return
try:
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
except ImportError:
yield
return
def _fake_from_file(config_path: str | None = None) -> AppConfig: # noqa: ARG001
return AppConfig(sandbox=SandboxConfig(use="test"))
monkeypatch.setattr(AppConfig, "from_file", _fake_from_file)
yield
@pytest.fixture(autouse=True)
def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar.
+30 -37
View File
@@ -2,21 +2,27 @@
import json
import pytest
import pytest
import yaml
pytestmark = pytest.mark.real_from_file
from pydantic import ValidationError
from deerflow.config.acp_config import ACPAgentConfig, get_acp_agents, load_acp_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def setup_function():
"""Reset ACP config before each test."""
load_acp_config_from_dict({})
def _make_config(acp_agents: dict | None = None) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
acp_agents={name: ACPAgentConfig(**cfg) for name, cfg in (acp_agents or {}).items()},
)
def test_load_acp_config_sets_agents():
load_acp_config_from_dict(
def test_acp_agents_via_app_config():
cfg = _make_config(
{
"claude_code": {
"command": "claude-code-acp",
@@ -26,39 +32,33 @@ def test_load_acp_config_sets_agents():
}
}
)
agents = get_acp_agents()
agents = cfg.acp_agents
assert "claude_code" in agents
assert agents["claude_code"].command == "claude-code-acp"
assert agents["claude_code"].description == "Claude Code for coding tasks"
assert agents["claude_code"].model is None
def test_load_acp_config_multiple_agents():
load_acp_config_from_dict(
def test_multiple_agents():
cfg = _make_config(
{
"claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"},
"codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"},
}
)
agents = get_acp_agents()
agents = cfg.acp_agents
assert len(agents) == 2
assert agents["codex"].args == ["--flag"]
def test_load_acp_config_empty_clears_agents():
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert len(get_acp_agents()) == 1
load_acp_config_from_dict({})
assert len(get_acp_agents()) == 0
def test_empty_acp_agents():
cfg = _make_config({})
assert cfg.acp_agents == {}
def test_load_acp_config_none_clears_agents():
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert len(get_acp_agents()) == 1
load_acp_config_from_dict(None)
assert get_acp_agents() == {}
def test_default_acp_agents_empty():
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
assert cfg.acp_agents == {}
def test_acp_agent_config_defaults():
@@ -79,8 +79,8 @@ def test_acp_agent_config_env_default_is_empty():
assert cfg.env == {}
def test_load_acp_config_preserves_env():
load_acp_config_from_dict(
def test_acp_agent_preserves_env():
cfg = _make_config(
{
"codex": {
"command": "codex-acp",
@@ -90,8 +90,7 @@ def test_load_acp_config_preserves_env():
}
}
)
cfg = get_acp_agents()["codex"]
assert cfg.env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
assert cfg.acp_agents["codex"].env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
def test_acp_agent_config_with_model():
@@ -115,13 +114,7 @@ def test_acp_agent_config_missing_description_raises():
ACPAgentConfig(command="my-agent")
def test_get_acp_agents_returns_empty_by_default():
"""After clearing, should return empty dict."""
load_acp_config_from_dict({})
assert get_acp_agents() == {}
def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, monkeypatch):
def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -157,9 +150,9 @@ def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, mo
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8")
AppConfig.from_file(str(config_path))
assert set(get_acp_agents()) == {"codex"}
app = AppConfig.from_file(str(config_path))
assert set(app.acp_agents) == {"codex"}
config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8")
AppConfig.from_file(str(config_path))
assert get_acp_agents() == {}
app = AppConfig.from_file(str(config_path))
assert app.acp_agents == {}
+39 -121
View File
@@ -1,13 +1,14 @@
from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
import yaml
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
from deerflow.config.app_config import AppConfig
pytestmark = pytest.mark.real_from_file
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@@ -29,149 +30,66 @@ def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> No
)
def _write_config_with_agents_api(
path: Path,
*,
model_name: str,
supports_thinking: bool,
agents_api: dict | None = None,
) -> None:
config = {
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [
{
"name": model_name,
"use": "langchain_openai:ChatOpenAI",
"model": "gpt-test",
"supports_thinking": supports_thinking,
}
],
}
if agents_api is not None:
config["agents_api"] = agents_api
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
def test_app_config_defaults_missing_database_to_sqlite(tmp_path, monkeypatch):
def test_from_file_reads_model_name(tmp_path, monkeypatch):
"""``AppConfig.from_file`` is the only lifecycle method now; there is no
process-global ``init/current``. Each consumer holds its own captured
AppConfig instance.
"""
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_path, model_name="first-model", supports_thinking=False)
_write_config(config_path, model_name="test-model", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config = AppConfig.from_file(str(config_path))
assert config.database.backend == "sqlite"
assert config.database.sqlite_dir == ".deer-flow/data"
assert config.models[0].name == "test-model"
def test_app_config_defaults_empty_database_to_sqlite(tmp_path, monkeypatch):
def test_from_file_each_call_returns_fresh_instance(tmp_path, monkeypatch):
"""Two reads of the same file produce separate AppConfig instances —
no hidden singleton, no memoization. Callers decide when to re-read.
"""
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_path, model_name="model-a", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config_a = AppConfig.from_file(str(config_path))
assert config_a.models[0].name == "model-a"
_write_config(config_path, model_name="model-b", supports_thinking=True)
config_b = AppConfig.from_file(str(config_path))
assert config_b.models[0].name == "model-b"
assert config_a is not config_b
def test_config_version_check(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
config_path.write_text(
yaml.safe_dump(
{
"database": {},
"config_version": 1,
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [],
}
),
encoding="utf-8",
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config = AppConfig.from_file(str(config_path))
assert config.database.backend == "sqlite"
assert config.database.sqlite_dir == ".deer-flow/data"
def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_path, model_name="first-model", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
initial = get_app_config()
assert initial.models[0].supports_thinking is False
_write_config(config_path, model_name="first-model", supports_thinking=True)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded.models[0].supports_thinking is True
assert reloaded is not initial
finally:
reset_app_config()
def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
config_a = tmp_path / "config-a.yaml"
config_b = tmp_path / "config-b.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_a, model_name="model-a", supports_thinking=False)
_write_config(config_b, model_name="model-b", supports_thinking=True)
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_a))
reset_app_config()
try:
first = get_app_config()
assert first.models[0].name == "model-a"
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_b))
second = get_app_config()
assert second.models[0].name == "model-b"
assert second is not first
finally:
reset_app_config()
def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_agents_api(
config_path,
model_name="first-model",
supports_thinking=False,
agents_api={"enabled": True},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
initial = get_app_config()
assert initial.models[0].name == "first-model"
assert get_agents_api_config().enabled is True
_write_config_with_agents_api(
config_path,
model_name="first-model",
supports_thinking=False,
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded is not initial
assert get_agents_api_config().enabled is False
finally:
reset_app_config()
assert config is not None
-14
View File
@@ -174,20 +174,6 @@ def test_protected_post_no_cookie_returns_401(client):
assert res.status_code == 401
def test_protected_post_with_internal_auth_header_passes():
from app.gateway.internal_auth import create_internal_auth_headers
app = _make_app()
client = TestClient(app)
res = client.post(
"/api/threads/abc/runs/stream",
headers=create_internal_auth_headers(),
)
assert res.status_code == 200
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
+43 -123
View File
@@ -5,25 +5,21 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import deerflow.config.app_config as app_config_module
from deerflow.config.checkpointer_config import (
CheckpointerConfig,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
def _make_config(checkpointer: CheckpointerConfig | None = None) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), checkpointer=checkpointer)
@pytest.fixture(autouse=True)
def reset_state():
"""Reset singleton state before each test."""
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
@@ -33,24 +29,18 @@ def reset_state():
class TestCheckpointerConfig:
def test_load_memory_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
def test_memory_config(self):
config = CheckpointerConfig(type="memory")
assert config.type == "memory"
assert config.connection_string is None
def test_load_sqlite_config(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
def test_sqlite_config(self):
config = CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")
assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db"
def test_load_postgres_config(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
def test_postgres_config(self):
config = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db"
@@ -58,14 +48,9 @@ class TestCheckpointerConfig:
config = CheckpointerConfig(type="memory")
assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
CheckpointerConfig(type="unknown")
# ---------------------------------------------------------------------------
@@ -78,58 +63,58 @@ class TestGetCheckpointer:
"""get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
cp = get_checkpointer()
cfg = _make_config()
cp = get_checkpointer(cfg)
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver
cp = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
cp = get_checkpointer(cfg)
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cp2 = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
cp1 = get_checkpointer(cfg)
cp2 = get_checkpointer(cfg)
assert cp1 is cp2
def test_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
cp1 = get_checkpointer(cfg)
reset_checkpointer()
cp2 = get_checkpointer()
cp2 = get_checkpointer(cfg)
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
get_checkpointer(cfg)
def test_postgres_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
get_checkpointer(cfg)
def test_postgres_raises_when_connection_string_missing(self):
load_checkpointer_config_from_dict({"type": "postgres"})
cfg = _make_config(CheckpointerConfig(type="postgres"))
mock_saver = MagicMock()
mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
get_checkpointer(cfg)
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
@@ -144,7 +129,7 @@ class TestGetCheckpointer:
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
reset_checkpointer()
cp = get_checkpointer()
cp = get_checkpointer(cfg)
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once()
@@ -225,7 +210,7 @@ class TestGetCheckpointer:
def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
@@ -240,7 +225,7 @@ class TestGetCheckpointer:
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
reset_checkpointer()
cp = get_checkpointer()
cp = get_checkpointer(cfg)
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
@@ -268,7 +253,6 @@ class TestAsyncCheckpointer:
mock_module.AsyncSqliteSaver = mock_saver_cls
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch(
@@ -276,7 +260,7 @@ class TestAsyncCheckpointer:
return_value="/tmp/resolved/test.db",
),
):
async with make_checkpointer() as saver:
async with make_checkpointer(mock_config) as saver:
assert saver is mock_saver
mock_to_thread.assert_awaited_once()
@@ -294,12 +278,10 @@ class TestAsyncCheckpointer:
class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self):
"""load_checkpointer_config_from_dict populates the global config."""
set_checkpointer_config(None)
load_checkpointer_config_from_dict({"type": "memory"})
cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
"""AppConfig with checkpointer section has the correct config."""
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None
assert cfg.checkpointer.type == "memory"
# ---------------------------------------------------------------------------
@@ -309,69 +291,7 @@ class TestAppConfigLoadsCheckpointer:
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp
"""DeerFlowClient._ensure_agent falls back to get_checkpointer(app_config) when checkpointer=None."""
# This is a structural test — verifying the fallback path exists.
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None
+20 -24
View File
@@ -1,6 +1,6 @@
"""Test for issue #1016: checkpointer should not return None."""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from langgraph.checkpoint.memory import InMemorySaver
@@ -14,42 +14,38 @@ class TestCheckpointerNoneFix:
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
# Mock get_app_config to return a config with checkpointer=None and database=None
mock_config = MagicMock()
mock_config.checkpointer = None
mock_config.database = None
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
async with make_checkpointer() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
async with make_checkpointer(mock_config) as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
# Should be able to call alist() without AttributeError
# This is what LangGraph does and what was failing in issue #1016
result = []
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
result.append(item)
# Should be able to call alist() without AttributeError
# This is what LangGraph does and what was failing in issue #1016
result = []
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
result.append(item)
# Empty list is expected for a fresh checkpointer
assert result == []
# Empty list is expected for a fresh checkpointer
assert result == []
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
from deerflow.runtime.checkpointer.provider import checkpointer_context
# Mock get_app_config to return a config with checkpointer=None
mock_config = MagicMock()
mock_config.checkpointer = None
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
with checkpointer_context() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
with checkpointer_context(mock_config) as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
# Should be able to call list() without AttributeError
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
# Should be able to call list() without AttributeError
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
# Empty list is expected for a fresh checkpointer
assert result == []
# Empty list is expected for a fresh checkpointer
assert result == []
+86 -82
View File
@@ -18,6 +18,7 @@ from app.gateway.routers.models import ModelResponse, ModelsListResponse
from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
from app.gateway.routers.uploads import UploadResponse
from deerflow.client import DeerFlowClient
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import Paths
from deerflow.uploads.manager import PathTraversalError
@@ -44,9 +45,12 @@ def mock_app_config():
@pytest.fixture
def client(mock_app_config):
"""Create a DeerFlowClient with mocked config loading."""
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
return DeerFlowClient()
"""Create a DeerFlowClient holding the mocked config directly.
Passing ``config=`` is the documented post-refactor way to inject a
test AppConfig; nothing relies on process-global state.
"""
return DeerFlowClient(config=mock_app_config)
# ---------------------------------------------------------------------------
@@ -67,8 +71,7 @@ class TestClientInit:
def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
assert c._subagent_enabled is True
@@ -78,24 +81,21 @@ class TestClientInit:
assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config):
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="invalid name with spaces!")
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="../path/traversal")
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="invalid name with spaces!")
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="../path/traversal")
def test_custom_config_path(self, mock_app_config):
with (
patch("deerflow.client.reload_app_config") as mock_reload,
patch("deerflow.client.get_app_config", return_value=mock_app_config),
):
DeerFlowClient(config_path="/tmp/custom.yaml")
mock_reload.assert_called_once_with("/tmp/custom.yaml")
# rather than touching AppConfig.init() / process-global state.
with patch.object(AppConfig, "from_file", return_value=mock_app_config) as mock_from_file:
client = DeerFlowClient(config_path="/tmp/custom.yaml")
mock_from_file.assert_called_once_with("/tmp/custom.yaml")
assert client._app_config is mock_app_config
def test_checkpointer_stored(self, mock_app_config):
cp = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(checkpointer=cp)
c = DeerFlowClient(checkpointer=cp)
assert c._checkpointer is cp
@@ -126,7 +126,7 @@ class TestConfigQueries:
with patch("deerflow.skills.loader.load_skills", return_value=[skill]) as mock_load:
result = client.list_skills()
mock_load.assert_called_once_with(enabled_only=False)
mock_load.assert_called_once_with(client._app_config, enabled_only=False)
assert "skills" in result
assert len(result["skills"]) == 1
@@ -141,7 +141,7 @@ class TestConfigQueries:
def test_list_skills_enabled_only(self, client):
with patch("deerflow.skills.loader.load_skills", return_value=[]) as mock_load:
client.list_skills(enabled_only=True)
mock_load.assert_called_once_with(enabled_only=True)
mock_load.assert_called_once_with(client._app_config, enabled_only=True)
def test_get_memory(self, client):
memory = {"version": "1.0", "facts": []}
@@ -251,8 +251,8 @@ class TestStream:
# Verify context passed to agent.stream
agent.stream.assert_called_once()
call_kwargs = agent.stream.call_args.kwargs
assert call_kwargs["context"]["thread_id"] == "t1"
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
ctx = call_kwargs["context"]
assert ctx.app_config is client._app_config
def test_custom_mode_is_normalized_to_string(self, client):
"""stream() forwards custom events even when the mode is not a plain string."""
@@ -1091,8 +1091,8 @@ class TestMcpConfig:
ext_config = MagicMock()
ext_config.mcp_servers = {"github": server}
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
result = client.get_mcp_config()
client._app_config = MagicMock(extensions=ext_config)
result = client.get_mcp_config()
assert "mcp_servers" in result
assert "github" in result["mcp_servers"]
@@ -1116,10 +1116,11 @@ class TestMcpConfig:
# Pre-set agent to verify it gets invalidated
client._agent = MagicMock()
client._app_config = MagicMock(extensions=current_config)
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
):
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
@@ -1177,12 +1178,12 @@ class TestSkillsManagement:
try:
# Pre-set agent to verify it gets invalidated
client._agent = MagicMock()
client._app_config = MagicMock(extensions=ext_config)
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
result = client.update_skill("test-skill", enabled=False)
assert result["enabled"] is False
@@ -1245,7 +1246,7 @@ class TestMemoryManagement:
assert mock_import.call_count == 1
call_args = mock_import.call_args
assert call_args.args == (imported,)
assert call_args.args == (client._app_config.memory, imported)
assert "user_id" in call_args.kwargs
assert result == imported
@@ -1270,6 +1271,7 @@ class TestMemoryManagement:
confidence=0.88,
)
create_fact.assert_called_once_with(
client._app_config.memory,
content="User prefers concise code reviews.",
category="preference",
confidence=0.88,
@@ -1280,7 +1282,7 @@ class TestMemoryManagement:
data = {"version": "1.0", "facts": []}
with patch("deerflow.agents.memory.updater.delete_memory_fact", return_value=data) as delete_fact:
result = client.delete_memory_fact("fact_123")
delete_fact.assert_called_once_with("fact_123")
delete_fact.assert_called_once_with(client._app_config.memory, "fact_123")
assert result == data
def test_update_memory_fact(self, client):
@@ -1293,6 +1295,7 @@ class TestMemoryManagement:
confidence=0.91,
)
update_fact.assert_called_once_with(
client._app_config.memory,
fact_id="fact_123",
content="User prefers spaces",
category="workflow",
@@ -1308,6 +1311,7 @@ class TestMemoryManagement:
"User prefers spaces",
)
update_fact.assert_called_once_with(
client._app_config.memory,
fact_id="fact_123",
content="User prefers spaces",
category=None,
@@ -1316,37 +1320,40 @@ class TestMemoryManagement:
assert result == data
def test_get_memory_config(self, client):
config = MagicMock()
config.enabled = True
config.storage_path = ".deer-flow/memory.json"
config.debounce_seconds = 30
config.max_facts = 100
config.fact_confidence_threshold = 0.7
config.injection_enabled = True
config.max_injection_tokens = 2000
mem_config = MagicMock()
mem_config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30
mem_config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True
mem_config.max_injection_tokens = 2000
with patch("deerflow.config.memory_config.get_memory_config", return_value=config):
result = client.get_memory_config()
app_cfg = MagicMock()
app_cfg.memory = mem_config
client._app_config = app_cfg
result = client.get_memory_config()
assert result["enabled"] is True
assert result["max_facts"] == 100
def test_get_memory_status(self, client):
config = MagicMock()
config.enabled = True
config.storage_path = ".deer-flow/memory.json"
config.debounce_seconds = 30
config.max_facts = 100
config.fact_confidence_threshold = 0.7
config.injection_enabled = True
config.max_injection_tokens = 2000
mem_config = MagicMock()
mem_config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30
mem_config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True
mem_config.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_config
data = {"version": "1.0", "facts": []}
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=data),
):
client._app_config = app_cfg
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=data):
result = client.get_memory_status()
assert "config" in result
@@ -1800,10 +1807,10 @@ class TestScenarioConfigManagement:
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
client._agent = MagicMock() # Simulate existing agent
client._app_config = MagicMock(extensions=current_config)
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
):
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
assert "my-mcp" in mcp_result["mcp_servers"]
@@ -1832,8 +1839,7 @@ class TestScenarioConfigManagement:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
skill_result = client.update_skill("code-gen", enabled=False)
assert skill_result["enabled"] is False
@@ -2021,10 +2027,10 @@ class TestScenarioMemoryWorkflow:
refreshed = client.reload_memory()
assert len(refreshed["facts"]) == 2
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data),
):
app_cfg = MagicMock()
app_cfg.memory = config
client._app_config = app_cfg
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data):
status = client.get_memory_status()
assert status["config"]["enabled"] is True
assert len(status["data"]["facts"]) == 2
@@ -2085,8 +2091,7 @@ class TestScenarioSkillInstallAndUse:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
toggled = client.update_skill("my-analyzer", enabled=False)
assert toggled["enabled"] is False
@@ -2220,8 +2225,7 @@ class TestGatewayConformance:
mock_app_config.models = [model]
mock_app_config.token_usage.enabled = True
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient()
client = DeerFlowClient(config=mock_app_config)
result = client.list_models()
parsed = ModelsListResponse(**result)
@@ -2240,8 +2244,7 @@ class TestGatewayConformance:
mock_app_config.models = [model]
mock_app_config.get_model_config.return_value = model
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient()
client = DeerFlowClient(config=mock_app_config)
result = client.get_model("test-model")
assert result is not None
@@ -2310,8 +2313,8 @@ class TestGatewayConformance:
ext_config = MagicMock()
ext_config.mcp_servers = {"test": server}
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
result = client.get_mcp_config()
client._app_config = MagicMock(extensions=ext_config)
result = client.get_mcp_config()
parsed = McpConfigResponse(**result)
assert "test" in parsed.mcp_servers
@@ -2335,10 +2338,10 @@ class TestGatewayConformance:
config_file = tmp_path / "extensions_config.json"
config_file.write_text("{}")
client._app_config = MagicMock(extensions=ext_config)
with (
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.reload_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=ext_config)),
):
result = client.update_mcp_config({"srv": server.model_dump.return_value})
@@ -2369,8 +2372,11 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
result = client.get_memory_config()
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
client._app_config = app_cfg
result = client.get_memory_config()
parsed = MemoryConfigResponse(**result)
assert parsed.enabled is True
@@ -2386,6 +2392,8 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
memory_data = {
"version": "1.0",
"lastUpdated": "",
@@ -2402,10 +2410,8 @@ class TestGatewayConformance:
"facts": [],
}
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data),
):
client._app_config = app_cfg
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data):
result = client.get_memory_status()
parsed = MemoryStatusResponse(**result)
@@ -2694,8 +2700,7 @@ class TestConfigUpdateErrors:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
with pytest.raises(RuntimeError, match="disappeared"):
client.update_skill("ghost-skill", enabled=False)
@@ -3074,10 +3079,10 @@ class TestBugAgentInvalidationInconsistency:
config_file = Path(tmp) / "ext.json"
config_file.write_text("{}")
client._app_config = MagicMock(extensions=current_config)
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),
):
client.update_mcp_config({})
@@ -3109,8 +3114,7 @@ class TestBugAgentInvalidationInconsistency:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
client.update_skill("s1", enabled=False)
+19 -35
View File
@@ -56,6 +56,10 @@ def _make_e2e_config() -> AppConfig:
- ``E2E_BASE_URL`` (default: ``https://ark-cn-beijing.bytedance.net/api/v3``)
- ``OPENAI_API_KEY`` (required for LLM tests)
"""
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.summarization_config import SummarizationConfig
from deerflow.config.title_config import TitleConfig
return AppConfig(
models=[
ModelConfig(
@@ -73,6 +77,9 @@ def _make_e2e_config() -> AppConfig:
)
],
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", allow_host_bash=True),
title=TitleConfig(enabled=False),
memory=MemoryConfig(enabled=False),
summarization=SummarizationConfig(enabled=False),
)
@@ -87,7 +94,7 @@ def e2e_env(tmp_path, monkeypatch):
- DEER_FLOW_HOME → tmp_path (all thread data lands in a temp dir)
- Singletons reset so they pick up the new env
- Title/memory/summarization disabled to avoid extra LLM calls
- Title/memory/summarization disabled via AppConfig fields
- AppConfig built programmatically (avoids config.yaml param-name issues)
"""
# 1. Filesystem isolation
@@ -95,30 +102,12 @@ def e2e_env(tmp_path, monkeypatch):
monkeypatch.setattr("deerflow.config.paths._paths", None)
monkeypatch.setattr("deerflow.sandbox.sandbox_provider._default_sandbox_provider", None)
# 2. Inject a clean AppConfig via the global singleton.
config = _make_e2e_config()
monkeypatch.setattr("deerflow.config.app_config._app_config", config)
monkeypatch.setattr("deerflow.config.app_config._app_config_is_custom", True)
# 1b. Override the autouse ``AppConfig.from_file`` stub from conftest
# (minimal test config) with the e2e-specific config that carries a
# real model entry and disables title/memory/summarization.
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda config_path=None: _make_e2e_config()))
# 3. Disable title generation (extra LLM call, non-deterministic)
from deerflow.config.title_config import TitleConfig
monkeypatch.setattr("deerflow.config.title_config._title_config", TitleConfig(enabled=False))
# 4. Disable memory queueing (avoids background threads & file writes)
from deerflow.config.memory_config import MemoryConfig
monkeypatch.setattr(
"deerflow.agents.middlewares.memory_middleware.get_memory_config",
lambda: MemoryConfig(enabled=False),
)
# 5. Ensure summarization is off (default, but be explicit)
from deerflow.config.summarization_config import SummarizationConfig
monkeypatch.setattr("deerflow.config.summarization_config._summarization_config", SummarizationConfig(enabled=False))
# 6. Exclude TitleMiddleware from the chain.
# 2. Exclude TitleMiddleware from the chain.
# It triggers an extra LLM call to generate a thread title, which adds
# non-determinism and cost to E2E tests (title generation is already
# disabled via TitleConfig above, but the middleware still participates
@@ -666,10 +655,9 @@ class TestConfigManagement:
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
# Force reload so the singleton picks up our test file
from deerflow.config.extensions_config import reload_extensions_config
reload_extensions_config()
# Mock from_file so update_mcp_config's internal reload works without config.yaml
e2e_config = _make_e2e_config()
monkeypatch.setattr(AppConfig, "from_file", classmethod(lambda cls, path=None: e2e_config))
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
# Simulate a cached agent
@@ -693,9 +681,9 @@ class TestConfigManagement:
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
from deerflow.config.extensions_config import reload_extensions_config
reload_extensions_config()
# Mock from_file so update_skill's internal reload works without config.yaml
e2e_config = _make_e2e_config()
monkeypatch.setattr(AppConfig, "from_file", classmethod(lambda cls, path=None: e2e_config))
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
c._agent = "fake-agent-placeholder"
@@ -721,10 +709,6 @@ class TestConfigManagement:
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
from deerflow.config.extensions_config import reload_extensions_config
reload_extensions_config()
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
with pytest.raises(ValueError, match="not found"):
c.update_skill("nonexistent-skill-xyz", enabled=True)
+1 -1
View File
@@ -101,7 +101,7 @@ class TestLiveStreaming:
class TestLiveToolUse:
def test_agent_uses_bash_tool(self, client):
"""Agent uses bash tool when asked to run a command."""
if not is_host_bash_allowed():
if not is_host_bash_allowed(client._app_config):
pytest.skip("Host bash is disabled for LocalSandboxProvider in the active config")
events = list(client.stream("Use the bash tool to run: echo 'LIVE_TEST_OK'. Then tell me the output."))
@@ -0,0 +1,82 @@
"""Multi-client isolation regression test.
Phase 2 Task P2-3: ``DeerFlowClient`` now captures its ``AppConfig`` in the
constructor instead of going through a process-global config.
This test pins the resulting invariant: two clients with different configs
can coexist without contending over shared state.
Before P2-3, the shared ``AppConfig._global`` caused the second client's
``init()`` to clobber the first client's config.
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from deerflow.client import DeerFlowClient
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
@pytest.fixture
def disable_agent_creation(monkeypatch):
"""Prevent lazy agent creation — we only care about config access."""
monkeypatch.setattr(DeerFlowClient, "_get_or_create_agent", MagicMock(), raising=False)
def test_two_clients_do_not_clobber_each_other(disable_agent_creation):
"""Two clients with distinct configs keep their own AppConfig."""
cfg_a = AppConfig(
sandbox=SandboxConfig(use="test"),
memory=MemoryConfig(enabled=True),
)
cfg_b = AppConfig(
sandbox=SandboxConfig(use="test"),
memory=MemoryConfig(enabled=False),
)
client_a = DeerFlowClient(config=cfg_a)
client_b = DeerFlowClient(config=cfg_b)
# Identity: each client retains its own instance, not a shared ref
assert client_a._app_config is cfg_a
assert client_b._app_config is cfg_b
# Semantic: memory flag differs
assert client_a._app_config.memory.enabled is True
assert client_b._app_config.memory.enabled is False
def test_client_config_precedes_path(disable_agent_creation, tmp_path):
"""When both config= and config_path= are given, config= wins."""
cfg = AppConfig(sandbox=SandboxConfig(use="test"), log_level="debug")
# config_path points at a file that doesn't exist — proves it's unused
bogus_path = str(tmp_path / "nope.yaml")
client = DeerFlowClient(config_path=bogus_path, config=cfg)
assert client._app_config is cfg
assert client._app_config.log_level == "debug"
def test_multi_client_gateway_dict_returns_distinct(disable_agent_creation):
"""get_mcp_config() reads from self._app_config, not process-global."""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
ext_a = ExtensionsConfig(mcp_servers={"server-a": McpServerConfig(enabled=True)})
ext_b = ExtensionsConfig(mcp_servers={"server-b": McpServerConfig(enabled=True)})
cfg_a = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ext_a)
cfg_b = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ext_b)
client_a = DeerFlowClient(config=cfg_a)
client_b = DeerFlowClient(config=cfg_b)
servers_a = client_a.get_mcp_config()["mcp_servers"]
servers_b = client_b.get_mcp_config()["mcp_servers"]
assert set(servers_a.keys()) == {"server-a"}
assert set(servers_b.keys()) == {"server-b"}
+95
View File
@@ -0,0 +1,95 @@
"""Verify that all sub-config Pydantic models are frozen (immutable).
Frozen models reject attribute assignment after construction, raising
pydantic.ValidationError. This test collects every BaseModel subclass
defined in the deerflow.config package and asserts that mutation is
blocked.
"""
import inspect
import pkgutil
import pytest
from pydantic import BaseModel, ValidationError
import deerflow.config as config_pkg
def _collect_config_models() -> list[type[BaseModel]]:
"""Walk deerflow.config.* and return all concrete BaseModel subclasses."""
import importlib
models: list[type[BaseModel]] = []
package_path = config_pkg.__path__
package_prefix = config_pkg.__name__ + "."
for _importer, modname, _ispkg in pkgutil.walk_packages(package_path, prefix=package_prefix):
try:
mod = importlib.import_module(modname)
except Exception:
continue
for _name, obj in inspect.getmembers(mod, inspect.isclass):
if (
issubclass(obj, BaseModel)
and obj is not BaseModel
and obj.__module__ == mod.__name__
):
models.append(obj)
return models
_EXCLUDED: set[str] = set()
_ALL_MODELS = [m for m in _collect_config_models() if m.__name__ not in _EXCLUDED]
# Sanity: make sure we actually collected a meaningful set.
assert len(_ALL_MODELS) >= 15, f"Expected at least 15 config models, found {len(_ALL_MODELS)}: {[m.__name__ for m in _ALL_MODELS]}"
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_is_frozen(model_cls: type[BaseModel]):
"""Every sub-config model must have frozen=True in its model_config."""
cfg = model_cls.model_config
assert cfg.get("frozen") is True, (
f"{model_cls.__name__} is not frozen. "
f"Add `model_config = ConfigDict(frozen=True)` or add `frozen=True` to the existing ConfigDict."
)
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
"""Constructing then mutating any field must raise ValidationError."""
# Build a minimal instance -- use model_construct to skip validation for
# required fields, then pick the first field to try mutating.
fields = list(model_cls.model_fields.keys())
if not fields:
pytest.skip(f"{model_cls.__name__} has no fields")
instance = model_cls.model_construct()
first_field = fields[0]
with pytest.raises(ValidationError):
setattr(instance, first_field, "MUTATED")
def test_extensions_nested_dict_mutation_is_not_blocked_by_pydantic():
"""Regression guard: Pydantic `frozen=True` does NOT deep-freeze container fields.
This test documents the trap — callers MUST compose a new dict and persist
it + reload AppConfig instead of reaching into `extensions.skills[x]`.
If you need the dict to be truly immutable, wrap with Mapping/frozendict.
"""
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
ext = ExtensionsConfig(mcp_servers={}, skills={"a": SkillStateConfig(enabled=True)})
# This is the pre-refactor anti-pattern: Pydantic lets it through because
# the outer model is frozen but the inner dict is a plain builtin. No error.
ext.skills["a"] = SkillStateConfig(enabled=False)
ext.skills["b"] = SkillStateConfig(enabled=True)
# The test asserts the leak exists so a future "add deep-freeze" change
# flips this expectation and forces call-site review.
assert ext.skills["a"].enabled is False
assert "b" in ext.skills
+24 -83
View File
@@ -9,7 +9,9 @@ import pytest
import yaml
from fastapi.testclient import TestClient
from deerflow.config.agents_api_config import AgentsApiConfig, get_agents_api_config, set_agents_api_config
from deerflow.config.memory_config import MemoryConfig
_TEST_MEMORY_CONFIG = MemoryConfig()
# ---------------------------------------------------------------------------
# Helpers
@@ -329,38 +331,26 @@ class TestMemoryFilePath:
def test_global_memory_path(self, tmp_path):
"""None agent_name should return global memory file."""
from deerflow.agents.memory.storage import FileMemoryStorage
from deerflow.config.memory_config import MemoryConfig
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
):
storage = FileMemoryStorage()
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
def test_agent_memory_path(self, tmp_path):
"""Providing agent_name should return per-agent memory file."""
from deerflow.agents.memory.storage import FileMemoryStorage
from deerflow.config.memory_config import MemoryConfig
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
):
storage = FileMemoryStorage()
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path("code-reviewer")
assert path == tmp_path / "agents" / "code-reviewer" / "memory.json"
def test_different_paths_for_different_agents(self, tmp_path):
from deerflow.agents.memory.storage import FileMemoryStorage
from deerflow.config.memory_config import MemoryConfig
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
):
storage = FileMemoryStorage()
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path_global = storage._get_memory_file_path(None)
path_a = storage._get_memory_file_path("agent-a")
path_b = storage._get_memory_file_path("agent-b")
@@ -380,47 +370,32 @@ def _make_test_app(tmp_path: Path):
from fastapi import FastAPI
from app.gateway.routers.agents import router
from deerflow.config.agents_api_config import AgentsApiConfig
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
app = FastAPI()
app.include_router(router)
# The agents router gates every route through ``Depends(get_config)`` and
# only allows access when ``agents_api.enabled`` is true. Wire a permissive
# AppConfig onto ``app.state.config`` so the routes are reachable in tests.
app.state.config = AppConfig(
sandbox=SandboxConfig(use="test"),
agents_api=AgentsApiConfig(enabled=True),
)
return app
@pytest.fixture()
def agent_client(tmp_path):
"""TestClient with agents router, using tmp_path as base_dir."""
import app.gateway.routers.agents as agents_router
paths_instance = _make_paths(tmp_path)
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
set_agents_api_config(AgentsApiConfig(enabled=True))
try:
app = _make_test_app(tmp_path)
with TestClient(app) as client:
client._tmp_path = tmp_path # type: ignore[attr-defined]
yield client
finally:
set_agents_api_config(previous_config)
@pytest.fixture()
def disabled_agent_client(tmp_path):
"""TestClient with agents router while the management API is disabled."""
import app.gateway.routers.agents as agents_router
paths_instance = _make_paths(tmp_path)
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
set_agents_api_config(AgentsApiConfig(enabled=False))
try:
app = _make_test_app(tmp_path)
with TestClient(app) as client:
yield client
finally:
set_agents_api_config(previous_config)
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch("app.gateway.routers.agents.get_paths", return_value=paths_instance):
app = _make_test_app(tmp_path)
with TestClient(app) as client:
client._tmp_path = tmp_path # type: ignore[attr-defined]
yield client
class TestAgentsAPI:
@@ -586,37 +561,3 @@ class TestUserProfileAPI:
response = agent_client.put("/api/user-profile", json={"content": ""})
assert response.status_code == 200
assert response.json()["content"] is None
class TestAgentsApiDisabled:
def test_agents_list_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents")
assert response.status_code == 403
assert "agents_api.enabled=true" in response.json()["detail"]
def test_agent_get_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents/example-agent")
assert response.status_code == 403
def test_agent_name_check_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents/check", params={"name": "example-agent"})
assert response.status_code == 403
def test_agent_create_returns_403(self, disabled_agent_client):
response = disabled_agent_client.post("/api/agents", json={"name": "example-agent", "soul": "blocked"})
assert response.status_code == 403
def test_agent_update_returns_403(self, disabled_agent_client):
response = disabled_agent_client.put("/api/agents/example-agent", json={"description": "blocked"})
assert response.status_code == 403
def test_agent_delete_returns_403(self, disabled_agent_client):
response = disabled_agent_client.delete("/api/agents/example-agent")
assert response.status_code == 403
def test_user_profile_routes_return_403(self, disabled_agent_client):
get_response = disabled_agent_client.get("/api/user-profile")
put_response = disabled_agent_client.put("/api/user-profile", json={"content": "blocked"})
assert get_response.status_code == 403
assert put_response.status_code == 403
+62
View File
@@ -0,0 +1,62 @@
"""Tests for DeerFlowContext and resolve_context()."""
from dataclasses import FrozenInstanceError
from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.sandbox_config import SandboxConfig
def _make_config(**overrides) -> AppConfig:
defaults = {"sandbox": SandboxConfig(use="test")}
defaults.update(overrides)
return AppConfig(**defaults)
class TestDeerFlowContext:
def test_frozen(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
with pytest.raises(FrozenInstanceError):
ctx.app_config = _make_config()
def test_fields(self):
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1", agent_name="test-agent")
assert ctx.thread_id == "t1"
assert ctx.agent_name == "test-agent"
assert ctx.app_config is config
def test_agent_name_default(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
assert ctx.agent_name is None
def test_thread_id_required(self):
with pytest.raises(TypeError):
DeerFlowContext(app_config=_make_config()) # type: ignore[call-arg]
class TestResolveContext:
def test_returns_typed_context_directly(self):
"""Gateway/Client path: runtime.context is DeerFlowContext → return as-is."""
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1")
runtime = MagicMock()
runtime.context = ctx
assert resolve_context(runtime) is ctx
def test_raises_on_none_context(self):
"""Without a typed DeerFlowContext, resolve_context refuses to guess."""
runtime = MagicMock()
runtime.context = None
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
resolve_context(runtime)
def test_raises_on_dict_context(self):
"""Legacy dict shape is no longer supported — we raise instead of lazily loading AppConfig."""
runtime = MagicMock()
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
resolve_context(runtime)
+90 -74
View File
@@ -5,20 +5,36 @@ from unittest.mock import MagicMock, patch
import pytest
# --- Phase 2 test helper: injected runtime for community tools ---
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
def _runtime_with_config(config):
"""Build a runtime carrying a custom (possibly mocked) app_config.
``DeerFlowContext`` is a frozen dataclass typed as ``AppConfig`` but
dataclasses don't enforce the type at runtime — handing a Mock through
lets tests exercise the tool's ``get_tool_config`` lookup without going
through a process-global config.
"""
ctx = _P2Ctx.__new__(_P2Ctx)
object.__setattr__(ctx, "app_config", config)
object.__setattr__(ctx, "thread_id", "test-thread")
object.__setattr__(ctx, "agent_name", None)
return _P2NS(context=ctx)
# -------------------------------------------------------------------
@pytest.fixture
def mock_app_config():
"""Mock the app config to return tool configurations."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 5,
"search_type": "auto",
"contents_max_characters": 1000,
"api_key": "test-api-key",
}
mock_config.return_value.get_tool_config.return_value = tool_config
yield mock_config
"""Fixture retained as a pass-through: tests inject config via runtime directly."""
yield
@pytest.fixture
@@ -49,7 +65,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "test query"})
result = web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
parsed = json.loads(result)
assert len(parsed) == 2
@@ -67,30 +83,30 @@ class TestWebSearchTool:
def test_search_with_custom_config(self, mock_exa_client):
"""Test search respects custom configuration values."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 10,
"search_type": "neural",
"contents_max_characters": 2000,
"api_key": "test-key",
}
mock_config.return_value.get_tool_config.return_value = tool_config
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 10,
"search_type": "neural",
"contents_max_characters": 2000,
"api_key": "test-key",
}
fake_config = MagicMock()
fake_config.get_tool_config.return_value = tool_config
mock_response = MagicMock()
mock_response.results = []
mock_exa_client.search.return_value = mock_response
mock_response = MagicMock()
mock_response.results = []
mock_exa_client.search.return_value = mock_response
from deerflow.community.exa.tools import web_search_tool
from deerflow.community.exa.tools import web_search_tool
web_search_tool.invoke({"query": "neural search"})
web_search_tool.func(query="neural search", runtime=_runtime_with_config(fake_config))
mock_exa_client.search.assert_called_once_with(
"neural search",
type="neural",
num_results=10,
contents={"highlights": {"max_characters": 2000}},
)
mock_exa_client.search.assert_called_once_with(
"neural search",
type="neural",
num_results=10,
contents={"highlights": {"max_characters": 2000}},
)
def test_search_with_no_highlights(self, mock_app_config, mock_exa_client):
"""Test search handles results with no highlights."""
@@ -105,7 +121,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
result = web_search_tool.func(query="test", runtime=_P2_RUNTIME)
parsed = json.loads(result)
assert parsed[0]["snippet"] == ""
@@ -118,7 +134,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "nothing"})
result = web_search_tool.func(query="nothing", runtime=_P2_RUNTIME)
parsed = json.loads(result)
assert parsed == []
@@ -129,7 +145,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "error"})
result = web_search_tool.func(query="error", runtime=_P2_RUNTIME)
assert result == "Error: API rate limit exceeded"
@@ -147,7 +163,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result == "# Fetched Page\n\nThis is the page content."
mock_exa_client.get_contents.assert_called_once_with(
@@ -167,7 +183,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result.startswith("# Untitled\n\n")
@@ -179,7 +195,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com/404"})
result = web_fetch_tool.func(url="https://example.com/404", runtime=_P2_RUNTIME)
assert result == "Error: No results found"
@@ -189,16 +205,44 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result == "Error: Connection timeout"
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"}
mock_config.return_value.get_tool_config.return_value = tool_config
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"}
fake_config = MagicMock()
fake_config.get_tool_config.return_value = tool_config
mock_result = MagicMock()
mock_result.title = "Page"
mock_result.text = "Content."
mock_response = MagicMock()
mock_response.results = [mock_result]
mock_exa_client.get_contents.return_value = mock_response
from deerflow.community.exa.tools import web_fetch_tool
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
fake_config.get_tool_config.assert_any_call("web_fetch")
def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
def get_tool_config(name):
if name == "web_fetch":
return fetch_config
return None
fake_config = MagicMock()
fake_config.get_tool_config.side_effect = get_tool_config
mock_result = MagicMock()
mock_result.title = "Page"
@@ -209,37 +253,9 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
web_fetch_tool.invoke({"url": "https://example.com"})
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
mock_config.return_value.get_tool_config.assert_any_call("web_fetch")
def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
def get_tool_config(name):
if name == "web_fetch":
return fetch_config
return None
mock_config.return_value.get_tool_config.side_effect = get_tool_config
mock_result = MagicMock()
mock_result.title = "Page"
mock_result.text = "Content."
mock_response = MagicMock()
mock_response.results = [mock_result]
mock_exa_client.get_contents.return_value = mock_response
from deerflow.community.exa.tools import web_fetch_tool
web_fetch_tool.invoke({"url": "https://example.com"})
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
def test_fetch_truncates_long_content(self, mock_app_config, mock_exa_client):
"""Test fetch truncates content to 4096 characters."""
@@ -253,7 +269,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
# "# Long Page\n\n" is 14 chars, content truncated to 4096
content_after_header = result.split("\n\n", 1)[1]
+27 -10
View File
@@ -3,14 +3,31 @@
import json
from unittest.mock import MagicMock, patch
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
def _runtime_with_config(config):
ctx = _P2Ctx.__new__(_P2Ctx)
object.__setattr__(ctx, "app_config", config)
object.__setattr__(ctx, "thread_id", "test-thread")
object.__setattr__(ctx, "agent_name", None)
return _P2NS(context=ctx)
class TestWebSearchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch("deerflow.community.firecrawl.tools.get_app_config")
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
def test_search_uses_web_search_config(self, mock_firecrawl_cls):
search_config = MagicMock()
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
mock_get_app_config.return_value.get_tool_config.return_value = search_config
fake_config = MagicMock()
fake_config.get_tool_config.return_value = search_config
mock_result = MagicMock()
mock_result.web = [
@@ -20,7 +37,7 @@ class TestWebSearchTool:
from deerflow.community.firecrawl.tools import web_search_tool
result = web_search_tool.invoke({"query": "test query"})
result = web_search_tool.func(query="test query", runtime=_runtime_with_config(fake_config))
assert json.loads(result) == [
{
@@ -29,15 +46,14 @@ class TestWebSearchTool:
"snippet": "Snippet",
}
]
mock_get_app_config.return_value.get_tool_config.assert_called_with("web_search")
fake_config.get_tool_config.assert_called_with("web_search")
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-search-key")
mock_firecrawl_cls.return_value.search.assert_called_once_with("test query", limit=7)
class TestWebFetchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch("deerflow.community.firecrawl.tools.get_app_config")
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
def test_fetch_uses_web_fetch_config(self, mock_firecrawl_cls):
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}
@@ -46,7 +62,8 @@ class TestWebFetchTool:
return fetch_config
return None
mock_get_app_config.return_value.get_tool_config.side_effect = get_tool_config
fake_config = MagicMock()
fake_config.get_tool_config.side_effect = get_tool_config
mock_scrape_result = MagicMock()
mock_scrape_result.markdown = "Fetched markdown"
@@ -55,10 +72,10 @@ class TestWebFetchTool:
from deerflow.community.firecrawl.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
assert result == "# Fetched Page\n\nFetched markdown"
mock_get_app_config.return_value.get_tool_config.assert_any_call("web_fetch")
fake_config.get_tool_config.assert_any_call("web_fetch")
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-fetch-key")
mock_firecrawl_cls.return_value.scrape.assert_called_once_with(
"https://example.com",
+55
View File
@@ -0,0 +1,55 @@
"""Tests for the FastAPI get_config dependency.
Phase 2 step 1: introduces the new explicit-config primitive that
resolves ``AppConfig`` from ``request.app.state.config``. After migration,
it is the sole mechanism.
"""
from __future__ import annotations
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def test_get_config_returns_app_state_config():
"""get_config returns the AppConfig stored on app.state.config."""
app = FastAPI()
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
app.state.config = cfg
@app.get("/probe")
def probe(c: AppConfig = Depends(get_config)):
# Identity check: FastAPI must hand us the exact object from app.state
return {"same_identity": c is cfg, "log_level": c.log_level}
client = TestClient(app)
response = client.get("/probe")
assert response.status_code == 200
body = response.json()
assert body["same_identity"] is True
assert body["log_level"] == "info"
def test_get_config_reads_updated_app_state():
"""When app.state.config is swapped (config reload), get_config sees the new value."""
app = FastAPI()
original = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
replacement = original.model_copy(update={"log_level": "debug"})
app.state.config = original
@app.get("/log-level")
def log_level(c: AppConfig = Depends(get_config)):
return {"level": c.log_level}
client = TestClient(app)
assert client.get("/log-level").json() == {"level": "info"}
# Simulate config reload (PUT /mcp/config, etc.)
app.state.config = replacement
assert client.get("/log-level").json() == {"level": "debug"}
+10 -8
View File
@@ -333,12 +333,14 @@ class TestGuardrailsConfig:
assert config.provider.use == "deerflow.guardrails.builtin:AllowlistProvider"
assert config.provider.config == {"denied_tools": ["bash"]}
def test_singleton_load_and_get(self):
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict, reset_guardrails_config
def test_guardrails_config_via_app_config(self):
from deerflow.config.app_config import AppConfig
from deerflow.config.guardrails_config import GuardrailProviderConfig, GuardrailsConfig
from deerflow.config.sandbox_config import SandboxConfig
try:
load_guardrails_config_from_dict({"enabled": True, "provider": {"use": "test:Foo"}})
config = get_guardrails_config()
assert config.enabled is True
finally:
reset_guardrails_config()
cfg = AppConfig(
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=True, provider=GuardrailProviderConfig(use="test:Foo")),
)
config = cfg.guardrails
assert config.enabled is True
+16 -8
View File
@@ -6,6 +6,16 @@ from unittest.mock import MagicMock, patch
from deerflow.community.infoquest import tools
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
# --- Phase 2 test helper: injected runtime for community tools ---
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
# -------------------------------------------------------------------
class TestInfoQuestClient:
def test_infoquest_client_initialization(self):
@@ -130,7 +140,7 @@ class TestInfoQuestClient:
mock_client.web_search.return_value = json.dumps([])
mock_get_client.return_value = mock_client
result = tools.web_search_tool.run("test query")
result = tools.web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
assert result == json.dumps([])
mock_get_client.assert_called_once()
@@ -143,14 +153,13 @@ class TestInfoQuestClient:
mock_client.fetch.return_value = "<html><body>Test content</body></html>"
mock_get_client.return_value = mock_client
result = tools.web_fetch_tool.run("https://example.com")
result = tools.web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result == "# Untitled\n\nTest content"
mock_get_client.assert_called_once()
mock_client.fetch.assert_called_once_with("https://example.com")
@patch("deerflow.community.infoquest.tools.get_app_config")
def test_get_infoquest_client(self, mock_get_app_config):
def test_get_infoquest_client(self):
"""Test _get_infoquest_client function with config."""
mock_config = MagicMock()
# Add image_search config to the side_effect
@@ -159,9 +168,8 @@ class TestInfoQuestClient:
MagicMock(model_extra={"fetch_time": 10, "timeout": 30, "navigation_timeout": 60}), # web_fetch config
MagicMock(model_extra={"image_search_time_range": 7, "image_size": "l"}), # image_search config
]
mock_get_app_config.return_value = mock_config
client = tools._get_infoquest_client()
client = tools._get_infoquest_client(mock_config)
assert client.search_time_range == 24
assert client.fetch_time == 10
@@ -321,7 +329,7 @@ class TestImageSearch:
mock_client.image_search.return_value = json.dumps([{"image_url": "https://example.com/image1.jpg"}])
mock_get_client.return_value = mock_client
result = tools.image_search_tool.run({"query": "test query"})
result = tools.image_search_tool.func(query="test query", runtime=_P2_RUNTIME)
# Check if result is a valid JSON string
result_data = json.loads(result)
@@ -340,7 +348,7 @@ class TestImageSearch:
mock_get_client.return_value = mock_client
# Pass all parameters as a dictionary (extra parameters will be ignored)
tools.image_search_tool.run({"query": "sunset", "time_range": 30, "site": "unsplash.com", "image_size": "l"})
tools.image_search_tool.func(query="sunset", runtime=_P2_RUNTIME)
mock_get_client.assert_called_once()
# image_search_tool only passes query to client.image_search
+9 -21
View File
@@ -6,7 +6,7 @@ from types import SimpleNamespace
import pytest
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.tools.builtins.invoke_acp_agent_tool import (
_build_acp_mcp_servers,
_build_mcp_servers,
@@ -18,7 +18,6 @@ from deerflow.tools.tools import get_available_tools
def test_build_mcp_servers_filters_disabled_and_maps_transports():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"]),
@@ -40,11 +39,9 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
}
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_acp_mcp_servers_formats_list_payload():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
@@ -77,7 +74,6 @@ def test_build_acp_mcp_servers_formats_list_payload():
]
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_permission_response_prefers_allow_once():
@@ -669,31 +665,23 @@ async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch,
def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(monkeypatch):
from deerflow.config.acp_config import load_acp_config_from_dict
load_acp_config_from_dict(
{
"codex": {
"command": "codex-acp",
"args": [],
"description": "Codex CLI",
}
}
)
fake_config = SimpleNamespace(
tools=[],
models=[],
tool_search=SimpleNamespace(enabled=False),
acp_agents={
"codex": ACPAgentConfig(
command="codex-acp",
args=[],
description="Codex CLI",
)
},
get_model_config=lambda name: None,
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: fake_config)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
)
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
tools = get_available_tools(include_mcp=True, subagent_enabled=False, app_config=fake_config)
assert "invoke_acp_agent" in [tool.name for tool in tools]
load_acp_config_from_dict({})
+12 -4
View File
@@ -10,6 +10,16 @@ import deerflow.community.jina_ai.jina_client as jina_client_module
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.community.jina_ai.tools import web_fetch_tool
# --- Phase 2 test helper: injected runtime for community tools ---
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
# -------------------------------------------------------------------
@pytest.fixture
def jina_client():
@@ -176,9 +186,8 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
assert result.startswith("Error:")
assert "429" in result
@@ -192,9 +201,8 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
assert "Hello world" in result
assert not result.startswith("Error:")
@@ -8,7 +8,6 @@ import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
@@ -33,7 +32,7 @@ def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
)
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
def test_resolve_model_name_falls_back_to_default(caplog):
app_config = _make_app_config(
[
_make_model("default-model", supports_thinking=False),
@@ -41,16 +40,14 @@ def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
with caplog.at_level("WARNING"):
resolved = lead_agent_module._resolve_model_name("missing-model")
resolved = lead_agent_module._resolve_model_name(app_config, "missing-model")
assert resolved == "default-model"
assert "fallback to default model 'default-model'" in caplog.text
def test_resolve_model_name_uses_default_when_none(monkeypatch):
def test_resolve_model_name_uses_default_when_none():
app_config = _make_app_config(
[
_make_model("default-model", supports_thinking=False),
@@ -58,23 +55,19 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
resolved = lead_agent_module._resolve_model_name(None)
resolved = lead_agent_module._resolve_model_name(app_config, None)
assert resolved == "default-model"
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
def test_resolve_model_name_raises_when_no_models_configured():
app_config = _make_app_config([])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
with pytest.raises(
ValueError,
match="No chat models are configured",
):
lead_agent_module._resolve_model_name("missing-model")
lead_agent_module._resolve_model_name(app_config, "missing-model")
def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch):
@@ -82,13 +75,12 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
import deerflow.tools as tools_module
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda app_config, config, model_name, agent_name=None: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
@@ -105,7 +97,8 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
"is_plan_mode": False,
"subagent_enabled": False,
}
}
},
app_config=app_config,
)
assert captured["name"] == "safe-model"
@@ -113,74 +106,6 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
assert result["model"] is not None
def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
app_config = _make_app_config(
[
_make_model("default-model", supports_thinking=False),
_make_model("context-model", supports_thinking=True),
]
)
import deerflow.tools as tools_module
get_available_tools = MagicMock(return_value=[])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
result = lead_agent_module.make_lead_agent(
{
"context": {
"model_name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
"is_plan_mode": True,
"subagent_enabled": True,
"max_concurrent_subagents": 7,
}
}
)
assert captured == {
"name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
}
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
assert result["model"] is not None
def test_make_lead_agent_rejects_invalid_bootstrap_agent_name(monkeypatch):
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
with pytest.raises(ValueError, match="Invalid agent name"):
lead_agent_module.make_lead_agent(
{
"configurable": {
"model_name": "safe-model",
"thinking_enabled": False,
"is_plan_mode": False,
"subagent_enabled": False,
"is_bootstrap": True,
"agent_name": "../../../tmp/evil",
}
}
)
def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
app_config = _make_app_config(
[
@@ -197,11 +122,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda _ac: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
middlewares = lead_agent_module._build_middlewares(app_config, {"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
@@ -209,12 +133,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
from unittest.mock import MagicMock
from unittest.mock import MagicMock
@@ -222,16 +144,16 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
middleware = lead_agent_module._create_summarization_middleware(patched)
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
+21 -32
View File
@@ -4,34 +4,23 @@ from types import SimpleNamespace
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.skills.types import Skill
def _set_skills_cache_state(*, skills=None, active=False, version=0):
prompt_module._get_cached_skills_prompt_section.cache_clear()
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = skills
prompt_module._enabled_skills_refresh_active = active
prompt_module._enabled_skills_refresh_version = version
prompt_module._enabled_skills_refresh_event.clear()
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
def test_build_custom_mounts_section_returns_empty_when_no_mounts():
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
assert prompt_module._build_custom_mounts_section() == ""
assert prompt_module._build_custom_mounts_section(config) == ""
def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
def test_build_custom_mounts_section_lists_configured_mounts():
mounts = [
SimpleNamespace(container_path="/home/user/shared", read_only=False),
SimpleNamespace(container_path="/mnt/reference", read_only=True),
]
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
section = prompt_module._build_custom_mounts_section()
section = prompt_module._build_custom_mounts_section(config)
assert "**Custom Mounted Directories:**" in section
assert "`/home/user/shared`" in section
@@ -45,15 +34,15 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
config = SimpleNamespace(
sandbox=SimpleNamespace(mounts=mounts),
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template()
prompt = prompt_module.apply_prompt_template(config)
assert "`/home/user/shared`" in prompt
assert "Custom Mounted Directories" in prompt
@@ -63,15 +52,15 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
config = SimpleNamespace(
sandbox=SimpleNamespace(mounts=[]),
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template()
prompt = prompt_module.apply_prompt_template(config)
assert "Treat `/mnt/user-data/workspace` as your default current working directory" in prompt
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
@@ -92,8 +81,8 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
)
state = {"skills": [make_skill("first-skill")]}
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
_set_skills_cache_state()
monkeypatch.setattr(prompt_module, "load_skills", lambda *a, **kwargs: list(state["skills"]))
prompt_module._reset_skills_system_prompt_cache_state()
try:
prompt_module.warm_enabled_skills_cache()
@@ -128,7 +117,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
enabled=True,
)
def fake_load_skills(enabled_only=True):
def fake_load_skills(*a, **kwargs):
nonlocal active_loads, max_active_loads, call_count
with lock:
active_loads += 1
@@ -165,7 +154,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
event = threading.Event()
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda: event)
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda *a, **k: event)
with caplog.at_level("WARNING"):
warmed = prompt_module.warm_enabled_skills_cache(timeout_seconds=0.01)
+38 -42
View File
@@ -19,27 +19,40 @@ def _make_skill(name: str) -> Skill:
)
_DEFAULT_SKILLS_CONFIG = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
def _evolution_enabled_config() -> SimpleNamespace:
return SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)
def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills={"non_existent_skill"})
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"non_existent_skill"})
assert result == ""
def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills=set())
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=set())
assert result == ""
def test_get_skills_prompt_section_returns_skills(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills={"skill1"})
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"skill1"})
assert "skill1" in result
assert "skill2" not in result
assert "[built-in]" in result
@@ -47,56 +60,41 @@ def test_get_skills_prompt_section_returns_skills(monkeypatch):
def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills=None)
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=None)
assert "skill1" in result
assert "skill2" in result
def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
skills = [_make_skill("skill1")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr(
"deerflow.config.get_app_config",
lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills=None)
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
assert "Skill Self-Evolution" in result
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
monkeypatch.setattr(
"deerflow.config.get_app_config",
lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: [])
result = get_skills_prompt_section(available_skills=None)
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
assert "Skill Self-Evolution" in result
def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeypatch):
skills = [_make_skill("skill1")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
config = _evolution_enabled_config()
enabled_result = get_skills_prompt_section(available_skills=None)
enabled_result = get_skills_prompt_section(config, available_skills=None)
assert "Skill Self-Evolution" in enabled_result
config.skill_evolution.enabled = False
disabled_result = get_skills_prompt_section(available_skills=None)
disabled_config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
disabled_result = get_skills_prompt_section(disabled_config, available_skills=None)
assert "Skill Self-Evolution" not in disabled_result
@@ -106,8 +104,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from deerflow.agents.lead_agent import agent as lead_agent_module
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda app_config=None, x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
@@ -118,11 +115,10 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = MockModelConfig()
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
captured_skills = []
def mock_apply_prompt_template(**kwargs):
def mock_apply_prompt_template(_app_config, *args, **kwargs):
captured_skills.append(kwargs.get("available_skills"))
return "mock_prompt"
@@ -130,15 +126,15 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
# Case 1: Empty skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
assert captured_skills[-1] == set()
# Case 2: None skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
assert captured_skills[-1] is None
# Case 3: Some skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
assert captured_skills[-1] == {"skill1"}
@@ -22,26 +22,26 @@ def _make_config(*, allow_host_bash: bool, sandbox_use: str = "deerflow.sandbox.
def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=False))
app_config = _make_config(allow_host_bash=False)
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=app_config)]
assert "bash" not in names
assert "ls" in names
def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=True))
app_config = _make_config(allow_host_bash=True)
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=app_config)]
assert "bash" in names
assert "ls" in names
@@ -52,13 +52,12 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
allow_host_bash=False,
extra_tools=[SimpleNamespace(name="shell", group="bash", use="deerflow.sandbox.tools:bash_tool")],
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=config)]
assert "bash" not in names
assert "shell" not in names
@@ -70,13 +69,12 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
allow_host_bash=False,
sandbox_use="deerflow.community.aio_sandbox:AioSandboxProvider",
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=config)]
assert "bash" in names
assert "ls" in names
@@ -1,6 +1,5 @@
import errno
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@@ -314,8 +313,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
@@ -336,8 +334,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@@ -360,8 +357,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@@ -476,7 +472,6 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
@@ -10,12 +10,22 @@ from deerflow.agents.middlewares.loop_detection_middleware import (
LoopDetectionMiddleware,
_hash_tool_calls,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime(thread_id="test-thread"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
runtime.context = _make_context(thread_id)
return runtime
@@ -293,10 +303,10 @@ class TestLoopDetection:
assert isinstance(mw._lock, type(mw._lock))
def test_fallback_thread_id_when_missing(self):
"""When runtime context has no thread_id, should use 'default'."""
"""When runtime context has empty thread_id, should use 'default'."""
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime = MagicMock()
runtime.context = {}
runtime.context = _make_context("")
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
+22 -17
View File
@@ -3,23 +3,31 @@ import time
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
def _make_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
with patch.object(queue, "_reset_timer"):
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
@@ -29,7 +37,7 @@ def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
def test_process_queue_forwards_correction_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
queue._queue = [
ConversationContext(
thread_id="thread-1",
@@ -55,12 +63,9 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
with patch.object(queue, "_reset_timer"):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
@@ -70,7 +75,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
queue._queue = [
ConversationContext(
thread_id="thread-1",
@@ -1,8 +1,30 @@
"""Tests for user_id propagation through memory queue."""
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for user_id propagation through memory queue."""
from unittest.mock import MagicMock, patch
import pytest
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
@pytest.fixture(autouse=True)
def _enable_memory(monkeypatch):
"""Ensure MemoryUpdateQueue.add() doesn't early-return on disabled memory."""
config = MagicMock(spec=AppConfig)
config.memory = MemoryConfig(enabled=True)
def test_conversation_context_has_user_id():
@@ -16,7 +38,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
with patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
@@ -25,7 +47,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
with patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
+24 -24
View File
@@ -4,6 +4,18 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.routers import memory
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
_TEST_APP_CONFIG = AppConfig(sandbox=SandboxConfig(use="test"))
def _make_app() -> FastAPI:
"""Build a memory-router app pre-populated with a minimal AppConfig."""
app = FastAPI()
app.state.config = _TEST_APP_CONFIG
app.include_router(memory.router)
return app
def _sample_memory(facts: list[dict] | None = None) -> dict:
@@ -25,8 +37,7 @@ def _sample_memory(facts: list[dict] | None = None) -> dict:
def test_export_memory_route_returns_current_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
exported_memory = _sample_memory(
facts=[
{
@@ -49,8 +60,7 @@ def test_export_memory_route_returns_current_memory() -> None:
def test_import_memory_route_returns_imported_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
imported_memory = _sample_memory(
facts=[
{
@@ -73,8 +83,7 @@ def test_import_memory_route_returns_imported_memory() -> None:
def test_export_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
exported_memory = _sample_memory(
facts=[
{
@@ -98,8 +107,7 @@ def test_export_memory_route_preserves_source_error() -> None:
def test_import_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
imported_memory = _sample_memory(
facts=[
{
@@ -123,8 +131,7 @@ def test_import_memory_route_preserves_source_error() -> None:
def test_clear_memory_route_returns_cleared_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.clear_memory_data", return_value=_sample_memory()):
with TestClient(app) as client:
@@ -135,8 +142,7 @@ def test_clear_memory_route_returns_cleared_memory() -> None:
def test_create_memory_fact_route_returns_updated_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@@ -166,8 +172,7 @@ def test_create_memory_fact_route_returns_updated_memory() -> None:
def test_delete_memory_fact_route_returns_updated_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@@ -190,8 +195,7 @@ def test_delete_memory_fact_route_returns_updated_memory() -> None:
def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.delete_memory_fact", side_effect=KeyError("fact_missing")):
with TestClient(app) as client:
@@ -202,8 +206,7 @@ def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
def test_update_memory_fact_route_returns_updated_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@@ -233,8 +236,7 @@ def test_update_memory_fact_route_returns_updated_memory() -> None:
def test_update_memory_fact_route_preserves_omitted_fields() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@@ -269,8 +271,7 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=KeyError("fact_missing")):
with TestClient(app) as client:
@@ -288,8 +289,7 @@ def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
def test_update_memory_fact_route_returns_specific_error_for_invalid_confidence() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=ValueError("confidence")):
with TestClient(app) as client:
+59 -51
View File
@@ -1,3 +1,15 @@
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for memory storage providers."""
import threading
@@ -11,7 +23,13 @@ from deerflow.agents.memory.storage import (
create_empty_memory,
get_memory_storage,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _app_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
class TestCreateEmptyMemory:
@@ -53,10 +71,9 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
def test_get_memory_file_path_agent(self, tmp_path):
"""Should return per-agent memory file path when agent_name is provided."""
@@ -67,14 +84,14 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path("test-agent")
assert path == tmp_path / "agents" / "test-agent" / "memory.json"
@pytest.mark.parametrize("invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"])
def test_validate_agent_name_invalid(self, invalid_name):
"""Should raise ValueError for invalid agent names that don't match the pattern."""
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
with pytest.raises(ValueError, match="Invalid agent name|Agent name must be a non-empty string"):
storage._validate_agent_name(invalid_name)
@@ -87,11 +104,10 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
memory = storage.load()
assert isinstance(memory, dict)
assert memory["version"] == "1.0"
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = storage.load()
assert isinstance(memory, dict)
assert memory["version"] == "1.0"
def test_save_writes_to_file(self, tmp_path):
"""Should save memory data to file."""
@@ -103,12 +119,11 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
result = storage.save(test_memory)
assert result is True
assert memory_file.exists()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
result = storage.save(test_memory)
assert result is True
assert memory_file.exists()
def test_save_does_not_mutate_caller_dict(self, tmp_path):
"""save() must not mutate the caller's dict (lastUpdated side-effect)."""
@@ -209,18 +224,17 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
# First load
memory1 = storage.load()
assert memory1["facts"][0]["content"] == "initial fact"
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
# First load
memory1 = storage.load()
assert memory1["facts"][0]["content"] == "initial fact"
# Update file directly
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
# Update file directly
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
# Reload should get updated data
memory2 = storage.reload()
assert memory2["facts"][0]["content"] == "updated fact"
# Reload should get updated data
memory2 = storage.reload()
assert memory2["facts"][0]["content"] == "updated fact"
class TestGetMemoryStorage:
@@ -237,22 +251,19 @@ class TestGetMemoryStorage:
def test_returns_file_memory_storage_by_default(self):
"""Should return FileMemoryStorage by default."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
def test_falls_back_to_file_memory_storage_on_error(self):
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="non.existent.StorageClass")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
def test_returns_singleton_instance(self):
"""Should return the same instance on subsequent calls."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage1 = get_memory_storage()
storage2 = get_memory_storage()
assert storage1 is storage2
storage1 = get_memory_storage(_TEST_MEMORY_CONFIG)
storage2 = get_memory_storage(_TEST_MEMORY_CONFIG)
assert storage1 is storage2
def test_get_memory_storage_thread_safety(self):
"""Should safely initialize the singleton even with concurrent calls."""
@@ -260,16 +271,15 @@ class TestGetMemoryStorage:
def get_storage():
# get_memory_storage is called concurrently from multiple threads while
# get_memory_config is patched once around thread creation. This verifies
# AppConfig.get is patched once around thread creation. This verifies
# that the singleton initialization remains thread-safe.
results.append(get_memory_storage())
results.append(get_memory_storage(_TEST_MEMORY_CONFIG))
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
threads = [threading.Thread(target=get_storage) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
threads = [threading.Thread(target=get_storage) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All results should be the exact same instance
assert len(results) == 10
@@ -278,13 +288,11 @@ class TestGetMemoryStorage:
def test_get_memory_storage_invalid_class_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
# Using a built-in function instead of a class
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="os.path.join")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
def test_get_memory_storage_non_subclass_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
# Using 'dict' as a class that is not a MemoryStorage subclass
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="builtins.dict")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
@@ -1,11 +1,29 @@
"""Tests for per-user memory storage isolation."""
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for per-user memory storage isolation."""
import pytest
from pathlib import Path
from unittest.mock import patch
import pytest
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _mock_app_config() -> AppConfig:
"""Build a minimal AppConfig with default (empty) memory storage_path."""
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(storage_path=""))
@pytest.fixture
@@ -15,7 +33,9 @@ def base_dir(tmp_path: Path) -> Path:
@pytest.fixture
def storage() -> FileMemoryStorage:
return FileMemoryStorage()
return FileMemoryStorage(_TEST_MEMORY_CONFIG)
class TestUserIsolatedStorage:
@@ -43,7 +63,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
s.save(memory, user_id="alice")
expected_path = base_dir / "users" / "alice" / "memory.json"
@@ -54,7 +74,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory_a = create_empty_memory()
memory_a["user"]["workContext"]["summary"] = "A"
s.save(memory_a, user_id="alice")
@@ -67,38 +87,34 @@ class TestUserIsolatedStorage:
assert loaded_a["user"]["workContext"]["summary"] == "A"
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.paths import Paths
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
s = FileMemoryStorage()
memory = create_empty_memory()
s.save(memory, user_id=None)
expected_path = base_dir / "memory.json"
assert expected_path.exists()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
s.save(memory, user_id=None)
expected_path = base_dir / "memory.json"
assert expected_path.exists()
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.paths import Paths
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
legacy_mem = create_empty_memory()
legacy_mem["user"]["workContext"]["summary"] = "legacy"
s.save(legacy_mem, user_id=None)
legacy_mem = create_empty_memory()
legacy_mem["user"]["workContext"]["summary"] = "legacy"
s.save(legacy_mem, user_id=None)
user_mem = create_empty_memory()
user_mem["user"]["workContext"]["summary"] = "alice"
s.save(user_mem, user_id="alice")
user_mem = create_empty_memory()
user_mem["user"]["workContext"]["summary"] = "alice"
s.save(user_mem, user_id="alice")
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
def test_user_agent_memory_file_location(self, base_dir: Path):
"""Per-user per-agent memory uses the user_agent_memory_file path."""
@@ -106,7 +122,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
memory["user"]["workContext"]["summary"] = "agent scoped"
s.save(memory, "test-agent", user_id="alice")
@@ -119,7 +135,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
s.save(memory, user_id="alice")
# After save, cache should have tuple key
@@ -131,7 +147,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
memory["user"]["workContext"]["summary"] = "initial"
s.save(memory, user_id="alice")
@@ -6,6 +6,17 @@ the in-memory LangGraph Store backend used when database.backend=memory.
from __future__ import annotations
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
from types import SimpleNamespace
import pytest
+57 -248
View File
@@ -1,22 +1,32 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.prompt import format_conversation_for_update
from deerflow.agents.memory.updater import (
MemoryUpdater,
_extract_text,
_run_async_update_sync,
clear_memory_data,
create_memory_fact,
delete_memory_fact,
import_memory_data,
update_memory_fact,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
return {
"version": "1.0",
@@ -35,15 +45,12 @@ def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, obje
}
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def _memory_config(**overrides: object) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig().model_copy(update=overrides))
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@@ -70,19 +77,14 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
{"content": "User likes Python", "category": "preference", "confidence": 0.95},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert [fact["content"] for fact in result["facts"]] == ["User likes Python"]
assert all(fact["id"] != "fact_remove" for fact in result["facts"])
def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory()
update_data = {
"newFacts": [
@@ -91,12 +93,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
{"content": "User works on DeerFlow", "category": "context", "confidence": 0.87},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
assert [fact["content"] for fact in result["facts"]] == [
"User prefers dark mode",
@@ -107,7 +104,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=2, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@@ -135,12 +132,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
{"content": "User likes noisy logs", "category": "behavior", "confidence": 0.6},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
assert [fact["content"] for fact in result["facts"]] == [
"User likes Python",
@@ -151,7 +143,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
def test_apply_updates_preserves_source_error() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory()
update_data = {
"newFacts": [
@@ -163,19 +155,14 @@ def test_apply_updates_preserves_source_error() -> None:
}
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
assert result["facts"][0]["category"] == "correction"
def test_apply_updates_ignores_empty_source_error() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory()
update_data = {
"newFacts": [
@@ -187,19 +174,14 @@ def test_apply_updates_ignores_empty_source_error() -> None:
}
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert "sourceError" not in result["facts"][0]
def test_clear_memory_data_resets_all_sections() -> None:
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
result = clear_memory_data()
result = clear_memory_data(_TEST_MEMORY_CONFIG)
assert result["version"] == "1.0"
assert result["facts"] == []
@@ -233,7 +215,7 @@ def test_delete_memory_fact_removes_only_matching_fact() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = delete_memory_fact("fact_delete")
result = delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_delete")
assert [fact["id"] for fact in result["facts"]] == ["fact_keep"]
@@ -243,7 +225,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = create_memory_fact(
result = create_memory_fact(_TEST_MEMORY_CONFIG,
content=" User prefers concise code reviews. ",
category="preference",
confidence=0.88,
@@ -258,7 +240,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
def test_create_memory_fact_rejects_empty_content() -> None:
try:
create_memory_fact(content=" ")
create_memory_fact(_TEST_MEMORY_CONFIG, content=" ")
except ValueError as exc:
assert exc.args == ("content",)
else:
@@ -268,7 +250,7 @@ def test_create_memory_fact_rejects_empty_content() -> None:
def test_create_memory_fact_rejects_invalid_confidence() -> None:
for confidence in (-0.1, 1.1, float("nan"), float("inf"), float("-inf")):
try:
create_memory_fact(content="User likes tests", confidence=confidence)
create_memory_fact(_TEST_MEMORY_CONFIG, content="User likes tests", confidence=confidence)
except ValueError as exc:
assert exc.args == ("confidence",)
else:
@@ -278,7 +260,7 @@ def test_create_memory_fact_rejects_invalid_confidence() -> None:
def test_delete_memory_fact_raises_for_unknown_id() -> None:
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
try:
delete_memory_fact("fact_missing")
delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_missing")
except KeyError as exc:
assert exc.args == ("fact_missing",)
else:
@@ -303,7 +285,7 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
mock_storage.load.return_value = imported_memory
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
result = import_memory_data(imported_memory)
result = import_memory_data(_TEST_MEMORY_CONFIG, imported_memory)
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
mock_storage.load.assert_called_once_with(None, user_id=None)
@@ -336,7 +318,7 @@ def test_update_memory_fact_updates_only_matching_fact() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = update_memory_fact(
result = update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_edit",
content="User prefers spaces",
category="workflow",
@@ -369,7 +351,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = update_memory_fact(
result = update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_edit",
content="User prefers spaces",
)
@@ -382,7 +364,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
def test_update_memory_fact_raises_for_unknown_id() -> None:
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
try:
update_memory_fact(
update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_missing",
content="User prefers concise code reviews.",
category="preference",
@@ -414,7 +396,7 @@ def test_update_memory_fact_rejects_invalid_confidence() -> None:
return_value=current_memory,
):
try:
update_memory_fact(
update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_edit",
content="User prefers spaces",
confidence=confidence,
@@ -527,17 +509,15 @@ class TestUpdateMemoryStructuredResponse:
model = MagicMock()
response = MagicMock()
response.content = content
model.ainvoke = AsyncMock(return_value=response)
model.invoke.return_value = response
return model
def test_string_response_parses(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -551,17 +531,15 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg])
assert result is True
model.ainvoke.assert_awaited_once()
def test_list_content_response_parses(self):
"""LLM response as list-of-blocks should be extracted, not repr'd."""
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
list_content = [{"type": "text", "text": valid_json}]
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -576,38 +554,13 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
def test_async_update_memory_uses_ainvoke(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi there"
ai_msg.tool_calls = []
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
assert result is True
model.ainvoke.assert_awaited_once()
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -622,17 +575,16 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=True)
assert result is True
prompt = model.ainvoke.await_args.args[0]
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -647,95 +599,15 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=False)
assert result is True
prompt = model.ainvoke.await_args.args[0]
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt
def test_sync_update_memory_wrapper_works_in_running_loop(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is True
model.ainvoke.assert_awaited_once()
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
updater = MemoryUpdater()
with (
patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is False
class TestRunAsyncUpdateSync:
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
class CloseableAwaitable:
def __init__(self):
self.closed = False
def __await__(self):
pytest.fail("awaitable should not have been awaited")
yield
def close(self):
self.closed = True
awaitable = CloseableAwaitable()
with patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
):
async def run_in_loop():
return _run_async_update_sync(awaitable)
result = asyncio.run(run_in_loop())
assert result is False
assert awaitable.closed is True
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
def test_duplicate_fact_different_case_not_stored(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@@ -755,19 +627,14 @@ class TestFactDeduplicationCaseInsensitive:
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
# Should still have only 1 fact (duplicate rejected)
assert len(result["facts"]) == 1
assert result["facts"][0]["content"] == "User prefers Python"
def test_unique_fact_different_case_and_content_stored(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@@ -786,12 +653,7 @@ class TestFactDeduplicationCaseInsensitive:
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert len(result["facts"]) == 2
@@ -804,17 +666,16 @@ class TestReinforcementHint:
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.ainvoke = AsyncMock(return_value=response)
model.invoke.return_value = response
return model
def test_reinforcement_hint_injected_when_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -829,17 +690,16 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.ainvoke.await_args.args[0]
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -854,17 +714,16 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.ainvoke.await_args.args[0]
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@@ -879,56 +738,6 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.ainvoke.await_args.args[0]
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt
class TestFinalizeCacheIsolation:
"""_finalize_update must not mutate the cached memory object."""
def test_deepcopy_prevents_cache_corruption_on_save_failure(self):
"""If save() fails, the in-memory snapshot used by _finalize_update
must remain independent of any object the storage layer may still hold in
its cache. The deepcopy in _finalize_update achieves this — the object
passed to _apply_updates is always a fresh copy, never the cache reference.
"""
updater = MemoryUpdater()
original_memory = _make_memory(facts=[{"id": "fact_orig", "content": "original", "category": "context", "confidence": 0.9, "createdAt": "2024-01-01T00:00:00Z", "source": "t1"}])
import json as _json
new_fact_json = _json.dumps(
{
"user": {},
"history": {},
"newFacts": [{"content": "new fact", "category": "context", "confidence": 0.9}],
"factsToRemove": [],
}
)
mock_response = MagicMock()
mock_response.content = new_fact_json
mock_model = AsyncMock()
mock_model.ainvoke = AsyncMock(return_value=mock_response)
saved_objects: list[dict] = []
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
with (
patch.object(updater, "_get_model", return_value=mock_model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=original_memory),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=save_mock)),
):
msg = MagicMock()
msg.type = "human"
msg.content = "hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "world"
ai_msg.tool_calls = []
updater.update_memory([msg, ai_msg], thread_id="t1")
# original_memory must not have been mutated — deepcopy isolates the mutation
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
assert original_memory["facts"][0]["content"] == "original"
@@ -1,15 +1,26 @@
"""Tests for user_id propagation in memory updater."""
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for user_id propagation in memory updater."""
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.updater import _save_memory_to_file, clear_memory_data, get_memory_data
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
def test_get_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.load.return_value = {"version": "1.0"}
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
get_memory_data(user_id="alice")
get_memory_data(_TEST_MEMORY_CONFIG, user_id="alice")
mock_storage.load.assert_called_once_with(None, user_id="alice")
@@ -17,7 +28,7 @@ def test_save_memory_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
_save_memory_to_file({"version": "1.0"}, user_id="bob")
_save_memory_to_file(_TEST_MEMORY_CONFIG, {"version": "1.0"}, user_id="bob")
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
@@ -25,6 +36,6 @@ def test_clear_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
clear_memory_data(user_id="charlie")
clear_memory_data(_TEST_MEMORY_CONFIG, user_id="charlie")
# Verify save was called with user_id
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
+1 -12
View File
@@ -1,9 +1,7 @@
"""Tests for per-user data migration."""
import json
from pathlib import Path
import pytest
from pathlib import Path
from deerflow.config.paths import Paths
@@ -25,7 +23,6 @@ class TestMigrateThreadDirs:
(legacy / "file.txt").write_text("hello")
from scripts.migrate_user_isolation import migrate_thread_dirs
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
@@ -38,7 +35,6 @@ class TestMigrateThreadDirs:
legacy.mkdir(parents=True)
from scripts.migrate_user_isolation import migrate_thread_dirs
migrate_thread_dirs(paths, thread_owner_map={})
expected = base_dir / "users" / "default" / "threads" / "t2"
@@ -49,7 +45,6 @@ class TestMigrateThreadDirs:
new_dir.mkdir(parents=True)
from scripts.migrate_user_isolation import migrate_thread_dirs
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
assert new_dir.exists()
@@ -63,7 +58,6 @@ class TestMigrateThreadDirs:
(dest / "new.txt").write_text("new")
from scripts.migrate_user_isolation import migrate_thread_dirs
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
assert (dest / "new.txt").read_text() == "new"
@@ -75,7 +69,6 @@ class TestMigrateThreadDirs:
legacy.mkdir(parents=True)
from scripts.migrate_user_isolation import migrate_thread_dirs
migrate_thread_dirs(paths, thread_owner_map={})
assert not (base_dir / "threads").exists()
@@ -85,7 +78,6 @@ class TestMigrateThreadDirs:
legacy.mkdir(parents=True)
from scripts.migrate_user_isolation import migrate_thread_dirs
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
assert len(report) == 1
@@ -99,7 +91,6 @@ class TestMigrateMemory:
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
from scripts.migrate_user_isolation import migrate_memory
migrate_memory(paths, user_id="default")
expected = base_dir / "users" / "default" / "memory.json"
@@ -115,7 +106,6 @@ class TestMigrateMemory:
dest.write_text(json.dumps({"version": "new"}))
from scripts.migrate_user_isolation import migrate_memory
migrate_memory(paths, user_id="default")
assert json.loads(dest.read_text())["version"] == "new"
@@ -123,5 +113,4 @@ class TestMigrateMemory:
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
from scripts.migrate_user_isolation import migrate_memory
migrate_memory(paths, user_id="default") # should not raise
+48 -49
View File
@@ -72,8 +72,7 @@ class FakeChatModel(BaseChatModel):
def _patch_factory(monkeypatch, app_config: AppConfig, model_class=FakeChatModel):
"""Patch get_app_config, resolve_class, and tracing for isolated unit tests."""
monkeypatch.setattr(factory_module, "get_app_config", lambda: app_config)
"""Patch resolve_class and tracing for isolated unit tests."""
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: model_class)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
@@ -88,7 +87,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
_patch_factory(monkeypatch, cfg)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name=None)
factory_module.create_chat_model(name=None, app_config=cfg)
# resolve_class is called — if we reach here without ValueError, the correct model was used
assert FakeChatModel.captured_kwargs.get("model") == "alpha"
@@ -96,11 +95,10 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
def test_raises_when_model_not_found(monkeypatch):
cfg = _make_app_config([_make_model("only-model")])
monkeypatch.setattr(factory_module, "get_app_config", lambda: cfg)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
with pytest.raises(ValueError, match="ghost-model"):
factory_module.create_chat_model(name="ghost-model")
factory_module.create_chat_model(name="ghost-model", app_config=cfg)
def test_appends_all_tracing_callbacks(monkeypatch):
@@ -109,7 +107,7 @@ def test_appends_all_tracing_callbacks(monkeypatch):
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: ["smith-callback", "langfuse-callback"])
FakeChatModel.captured_kwargs = {}
model = factory_module.create_chat_model(name="alpha")
model = factory_module.create_chat_model(name="alpha", app_config=cfg)
assert model.callbacks == ["smith-callback", "langfuse-callback"]
@@ -127,7 +125,7 @@ def test_thinking_enabled_raises_when_not_supported_but_when_thinking_enabled_is
_patch_factory(monkeypatch, cfg)
with pytest.raises(ValueError, match="does not support thinking"):
factory_module.create_chat_model(name="no-think", thinking_enabled=True)
factory_module.create_chat_model(name="no-think", thinking_enabled=True, app_config=cfg)
def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(monkeypatch):
@@ -138,7 +136,7 @@ def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(
_patch_factory(monkeypatch, cfg)
with pytest.raises(ValueError, match="does not support thinking"):
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True)
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True, app_config=cfg)
def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
@@ -147,7 +145,7 @@ def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
_patch_factory(monkeypatch, cfg)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="thinker", thinking_enabled=True)
factory_module.create_chat_model(name="thinker", thinking_enabled=True, app_config=cfg)
assert FakeChatModel.captured_kwargs.get("temperature") == 1.0
assert FakeChatModel.captured_kwargs.get("max_tokens") == 16000
@@ -183,7 +181,7 @@ def test_thinking_disabled_openai_gateway_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False)
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False, app_config=cfg)
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
assert captured.get("reasoning_effort") == "minimal"
@@ -216,7 +214,7 @@ def test_thinking_disabled_langchain_anthropic_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False)
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False, app_config=cfg)
assert captured.get("thinking") == {"type": "disabled"}
assert "extra_body" not in captured
@@ -238,7 +236,7 @@ def test_thinking_disabled_no_when_thinking_enabled_does_nothing(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="plain", thinking_enabled=False)
factory_module.create_chat_model(name="plain", thinking_enabled=False, app_config=cfg)
assert "extra_body" not in captured
assert "thinking" not in captured
@@ -278,7 +276,7 @@ def test_when_thinking_disabled_takes_precedence_over_hardcoded_disable(monkeypa
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False)
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False, app_config=cfg)
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
# User overrode the hardcoded "minimal" with "low"
@@ -310,7 +308,7 @@ def test_when_thinking_disabled_not_used_when_thinking_enabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True)
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True, app_config=cfg)
# when_thinking_enabled should apply, NOT when_thinking_disabled
assert captured.get("extra_body") == {"thinking": {"type": "enabled"}}
@@ -339,7 +337,7 @@ def test_when_thinking_disabled_without_when_thinking_enabled_still_applies(monk
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False)
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False, app_config=cfg)
# when_thinking_disabled is now gated independently of has_thinking_settings
assert captured.get("reasoning_effort") == "low"
@@ -370,7 +368,7 @@ def test_when_thinking_disabled_excluded_from_model_dump(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True)
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True, app_config=cfg)
# when_thinking_disabled value must NOT appear as a raw key
assert "when_thinking_disabled" not in captured
@@ -394,7 +392,7 @@ def test_reasoning_effort_cleared_when_not_supported(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="no-effort", thinking_enabled=False)
factory_module.create_chat_model(name="no-effort", thinking_enabled=False, app_config=cfg)
assert captured.get("reasoning_effort") is None
@@ -422,7 +420,7 @@ def test_reasoning_effort_preserved_when_supported(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="effort-model", thinking_enabled=False)
factory_module.create_chat_model(name="effort-model", thinking_enabled=False, app_config=cfg)
# When supports_reasoning_effort=True, it should NOT be cleared to None
# The disable path sets it to "minimal"; supports_reasoning_effort=True keeps it
@@ -458,7 +456,7 @@ def test_thinking_shortcut_enables_thinking_when_thinking_enabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True)
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True, app_config=cfg)
assert captured.get("thinking") == thinking_settings
@@ -488,7 +486,7 @@ def test_thinking_shortcut_disables_thinking_when_thinking_disabled(monkeypatch)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False)
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False, app_config=cfg)
assert captured.get("thinking") == {"type": "disabled"}
assert "extra_body" not in captured
@@ -520,7 +518,7 @@ def test_thinking_shortcut_merges_with_when_thinking_enabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="merge-model", thinking_enabled=True)
factory_module.create_chat_model(name="merge-model", thinking_enabled=True, app_config=cfg)
# Both the thinking shortcut and when_thinking_enabled settings should be applied
assert captured.get("thinking") == thinking_settings
@@ -552,7 +550,7 @@ def test_thinking_shortcut_not_leaked_into_model_when_disabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="no-leak", thinking_enabled=False)
factory_module.create_chat_model(name="no-leak", thinking_enabled=False, app_config=cfg)
# The disable path should have set thinking to disabled (not the raw enabled shortcut)
assert captured.get("thinking") == {"type": "disabled"}
@@ -590,7 +588,7 @@ def test_openai_compatible_provider_passes_base_url(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="minimax-m2.5")
factory_module.create_chat_model(name="minimax-m2.5", app_config=cfg)
assert captured.get("model") == "MiniMax-M2.5"
assert captured.get("base_url") == "https://api.minimax.io/v1"
@@ -731,11 +729,11 @@ def test_openai_compatible_provider_multiple_models(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
# Create first model
factory_module.create_chat_model(name="minimax-m2.5")
factory_module.create_chat_model(name="minimax-m2.5", app_config=cfg)
assert captured.get("model") == "MiniMax-M2.5"
# Create second model
factory_module.create_chat_model(name="minimax-m2.5-highspeed")
factory_module.create_chat_model(name="minimax-m2.5-highspeed", app_config=cfg)
assert captured.get("model") == "MiniMax-M2.5-highspeed"
@@ -763,7 +761,7 @@ def test_codex_provider_disables_reasoning_when_thinking_disabled(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=False)
factory_module.create_chat_model(name="codex", thinking_enabled=False, app_config=cfg)
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "none"
@@ -783,7 +781,7 @@ def test_codex_provider_preserves_explicit_reasoning_effort(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high")
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high", app_config=cfg)
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "high"
@@ -803,7 +801,7 @@ def test_codex_provider_defaults_reasoning_effort_to_medium(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=True)
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=cfg)
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "medium"
@@ -824,7 +822,7 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=True)
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=cfg)
assert "max_tokens" not in FakeChatModel.captured_kwargs
@@ -837,7 +835,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
supports_thinking=True,
when_thinking_enabled=wte,
)
model.extra_body = {"top_k": 20}
model = model.model_copy(update={"extra_body": {"top_k": 20}})
cfg = _make_app_config([model])
_patch_factory(monkeypatch, cfg)
@@ -850,7 +848,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False, app_config=cfg)
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
assert captured.get("reasoning_effort") is None
@@ -864,7 +862,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
supports_thinking=True,
when_thinking_enabled=wte,
)
model.extra_body = {"top_k": 20}
model = model.model_copy(update={"extra_body": {"top_k": 20}})
cfg = _make_app_config([model])
_patch_factory(monkeypatch, cfg)
@@ -877,7 +875,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False, app_config=cfg)
assert captured.get("extra_body") == {
"top_k": 20,
@@ -911,7 +909,7 @@ def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="deepseek")
factory_module.create_chat_model(name="deepseek", app_config=cfg)
assert captured.get("stream_usage") is True
@@ -930,14 +928,25 @@ def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="claude")
factory_module.create_chat_model(name="claude", app_config=cfg)
assert "stream_usage" not in captured
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
"""If config dumps stream_usage=False, factory should respect it."""
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
# Build a ModelConfig with stream_usage=False as an extra field (extra="allow").
model_with_stream_usage = ModelConfig(
name="deepseek",
display_name="deepseek",
description=None,
use="langchain_deepseek:ChatDeepSeek",
model="deepseek",
supports_thinking=False,
supports_vision=False,
stream_usage=False,
)
cfg = _make_app_config([model_with_stream_usage])
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
captured: dict = {}
@@ -949,17 +958,7 @@ def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
# Simulate config having stream_usage explicitly set by patching model_dump
original_get_model_config = cfg.get_model_config
def patched_get_model_config(name):
mc = original_get_model_config(name)
mc.stream_usage = False # type: ignore[attr-defined]
return mc
monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config)
factory_module.create_chat_model(name="deepseek")
factory_module.create_chat_model(name="deepseek", app_config=cfg)
assert captured.get("stream_usage") is False
@@ -989,7 +988,7 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="gpt-5-responses")
factory_module.create_chat_model(name="gpt-5-responses", app_config=cfg)
assert captured.get("use_responses_api") is True
assert captured.get("output_version") == "responses/v1"
@@ -1030,7 +1029,7 @@ def test_no_duplicate_kwarg_when_reasoning_effort_in_config_and_thinking_disable
_patch_factory(monkeypatch, cfg, model_class=CapturingModel)
# Must not raise TypeError
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False)
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False, app_config=cfg)
# kwargs (runtime) takes precedence: thinking-disabled path sets reasoning_effort=minimal
assert captured.get("reasoning_effort") == "minimal"
+1 -3
View File
@@ -1,8 +1,6 @@
"""Tests for user-scoped path resolution in Paths."""
from pathlib import Path
import pytest
from pathlib import Path
from deerflow.config.paths import Paths
@@ -3,14 +3,24 @@
import importlib
from types import SimpleNamespace
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
present_file_tool_module = importlib.import_module("deerflow.tools.builtins.present_file_tool")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime(outputs_path: str) -> SimpleNamespace:
return SimpleNamespace(
state={"thread_data": {"outputs_path": outputs_path}},
context={"thread_id": "thread-1"},
config={},
context=_make_context("thread-1"),
)
@@ -51,34 +61,6 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
def test_present_files_uses_config_thread_id_when_context_missing(tmp_path, monkeypatch):
outputs_dir = tmp_path / "threads" / "thread-from-config" / "user-data" / "outputs"
outputs_dir.mkdir(parents=True)
artifact_path = outputs_dir / "summary.json"
artifact_path.write_text("{}")
monkeypatch.setattr(
present_file_tool_module,
"get_paths",
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
)
runtime = SimpleNamespace(
state={"thread_data": {"outputs_path": str(outputs_dir)}},
context={},
config={"configurable": {"thread_id": "thread-from-config"}},
)
result = present_file_tool_module.present_file_tool.func(
runtime=runtime,
filepaths=["/mnt/user-data/outputs/summary.json"],
tool_call_id="tc-config",
)
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
assert result.update["messages"][0].content == "Successfully presented files"
def test_present_files_rejects_paths_outside_outputs(tmp_path):
outputs_dir = tmp_path / "threads" / "thread-1" / "user-data" / "outputs"
workspace_dir = tmp_path / "threads" / "thread-1" / "user-data" / "workspace"
@@ -1,5 +1,4 @@
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
@@ -15,19 +14,14 @@ async def test_list_messages_by_run_default_returns_all(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1",
run_id="run-a",
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message",
content=f"msg-a-{i}",
category="message", content=f"msg-a-{i}",
)
for i in range(3):
await store.put(
thread_id="t1",
run_id="run-b",
event_type="human_message",
category="message",
content=f"msg-b-{i}",
thread_id="t1", run_id="run-b",
event_type="human_message", category="message", content=f"msg-b-{i}",
)
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
@@ -42,11 +36,9 @@ async def test_list_messages_by_run_with_limit(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1",
run_id="run-a",
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message",
content=f"msg-a-{i}",
category="message", content=f"msg-a-{i}",
)
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
@@ -60,11 +52,9 @@ async def test_list_messages_by_run_after_seq(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1",
run_id="run-a",
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message",
content=f"msg-a-{i}",
category="message", content=f"msg-a-{i}",
)
all_msgs = await store.list_messages_by_run("t1", "run-a")
@@ -79,11 +69,9 @@ async def test_list_messages_by_run_before_seq(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1",
run_id="run-a",
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message",
content=f"msg-a-{i}",
category="message", content=f"msg-a-{i}",
)
all_msgs = await store.list_messages_by_run("t1", "run-a")
@@ -98,19 +86,13 @@ async def test_list_messages_by_run_does_not_include_other_run(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1",
run_id="run-a",
event_type="human_message",
category="message",
content=f"msg-a-{i}",
thread_id="t1", run_id="run-a",
event_type="human_message", category="message", content=f"msg-a-{i}",
)
for i in range(3):
await store.put(
thread_id="t1",
run_id="run-b",
event_type="human_message",
category="message",
content=f"msg-b-{i}",
thread_id="t1", run_id="run-b",
event_type="human_message", category="message", content=f"msg-b-{i}",
)
msgs = await store.list_messages_by_run("t1", "run-b")
File diff suppressed because it is too large Load Diff
+6 -8
View File
@@ -1,14 +1,15 @@
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.gateway.routers import runs
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -112,8 +113,7 @@ def test_run_messages_passes_after_seq_to_event_store():
response = client.get("/api/runs/run-3/messages?after_seq=5")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-3",
"run-3",
"thread-3", "run-3",
limit=51, # default limit(50) + 1
before_seq=None,
after_seq=5,
@@ -133,8 +133,7 @@ def test_run_messages_respects_custom_limit():
response = client.get("/api/runs/run-4/messages?limit=10")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-4",
"run-4",
"thread-4", "run-4",
limit=11, # 10 + 1
before_seq=None,
after_seq=None,
@@ -154,8 +153,7 @@ def test_run_messages_passes_before_seq_to_event_store():
response = client.get("/api/runs/run-5/messages?before_seq=10")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-5",
"run-5",
"thread-5", "run-5",
limit=51,
before_seq=10,
after_seq=None,
+8 -7
View File
@@ -14,6 +14,10 @@ def _make_runtime(tmp_path):
workspace.mkdir()
uploads.mkdir()
outputs.mkdir()
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
return SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
@@ -23,7 +27,10 @@ def _make_runtime(tmp_path):
"outputs_path": str(outputs),
},
},
context={"thread_id": "thread-1"},
context=DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id="thread-1",
),
)
@@ -103,8 +110,6 @@ def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
result = grep_tool.func(
runtime=runtime,
@@ -324,10 +329,6 @@ def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
monkeypatch.setattr(
"deerflow.sandbox.tools.get_app_config",
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
)
result = glob_tool.func(
runtime=runtime,
+199 -195
View File
@@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
@@ -34,6 +35,53 @@ _THREAD_DATA = {
}
def _make_app_config(
*,
skills_container_path: str = "/mnt/skills",
skills_host_path: str | None = None,
mounts=None,
mcp_servers=None,
tool_config_map=None,
) -> SimpleNamespace:
"""Build a lightweight AppConfig stand-in used by tests.
Only the attributes accessed by the helpers under test are populated;
everything else is omitted to keep the fake minimal and explicit.
"""
skills_path = Path(skills_host_path) if skills_host_path is not None else None
skills_cfg = SimpleNamespace(
container_path=skills_container_path,
get_skills_path=lambda: skills_path if skills_path is not None else Path("/nonexistent-skills-root-12345"),
)
sandbox_cfg = SimpleNamespace(mounts=list(mounts) if mounts else [], bash_output_max_chars=20000)
extensions_cfg = SimpleNamespace(mcp_servers=dict(mcp_servers) if mcp_servers else {})
tool_config_map = dict(tool_config_map or {})
return SimpleNamespace(
skills=skills_cfg,
sandbox=sandbox_cfg,
extensions=extensions_cfg,
get_tool_config=lambda name: tool_config_map.get(name),
)
_DEFAULT_APP_CONFIG = _make_app_config()
def _make_ctx(thread_id: str = "thread-1", *, app_config=_DEFAULT_APP_CONFIG, sandbox_key: str | None = None):
"""Build a DeerFlowContext-like object with extra attributes allowed.
``resolve_context`` only checks ``isinstance(ctx, DeerFlowContext)``; for
tests that need additional attributes (``sandbox_key``) we use a subclass
created at runtime.
"""
from deerflow.config.deer_flow_context import DeerFlowContext as _DFC
ctx = _DFC(app_config=app_config, thread_id=thread_id)
if sandbox_key is not None:
object.__setattr__(ctx, "sandbox_key", sandbox_key)
return ctx
# ---------- replace_virtual_path ----------
@@ -85,7 +133,7 @@ def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
"""Trailing slash on a virtual path inside a command must be preserved."""
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
@@ -94,7 +142,7 @@ def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
def test_mask_local_paths_in_output_hides_host_paths() -> None:
output = "Created: /tmp/deer-flow/threads/t1/user-data/workspace/result.txt"
masked = mask_local_paths_in_output(output, _THREAD_DATA)
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/tmp/deer-flow/threads/t1/user-data" not in masked
assert "/mnt/user-data/workspace/result.txt" in masked
@@ -107,7 +155,7 @@ def test_mask_local_paths_in_output_hides_skills_host_paths() -> None:
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
output = "Reading: /home/user/deer-flow/skills/public/bootstrap/SKILL.md"
masked = mask_local_paths_in_output(output, _THREAD_DATA)
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/home/user/deer-flow/skills" not in masked
assert "/mnt/skills/public/bootstrap/SKILL.md" in masked
@@ -143,12 +191,12 @@ def test_reject_path_traversal_allows_normal_paths() -> None:
def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
with pytest.raises(PermissionError, match="configured mount paths"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
@@ -158,42 +206,41 @@ def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA)
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_allows_user_data_paths() -> None:
# Should not raise — user-data paths are always allowed
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA, _DEFAULT_APP_CONFIG)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_allows_user_data_write() -> None:
# read_only=False (default) should still work for user-data paths
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=False)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
def test_validate_local_tool_path_rejects_traversal_in_user_data() -> None:
"""Path traversal via .. in user-data paths must be rejected."""
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_rejects_traversal_in_skills() -> None:
"""Path traversal via .. in skills paths must be rejected."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, read_only=True)
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_tool_path_rejects_none_thread_data() -> None:
@@ -201,7 +248,7 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
from deerflow.sandbox.exceptions import SandboxRuntimeError
with pytest.raises(SandboxRuntimeError):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None, _DEFAULT_APP_CONFIG)
# ---------- _resolve_skills_path ----------
@@ -209,32 +256,26 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
def test_resolve_skills_path_resolves_correctly() -> None:
"""Skills virtual path should resolve to host path."""
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
# Force get_skills_path().exists() to be True without touching the FS
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", cfg)
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
def test_resolve_skills_path_resolves_root() -> None:
"""Skills container root should resolve to host skills directory."""
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
resolved = _resolve_skills_path("/mnt/skills")
assert resolved == "/home/user/deer-flow/skills"
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
resolved = _resolve_skills_path("/mnt/skills", cfg)
assert resolved == "/home/user/deer-flow/skills"
def test_resolve_skills_path_raises_when_not_configured() -> None:
"""Should raise FileNotFoundError when skills directory is not available."""
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=None),
):
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
# Default app config has no host path configured → _get_skills_host_path returns None
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG)
# ---------- _resolve_and_validate_user_data_path ----------
@@ -249,7 +290,7 @@ def test_resolve_and_validate_user_data_path_resolves_correctly(tmp_path: Path)
"uploads_path": str(tmp_path / "uploads"),
"outputs_path": str(tmp_path / "outputs"),
}
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data)
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data, _DEFAULT_APP_CONFIG)
assert resolved == str(workspace / "hello.txt")
@@ -264,7 +305,7 @@ def test_resolve_and_validate_user_data_path_blocks_traversal(tmp_path: Path) ->
}
# This path resolves outside the allowed roots
with pytest.raises(PermissionError):
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data)
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data, _DEFAULT_APP_CONFIG)
# ---------- replace_virtual_paths_in_command ----------
@@ -277,7 +318,7 @@ def test_replace_virtual_paths_in_command_replaces_skills_paths() -> None:
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
cmd = "cat /mnt/skills/public/bootstrap/SKILL.md"
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/mnt/skills" not in result
assert "/home/user/deer-flow/skills/public/bootstrap/SKILL.md" in result
@@ -289,7 +330,7 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/skills"),
):
cmd = "cat /mnt/skills/public/SKILL.md > /mnt/user-data/workspace/out.txt"
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/mnt/skills" not in result
assert "/mnt/user-data" not in result
assert "/home/user/skills/public/SKILL.md" in result
@@ -301,30 +342,27 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
"""HTTP URLs must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
validate_local_bash_command_paths(
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> None:
@@ -332,8 +370,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> No
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths(
"cat /mnt/user-data/workspace/../../etc/passwd",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
@@ -342,21 +379,20 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths(
"cat /mnt/skills/../../etc/passwd",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) -> None:
runtime = SimpleNamespace(
state={"sandbox": {"sandbox_id": "local"}, "thread_data": _THREAD_DATA.copy()},
context={"thread_id": "thread-1"},
context=_make_ctx("thread-1"),
)
monkeypatch.setattr(
"deerflow.sandbox.tools.ensure_sandbox_initialized",
lambda runtime: SimpleNamespace(execute_command=lambda command: pytest.fail("host bash should not execute")),
)
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda: False)
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda *a, **k: False)
result = bash_tool.func(
runtime=runtime,
@@ -371,33 +407,32 @@ def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) ->
def test_is_skills_path_recognises_default_prefix() -> None:
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
assert _is_skills_path("/mnt/skills") is True
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md") is True
assert _is_skills_path("/mnt/skills-extra/foo") is False
assert _is_skills_path("/mnt/user-data/workspace") is False
assert _is_skills_path("/mnt/skills", _DEFAULT_APP_CONFIG) is True
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG) is True
assert _is_skills_path("/mnt/skills-extra/foo", _DEFAULT_APP_CONFIG) is False
assert _is_skills_path("/mnt/user-data/workspace", _DEFAULT_APP_CONFIG) is False
def test_validate_local_tool_path_allows_skills_read_only() -> None:
"""read_file / ls should be able to access /mnt/skills paths."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
# Should not raise
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
read_only=True,
)
# Should not raise — default app config uses /mnt/skills as container path
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=True,
)
def test_validate_local_tool_path_blocks_skills_write() -> None:
"""write_file / str_replace must NOT write to skills paths."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
read_only=False,
)
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=False,
)
def test_validate_local_bash_command_paths_allows_skills_path() -> None:
@@ -405,8 +440,7 @@ def test_validate_local_bash_command_paths_allows_skills_path() -> None:
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
validate_local_bash_command_paths(
"cat /mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_urls() -> None:
@@ -414,40 +448,35 @@ def test_validate_local_bash_command_paths_allows_urls() -> None:
# HTTPS URLs
validate_local_bash_command_paths(
"curl -X POST https://example.com/api/v1/risk/check",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# HTTP URLs
validate_local_bash_command_paths(
"curl http://localhost:8080/health",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# URLs with query strings
validate_local_bash_command_paths(
"curl https://api.example.com/v2/search?q=test",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# FTP URLs
validate_local_bash_command_paths(
"curl ftp://ftp.example.com/pub/file.tar.gz",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# URL mixed with valid virtual path
validate_local_bash_command_paths(
"curl https://example.com/data -o /mnt/user-data/workspace/data.json",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_file_urls() -> None:
"""file:// URLs should be treated as unsafe and blocked."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA)
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_file_urls_case_insensitive() -> None:
"""file:// URL detection should be case-insensitive."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA)
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -> None:
@@ -455,35 +484,36 @@ def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -
with pytest.raises(PermissionError):
validate_local_bash_command_paths(
"curl file:///etc/passwd -o /mnt/user-data/workspace/out.txt",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_still_blocks_other_paths() -> None:
"""Paths outside virtual and system prefixes must still be blocked."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_skills_custom_container_path() -> None:
"""Skills with a custom container_path in config should also work."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/custom/skills"):
# Should not raise
custom_config = _make_app_config(skills_container_path="/custom/skills")
# Should not raise
validate_local_tool_path(
"/custom/skills/public/my-skill/SKILL.md",
_THREAD_DATA,
custom_config,
read_only=True,
)
# The default /mnt/skills should not match since container path is /custom/skills
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path(
"/custom/skills/public/my-skill/SKILL.md",
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
custom_config,
read_only=True,
)
# The default /mnt/skills should not match since container path is /custom/skills
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
read_only=True,
)
# ---------- ACP workspace path tests ----------
@@ -500,6 +530,7 @@ def test_validate_local_tool_path_allows_acp_workspace_read_only() -> None:
validate_local_tool_path(
"/mnt/acp-workspace/hello_world.py",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=True,
)
@@ -510,6 +541,7 @@ def test_validate_local_tool_path_blocks_acp_workspace_write() -> None:
validate_local_tool_path(
"/mnt/acp-workspace/hello_world.py",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=False,
)
@@ -518,8 +550,7 @@ def test_validate_local_bash_command_paths_allows_acp_workspace() -> None:
"""bash commands referencing /mnt/acp-workspace should be allowed."""
validate_local_bash_command_paths(
"cp /mnt/acp-workspace/hello_world.py /mnt/user-data/outputs/hello_world.py",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -> None:
@@ -527,8 +558,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths(
"cat /mnt/acp-workspace/../../etc/passwd",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_resolve_acp_workspace_path_resolves_correctly(tmp_path: Path) -> None:
@@ -570,7 +600,7 @@ def test_replace_virtual_paths_in_command_replaces_acp_workspace() -> None:
acp_host = "/home/user/.deer-flow/acp-workspace"
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
cmd = "cp /mnt/acp-workspace/hello.py /mnt/user-data/outputs/hello.py"
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/mnt/acp-workspace" not in result
assert f"{acp_host}/hello.py" in result
assert "/tmp/deer-flow/threads/t1/user-data/outputs/hello.py" in result
@@ -581,7 +611,7 @@ def test_mask_local_paths_in_output_hides_acp_workspace_host_paths() -> None:
acp_host = "/home/user/.deer-flow/acp-workspace"
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
output = f"Copied: {acp_host}/hello.py"
masked = mask_local_paths_in_output(output, _THREAD_DATA)
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert acp_host not in masked
assert "/mnt/acp-workspace/hello.py" in masked
@@ -617,39 +647,37 @@ def test_apply_cwd_prefix_quotes_path_with_spaces() -> None:
def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None:
"""Bash commands referencing MCP filesystem server paths should be allowed."""
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.config.sandbox_config import SandboxConfig
mock_config = ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=True,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
)
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=mock_config):
# Should not raise - MCP filesystem paths are allowed
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
# Path traversal should still be blocked
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
# Disabled servers should not expose paths
disabled_config = ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=False,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
def _mcp_app_config(enabled: bool) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
extensions=ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=enabled,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
),
)
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
enabled_cfg = _mcp_app_config(True)
# Should not raise - MCP filesystem paths are allowed
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, enabled_cfg)
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA, enabled_cfg)
# Path traversal should still be blocked
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA, enabled_cfg)
# Disabled servers should not expose paths
disabled_cfg = _mcp_app_config(False)
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, disabled_cfg)
# ---------- Custom mount path tests ----------
@@ -667,12 +695,12 @@ def _mock_custom_mounts():
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
assert _is_custom_mount_path("/mnt/code-read") is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
assert _is_custom_mount_path("/mnt/data") is True
assert _is_custom_mount_path("/mnt/data/file.txt") is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
assert _is_custom_mount_path("/mnt/other") is False
assert _is_custom_mount_path("/mnt/code-read", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/data", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/data/file.txt", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG) is False
assert _is_custom_mount_path("/mnt/other", _DEFAULT_APP_CONFIG) is False
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
@@ -683,7 +711,7 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
mount = _get_custom_mount_for_path("/mnt/code/file.py")
mount = _get_custom_mount_for_path("/mnt/code/file.py", _DEFAULT_APP_CONFIG)
assert mount is not None
assert mount.container_path == "/mnt/code"
@@ -691,90 +719,72 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
"""read_file / ls should be able to access custom mount paths."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
"""write_file / str_replace must NOT write to read-only custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
"""write_file / str_replace should succeed on writable custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
"""Path traversal via .. in custom mount paths must be rejected."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
"""bash commands referencing custom mount paths should be allowed."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
"""Bash commands with traversal in custom mount paths should be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
"""Paths not matching any custom mount should still be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should cache after first successful load."""
# Clear any existing cache
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
# Use real directories so host_path.exists() filtering passes
def test_get_custom_mounts_reads_from_app_config(tmp_path) -> None:
"""_get_custom_mounts should read directly from the supplied AppConfig."""
dir_a = tmp_path / "code-read"
dir_a.mkdir()
dir_b = tmp_path / "data"
dir_b.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
# After caching, should return cached value even without mock
assert hasattr(_get_custom_mounts, "_cached")
assert len(_get_custom_mounts()) == 2
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
cfg = _make_app_config(mounts=mounts)
result = _get_custom_mounts(cfg)
assert len(result) == 2
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
def test_get_custom_mounts_filters_nonexistent_host_path(tmp_path) -> None:
"""_get_custom_mounts should only return mounts whose host_path exists."""
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
from deerflow.config.sandbox_config import VolumeMountConfig
existing_dir = tmp_path / "existing"
existing_dir.mkdir()
@@ -783,22 +793,16 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
cfg = _make_app_config(mounts=mounts)
result = _get_custom_mounts(cfg)
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG)
assert mount is None
@@ -829,8 +833,8 @@ def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) ->
sandbox = SharedSandbox()
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
]
failures: list[BaseException] = []
@@ -905,14 +909,14 @@ def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_pat
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
}
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1", sandbox_key="sandbox-a"), config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-2", sandbox_key="sandbox-b"), config={}),
]
failures: list[BaseException] = []
monkeypatch.setattr(
"deerflow.sandbox.tools.ensure_sandbox_initialized",
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
lambda runtime: sandboxes[runtime.context.sandbox_key],
)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
@@ -972,8 +976,8 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
sandbox = SharedSandbox()
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
]
failures: list[BaseException] = []
+1 -2
View File
@@ -29,10 +29,9 @@ async def test_scan_skill_content_passes_run_name_to_model(monkeypatch):
@pytest.mark.anyio
async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None))
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
result = await scan_skill_content(config, "---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
assert result.decision == "block"
assert "manual review required" in result.reason
+20 -22
View File
@@ -4,9 +4,20 @@ from types import SimpleNamespace
import anyio
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
def _make_context(thread_id: str, app_config: object | None = None) -> DeerFlowContext:
return DeerFlowContext(
app_config=app_config if app_config is not None else AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _skill_content(name: str, description: str = "Demo skill") -> str:
return f"---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"
@@ -23,18 +34,15 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
result = anyio.run(
skill_manage_module.skill_manage_tool.coroutine,
@@ -67,17 +75,14 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
async def _refresh():
async def _refresh(*a, **k):
return None
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
@@ -107,10 +112,8 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
runtime = SimpleNamespace(context={}, config={"configurable": {}})
runtime = SimpleNamespace(context=_make_context("", config), config={"configurable": {}})
with pytest.raises(ValueError, match="built-in skill"):
anyio.run(
@@ -131,17 +134,15 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-sync"}, config={"configurable": {"thread_id": "thread-sync"}})
runtime = SimpleNamespace(context=_make_context("thread-sync", config), config={"configurable": {"thread_id": "thread-sync"}})
result = skill_manage_module.skill_manage_tool.func(
runtime=runtime,
action="create",
@@ -159,17 +160,14 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
async def _refresh():
async def _refresh(*a, **k):
return None
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):
+16 -14
View File
@@ -7,6 +7,9 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.routers import skills as skills_router
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.skills.manager import get_skill_history_file
from deerflow.skills.types import Skill
@@ -44,17 +47,16 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
with TestClient(app) as client:
@@ -94,14 +96,12 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
get_skill_history_file("demo-skill").write_text(
get_skill_history_file("demo-skill", config).write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8",
)
async def _refresh():
async def _refresh(*a, **k):
return None
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
@@ -114,6 +114,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
with TestClient(app) as client:
@@ -136,17 +137,16 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
with TestClient(app) as client:
@@ -238,23 +238,25 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
enabled_state = {"value": True}
refresh_calls = []
def _load_skills(*, enabled_only: bool):
def _load_skills(*a, enabled_only: bool = False, **k):
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
if enabled_only and not skill.enabled:
return []
return [skill]
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
enabled_state["value"] = False
_app_cfg = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ExtensionsConfig(mcp_servers={}, skills={}))
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
monkeypatch.setattr("app.gateway.routers.skills.get_extensions_config", lambda: SimpleNamespace(mcp_servers={}, skills={}))
monkeypatch.setattr("app.gateway.routers.skills.reload_extensions_config", lambda: None)
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.state.config = _app_cfg
app.include_router(skills_router.router)
with TestClient(app) as client:
+3 -3
View File
@@ -27,7 +27,7 @@ def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path:
_write_skill(skills_root / "public" / "parent" / "child-skill", "child-skill", "Child skill")
_write_skill(skills_root / "custom" / "team" / "helper", "team-helper", "Team helper")
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
skills = load_skills(skills_path=skills_root, enabled_only=False)
by_name = {skill.name: skill for skill in skills}
assert {"root-skill", "child-skill", "team-helper"} <= set(by_name)
@@ -57,7 +57,7 @@ def test_load_skills_skips_hidden_directories(tmp_path: Path):
"Hidden skill",
)
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
skills = load_skills(skills_path=skills_root, enabled_only=False)
names = {skill.name for skill in skills}
assert "ok-skill" in names
@@ -69,7 +69,7 @@ def test_load_skills_prefers_custom_over_public_with_same_name(tmp_path: Path):
_write_skill(skills_root / "public" / "shared-skill", "shared-skill", "Public version")
_write_skill(skills_root / "custom" / "shared-skill", "shared-skill", "Custom version")
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
skills = load_skills(skills_path=skills_root, enabled_only=False)
shared = next(skill for skill in skills if skill.name == "shared-skill")
assert shared.category == "custom"
+6 -2
View File
@@ -6,6 +6,7 @@ import re
import anyio
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
# ---------------------------------------------------------------------------
@@ -331,6 +332,9 @@ async def test_concurrent_tasks_end_sentinel():
@pytest.mark.anyio
async def test_make_stream_bridge_defaults():
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
async with make_stream_bridge() as bridge:
"""make_stream_bridge with a config lacking stream_bridge yields a MemoryStreamBridge."""
from deerflow.config.sandbox_config import SandboxConfig
config = AppConfig(sandbox=SandboxConfig(use="test"))
async with make_stream_bridge(config) as bridge:
assert isinstance(bridge, MemoryStreamBridge)
+20 -18
View File
@@ -21,6 +21,8 @@ from unittest.mock import MagicMock, patch
import pytest
_TEST_APP_CONFIG = MagicMock(name="TestAppConfig")
# Module names that need to be mocked to break circular imports
_MOCKED_MODULE_NAMES = [
"deerflow.agents",
@@ -203,7 +205,7 @@ class TestAsyncExecutionPath:
config=base_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -232,7 +234,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -259,7 +261,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -285,7 +287,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -306,7 +308,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -327,7 +329,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -348,7 +350,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -384,7 +386,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -419,7 +421,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -456,7 +458,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -477,7 +479,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_aexecute") as mock_aexecute:
@@ -511,7 +513,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -565,7 +567,7 @@ class TestAsyncToolSupport:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -602,7 +604,7 @@ class TestAsyncToolSupport:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -648,7 +650,7 @@ class TestThreadSafety:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id=f"thread-{task_id}",
thread_id=f"thread-{task_id}", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -858,7 +860,7 @@ class TestCooperativeCancellation:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -898,7 +900,7 @@ class TestCooperativeCancellation:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@@ -977,7 +979,7 @@ class TestCooperativeCancellation:
config=short_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
)
# Wrap _scheduler_pool.submit so we know when run_task finishes
+15 -9
View File
@@ -1,29 +1,35 @@
"""Tests for subagent availability and prompt exposure under local bash hardening."""
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.subagents import registry as registry_module
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: False)
def _config() -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"))
names = registry_module.get_available_subagent_names()
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: False)
names = registry_module.get_available_subagent_names(_config())
assert names == ["general-purpose"]
def test_get_available_subagent_names_keeps_bash_when_allowed(monkeypatch) -> None:
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: True)
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: True)
names = registry_module.get_available_subagent_names()
names = registry_module.get_available_subagent_names(_config())
assert names == ["general-purpose", "bash"]
def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch) -> None:
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose"])
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
section = prompt_module._build_subagent_section(3)
section = prompt_module._build_subagent_section(3, _config())
# When bash is not available, it should not appear at all (aligned with Codex:
# unavailable roles are omitted, not listed as disabled)
@@ -34,9 +40,9 @@ def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch
def test_build_subagent_section_includes_bash_when_available(monkeypatch) -> None:
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose", "bash"])
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose", "bash"])
section = prompt_module._build_subagent_section(3)
section = prompt_module._build_subagent_section(3, _config())
assert "For command execution (git, build, test, deploy operations)" in section
assert 'bash("npm test")' in section
@@ -1,596 +0,0 @@
"""Tests for subagent per-agent skill configuration and custom subagent types.
Covers:
- SubagentConfig.skills field
- SubagentOverrideConfig.skills field
- CustomSubagentConfig model validation
- SubagentsAppConfig.custom_agents and get_skills_for()
- Registry: custom agent lookup, skills override, merged available names
- Skills filter passthrough in task_tool config assembly
"""
import pytest
from deerflow.config.subagents_config import (
CustomSubagentConfig,
SubagentOverrideConfig,
SubagentsAppConfig,
get_subagents_app_config,
load_subagents_config_from_dict,
)
from deerflow.subagents.config import SubagentConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _reset_subagents_config(**kwargs) -> None:
"""Reset global subagents config to a known state."""
load_subagents_config_from_dict(kwargs)
# ---------------------------------------------------------------------------
# SubagentConfig.skills field
# ---------------------------------------------------------------------------
class TestSubagentConfigSkills:
def test_default_skills_is_none(self):
config = SubagentConfig(name="test", description="test", system_prompt="test")
assert config.skills is None
def test_skills_whitelist(self):
config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
skills=["data-analysis", "visualization"],
)
assert config.skills == ["data-analysis", "visualization"]
def test_skills_empty_list_means_no_skills(self):
config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
skills=[],
)
assert config.skills == []
# ---------------------------------------------------------------------------
# SubagentOverrideConfig.skills field
# ---------------------------------------------------------------------------
class TestSubagentOverrideConfigSkills:
def test_default_skills_is_none(self):
override = SubagentOverrideConfig()
assert override.skills is None
def test_skills_whitelist(self):
override = SubagentOverrideConfig(skills=["web-search", "data-analysis"])
assert override.skills == ["web-search", "data-analysis"]
def test_skills_empty_list(self):
override = SubagentOverrideConfig(skills=[])
assert override.skills == []
def test_skills_coexists_with_other_fields(self):
override = SubagentOverrideConfig(
timeout_seconds=300,
model="gpt-5",
skills=["my-skill"],
)
assert override.timeout_seconds == 300
assert override.model == "gpt-5"
assert override.skills == ["my-skill"]
# ---------------------------------------------------------------------------
# CustomSubagentConfig model
# ---------------------------------------------------------------------------
class TestCustomSubagentConfig:
def test_minimal_valid(self):
config = CustomSubagentConfig(
description="A test agent",
system_prompt="You are a test agent.",
)
assert config.description == "A test agent"
assert config.system_prompt == "You are a test agent."
assert config.tools is None
assert config.disallowed_tools == ["task", "ask_clarification", "present_files"]
assert config.skills is None
assert config.model == "inherit"
assert config.max_turns == 50
assert config.timeout_seconds == 900
def test_full_configuration(self):
config = CustomSubagentConfig(
description="Data analysis specialist",
system_prompt="You are a data analysis subagent.",
tools=["bash", "read_file", "write_file"],
disallowed_tools=["task"],
skills=["data-analysis", "visualization"],
model="qwen3:32b",
max_turns=80,
timeout_seconds=600,
)
assert config.tools == ["bash", "read_file", "write_file"]
assert config.skills == ["data-analysis", "visualization"]
assert config.model == "qwen3:32b"
assert config.max_turns == 80
assert config.timeout_seconds == 600
def test_skills_empty_list_no_skills(self):
config = CustomSubagentConfig(
description="test",
system_prompt="test",
skills=[],
)
assert config.skills == []
def test_rejects_zero_max_turns(self):
with pytest.raises(ValueError):
CustomSubagentConfig(
description="test",
system_prompt="test",
max_turns=0,
)
def test_rejects_zero_timeout(self):
with pytest.raises(ValueError):
CustomSubagentConfig(
description="test",
system_prompt="test",
timeout_seconds=0,
)
# ---------------------------------------------------------------------------
# SubagentsAppConfig.custom_agents and get_skills_for()
# ---------------------------------------------------------------------------
class TestSubagentsAppConfigCustomAgents:
def test_default_custom_agents_empty(self):
config = SubagentsAppConfig()
assert config.custom_agents == {}
def test_custom_agents_loaded(self):
config = SubagentsAppConfig(
custom_agents={
"analysis": CustomSubagentConfig(
description="Analysis agent",
system_prompt="You analyze data.",
skills=["data-analysis"],
),
}
)
assert "analysis" in config.custom_agents
assert config.custom_agents["analysis"].skills == ["data-analysis"]
def test_multiple_custom_agents(self):
config = SubagentsAppConfig(
custom_agents={
"analysis": CustomSubagentConfig(
description="Analysis",
system_prompt="analyze",
skills=["data-analysis"],
),
"researcher": CustomSubagentConfig(
description="Research",
system_prompt="research",
skills=["web-search"],
),
}
)
assert len(config.custom_agents) == 2
class TestGetSkillsFor:
def test_returns_none_when_no_override(self):
config = SubagentsAppConfig()
assert config.get_skills_for("general-purpose") is None
assert config.get_skills_for("unknown") is None
def test_returns_skills_whitelist(self):
config = SubagentsAppConfig(
agents={
"general-purpose": SubagentOverrideConfig(skills=["web-search", "coding"]),
}
)
assert config.get_skills_for("general-purpose") == ["web-search", "coding"]
def test_returns_empty_list_for_no_skills(self):
config = SubagentsAppConfig(
agents={
"bash": SubagentOverrideConfig(skills=[]),
}
)
assert config.get_skills_for("bash") == []
def test_returns_none_for_unrelated_agent(self):
config = SubagentsAppConfig(
agents={
"bash": SubagentOverrideConfig(skills=["web-search"]),
}
)
assert config.get_skills_for("general-purpose") is None
def test_returns_none_when_skills_not_set(self):
config = SubagentsAppConfig(
agents={
"bash": SubagentOverrideConfig(timeout_seconds=300),
}
)
assert config.get_skills_for("bash") is None
# ---------------------------------------------------------------------------
# load_subagents_config_from_dict with skills and custom_agents
# ---------------------------------------------------------------------------
class TestLoadSubagentsConfigWithSkills:
def teardown_method(self):
_reset_subagents_config()
def test_load_with_skills_override(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {
"general-purpose": {"skills": ["web-search", "data-analysis"]},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_skills_for("general-purpose") == ["web-search", "data-analysis"]
def test_load_with_empty_skills(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {
"bash": {"skills": []},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_skills_for("bash") == []
def test_load_with_custom_agents(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"custom_agents": {
"analysis": {
"description": "Data analysis specialist",
"system_prompt": "You are a data analysis subagent.",
"skills": ["data-analysis", "visualization"],
"tools": ["bash", "read_file"],
"max_turns": 80,
"timeout_seconds": 600,
},
},
}
)
cfg = get_subagents_app_config()
assert "analysis" in cfg.custom_agents
custom = cfg.custom_agents["analysis"]
assert custom.skills == ["data-analysis", "visualization"]
assert custom.tools == ["bash", "read_file"]
assert custom.max_turns == 80
assert custom.timeout_seconds == 600
def test_load_with_both_overrides_and_custom(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {
"general-purpose": {"skills": ["web-search"]},
},
"custom_agents": {
"analysis": {
"description": "Analysis",
"system_prompt": "Analyze.",
"skills": ["data-analysis"],
},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_skills_for("general-purpose") == ["web-search"]
assert cfg.custom_agents["analysis"].skills == ["data-analysis"]
# ---------------------------------------------------------------------------
# Registry: custom agent lookup
# ---------------------------------------------------------------------------
class TestRegistryCustomAgentLookup:
def teardown_method(self):
_reset_subagents_config()
def test_custom_agent_found(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"custom_agents": {
"analysis": {
"description": "Data analysis specialist",
"system_prompt": "You are a data analysis subagent.",
"skills": ["data-analysis"],
"tools": ["bash", "read_file"],
"max_turns": 80,
"timeout_seconds": 600,
},
},
}
)
config = get_subagent_config("analysis")
assert config is not None
assert config.name == "analysis"
assert config.skills == ["data-analysis"]
assert config.tools == ["bash", "read_file"]
assert config.max_turns == 80
assert config.timeout_seconds == 600
assert config.model == "inherit"
def test_custom_agent_not_found(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config()
assert get_subagent_config("nonexistent") is None
def test_builtin_takes_priority_over_custom(self):
"""If a custom agent has the same name as a builtin, builtin wins."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"custom_agents": {
"general-purpose": {
"description": "Custom override attempt",
"system_prompt": "Should not be used",
},
},
}
)
config = get_subagent_config("general-purpose")
# Should get the builtin description, not the custom one
assert config.description == BUILTIN_SUBAGENTS["general-purpose"].description
def test_custom_agent_with_override(self):
"""Per-agent overrides also apply to custom agents."""
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"custom_agents": {
"analysis": {
"description": "Analysis",
"system_prompt": "Analyze.",
"timeout_seconds": 600,
},
},
"agents": {
"analysis": {"timeout_seconds": 300, "skills": ["overridden-skill"]},
},
}
)
config = get_subagent_config("analysis")
assert config is not None
assert config.timeout_seconds == 300 # Override applied
assert config.skills == ["overridden-skill"] # Override applied
# ---------------------------------------------------------------------------
# Registry: skills override on builtin agents
# ---------------------------------------------------------------------------
class TestRegistrySkillsOverride:
def teardown_method(self):
_reset_subagents_config()
def test_skills_override_applied_to_builtin(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"agents": {
"general-purpose": {"skills": ["web-search", "data-analysis"]},
},
}
)
config = get_subagent_config("general-purpose")
assert config.skills == ["web-search", "data-analysis"]
def test_empty_skills_override(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"agents": {
"bash": {"skills": []},
},
}
)
config = get_subagent_config("bash")
assert config.skills == []
def test_no_skills_override_keeps_default(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config()
config = get_subagent_config("general-purpose")
assert config.skills is None # Default: inherit all
def test_skills_override_does_not_mutate_builtin(self):
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"agents": {
"general-purpose": {"skills": ["web-search"]},
},
}
)
_ = get_subagent_config("general-purpose")
assert BUILTIN_SUBAGENTS["general-purpose"].skills is None
# ---------------------------------------------------------------------------
# Registry: get_available_subagent_names merges custom types
# ---------------------------------------------------------------------------
class TestRegistryAvailableNames:
def teardown_method(self):
_reset_subagents_config()
def test_includes_builtin_names(self):
from deerflow.subagents.registry import get_subagent_names
_reset_subagents_config()
names = get_subagent_names()
assert "general-purpose" in names
assert "bash" in names
def test_includes_custom_names(self):
from deerflow.subagents.registry import get_subagent_names
load_subagents_config_from_dict(
{
"custom_agents": {
"analysis": {
"description": "Analysis",
"system_prompt": "Analyze.",
},
"researcher": {
"description": "Research",
"system_prompt": "Research.",
},
},
}
)
names = get_subagent_names()
assert "general-purpose" in names
assert "bash" in names
assert "analysis" in names
assert "researcher" in names
def test_no_duplicates_when_custom_name_matches_builtin(self):
from deerflow.subagents.registry import get_subagent_names
load_subagents_config_from_dict(
{
"custom_agents": {
"general-purpose": {
"description": "Duplicate name",
"system_prompt": "test",
},
},
}
)
names = get_subagent_names()
assert names.count("general-purpose") == 1
# ---------------------------------------------------------------------------
# Registry: list_subagents includes custom agents
# ---------------------------------------------------------------------------
class TestRegistryListSubagentsWithCustom:
def teardown_method(self):
_reset_subagents_config()
def test_list_includes_custom_agents(self):
from deerflow.subagents.registry import list_subagents
load_subagents_config_from_dict(
{
"custom_agents": {
"analysis": {
"description": "Analysis",
"system_prompt": "Analyze.",
"skills": ["data-analysis"],
},
},
}
)
configs = list_subagents()
names = {c.name for c in configs}
assert "general-purpose" in names
assert "bash" in names
assert "analysis" in names
def test_list_custom_agent_has_correct_skills(self):
from deerflow.subagents.registry import list_subagents
load_subagents_config_from_dict(
{
"custom_agents": {
"analysis": {
"description": "Analysis",
"system_prompt": "Analyze.",
"skills": ["data-analysis", "visualization"],
},
},
}
)
by_name = {c.name: c for c in list_subagents()}
assert by_name["analysis"].skills == ["data-analysis", "visualization"]
# ---------------------------------------------------------------------------
# Skills filter passthrough: verify config.skills is used in task_tool assembly
# ---------------------------------------------------------------------------
class TestSkillsFilterPassthrough:
"""Test that SubagentConfig.skills is correctly passed to get_skills_prompt_section."""
def test_none_skills_passes_none_to_prompt(self):
"""When config.skills is None, available_skills=None should be passed (inherit all)."""
config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
skills=None,
)
# Verify: set(None) would raise, so the code must check for None first
available = set(config.skills) if config.skills is not None else None
assert available is None
def test_empty_skills_passes_empty_set(self):
"""When config.skills is [], available_skills=set() should be passed (no skills)."""
config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
skills=[],
)
available = set(config.skills) if config.skills is not None else None
assert available == set()
def test_skills_whitelist_passes_correct_set(self):
"""When config.skills has values, those should be passed as available_skills."""
config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
skills=["data-analysis", "web-search"],
)
available = set(config.skills) if config.skills is not None else None
assert available == {"data-analysis", "web-search"}
+107 -503
View File
@@ -3,7 +3,7 @@
Covers:
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
- get_timeout_for() / get_max_turns_for() resolution logic
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
- AppConfig.subagents field access
- registry.get_subagent_config() applies config overrides
- registry.list_subagents() applies overrides for all agents
- Polling timeout calculation in task_tool is consistent with config
@@ -11,32 +11,28 @@ Covers:
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.subagents_config import (
SubagentOverrideConfig,
SubagentsAppConfig,
get_subagents_app_config,
load_subagents_config_from_dict,
)
from deerflow.subagents.config import SubagentConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _reset_subagents_config(
def _make_config(
timeout_seconds: int = 900,
*,
max_turns: int | None = None,
agents: dict | None = None,
) -> None:
"""Reset global subagents config to a known state."""
load_subagents_config_from_dict(
{
"timeout_seconds": timeout_seconds,
"max_turns": max_turns,
"agents": agents or {},
}
) -> AppConfig:
"""Build an AppConfig with the given subagents settings."""
return AppConfig(
sandbox=SandboxConfig(use="test"),
subagents=SubagentsAppConfig(
timeout_seconds=timeout_seconds,
max_turns=max_turns,
agents={k: SubagentOverrideConfig(**v) for k, v in (agents or {}).items()},
),
)
@@ -50,523 +46,131 @@ class TestSubagentOverrideConfig:
override = SubagentOverrideConfig()
assert override.timeout_seconds is None
assert override.max_turns is None
assert override.model is None
def test_explicit_value(self):
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42, model="gpt-5.4")
assert override.timeout_seconds == 300
assert override.max_turns == 42
assert override.model == "gpt-5.4"
def test_model_accepts_any_non_empty_string(self):
"""Model name is a free-form non-empty string; cross-reference validation
against the `models:` section happens at registry lookup time."""
override = SubagentOverrideConfig(model="any-arbitrary-model-name")
assert override.model == "any-arbitrary-model-name"
def test_rejects_zero(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=0)
def test_rejects_negative(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=-1)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=-1)
def test_rejects_empty_model(self):
"""Empty-string model would silently bypass the `is not None` check and
reach `create_chat_model(name="")` as a runtime error. Reject at load time
instead, symmetric with the `ge=1` guard on timeout_seconds / max_turns."""
with pytest.raises(ValueError):
SubagentOverrideConfig(model="")
def test_minimum_valid_value(self):
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
assert override.timeout_seconds == 1
assert override.max_turns == 1
# ---------------------------------------------------------------------------
# SubagentsAppConfig defaults and validation
# ---------------------------------------------------------------------------
class TestSubagentsAppConfigDefaults:
def test_default_timeout(self):
config = SubagentsAppConfig()
assert config.timeout_seconds == 900
def test_default_max_turns_override_is_none(self):
config = SubagentsAppConfig()
assert config.max_turns is None
def test_default_agents_empty(self):
config = SubagentsAppConfig()
assert config.agents == {}
def test_custom_global_runtime_overrides(self):
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
assert config.timeout_seconds == 1800
assert config.max_turns == 120
def test_rejects_zero_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=0)
def test_explicit_values(self):
override = SubagentOverrideConfig(timeout_seconds=120, max_turns=50)
assert override.timeout_seconds == 120
assert override.max_turns == 50
def test_rejects_negative_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=-60)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=-60)
with pytest.raises(Exception):
SubagentOverrideConfig(timeout_seconds=-1)
def test_rejects_zero_timeout(self):
with pytest.raises(Exception):
SubagentOverrideConfig(timeout_seconds=0)
# ---------------------------------------------------------------------------
# SubagentsAppConfig resolution helpers
# SubagentsAppConfig model
# ---------------------------------------------------------------------------
class TestRuntimeResolution:
def test_returns_global_default_when_no_override(self):
class TestSubagentsAppConfig:
def test_default_timeout_is_900(self):
config = SubagentsAppConfig()
assert config.timeout_seconds == 900
assert config.max_turns is None
assert config.agents == {}
def test_custom_defaults(self):
config = SubagentsAppConfig(timeout_seconds=300, max_turns=50)
assert config.timeout_seconds == 300
assert config.max_turns == 50
# ---------------------------------------------------------------------------
# get_timeout_for / get_max_turns_for
# ---------------------------------------------------------------------------
class TestTimeoutResolution:
def test_global_timeout_for_unknown_agent(self):
config = SubagentsAppConfig(timeout_seconds=600)
assert config.get_timeout_for("unknown") == 600
def test_per_agent_timeout_overrides_global(self):
config = SubagentsAppConfig(
timeout_seconds=600,
agents={"bash": SubagentOverrideConfig(timeout_seconds=120)},
)
assert config.get_timeout_for("bash") == 120
assert config.get_timeout_for("general-purpose") == 600
def test_per_agent_override_none_falls_back_to_global(self):
config = SubagentsAppConfig(
timeout_seconds=600,
agents={"bash": SubagentOverrideConfig(timeout_seconds=None)},
)
assert config.get_timeout_for("bash") == 600
assert config.get_timeout_for("unknown-agent") == 600
assert config.get_max_turns_for("general-purpose", 100) == 100
class TestMaxTurnsResolution:
def test_builtin_default_when_no_override(self):
config = SubagentsAppConfig()
assert config.get_max_turns_for("bash", 60) == 60
def test_returns_per_agent_override_when_set(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=120,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
)
assert config.get_timeout_for("bash") == 300
assert config.get_max_turns_for("bash", 60) == 80
def test_global_max_turns_overrides_builtin(self):
config = SubagentsAppConfig(max_turns=100)
assert config.get_max_turns_for("bash", 60) == 100
def test_other_agents_still_use_global_default(self):
def test_per_agent_max_turns_overrides_global(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=140,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
max_turns=100,
agents={"bash": SubagentOverrideConfig(max_turns=30)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 140
assert config.get_max_turns_for("bash", 60) == 30
assert config.get_max_turns_for("general-purpose", 60) == 100
def test_agent_with_none_override_falls_back_to_global(self):
def test_per_agent_override_none_falls_back(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=150,
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
max_turns=100,
agents={"bash": SubagentOverrideConfig(max_turns=None)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 150
def test_multiple_per_agent_overrides(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=120,
agents={
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
},
)
assert config.get_timeout_for("general-purpose") == 1800
assert config.get_timeout_for("bash") == 120
assert config.get_max_turns_for("general-purpose", 100) == 200
assert config.get_max_turns_for("bash", 60) == 80
def test_get_model_for_returns_none_when_no_override(self):
"""No per-agent model override -> returns None so callers fall back to builtin/parent."""
config = SubagentsAppConfig(timeout_seconds=900)
assert config.get_model_for("general-purpose") is None
assert config.get_model_for("bash") is None
assert config.get_model_for("unknown-agent") is None
def test_get_model_for_returns_override_when_set(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={
"general-purpose": SubagentOverrideConfig(model="qwen3.5-35b-a3b"),
"bash": SubagentOverrideConfig(model="gpt-5.4"),
},
)
assert config.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
assert config.get_model_for("bash") == "gpt-5.4"
def test_get_model_for_returns_none_for_omitted_agent(self):
"""An agent not listed in overrides returns None even when other agents have model overrides."""
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"bash": SubagentOverrideConfig(model="gpt-5.4")},
)
assert config.get_model_for("general-purpose") is None
def test_get_model_for_handles_explicit_none(self):
"""Explicit model=None in the override is equivalent to no override."""
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, model=None)},
)
assert config.get_model_for("bash") is None
# Timeout override is still applied even when model is None.
assert config.get_timeout_for("bash") == 300
assert config.get_max_turns_for("bash", 60) == 100
# ---------------------------------------------------------------------------
# load_subagents_config_from_dict / get_subagents_app_config singleton
# AppConfig.subagents
# ---------------------------------------------------------------------------
class TestLoadSubagentsConfig:
def teardown_method(self):
"""Restore defaults after each test."""
_reset_subagents_config()
class TestAppConfigSubagents:
def test_load_global_timeout(self):
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
assert get_subagents_app_config().timeout_seconds == 300
assert get_subagents_app_config().max_turns == 120
cfg = _make_config(timeout_seconds=300, max_turns=120)
sub = cfg.subagents
assert sub.timeout_seconds == 300
assert sub.max_turns == 120
def test_load_with_per_agent_overrides(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
cfg = _make_config(
timeout_seconds=900,
max_turns=120,
agents={
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 1800
assert cfg.get_timeout_for("bash") == 60
assert cfg.get_max_turns_for("general-purpose", 100) == 200
assert cfg.get_max_turns_for("bash", 60) == 80
sub = cfg.subagents
assert sub.get_timeout_for("general-purpose") == 1800
assert sub.get_timeout_for("bash") == 60
assert sub.get_max_turns_for("general-purpose", 100) == 200
assert sub.get_max_turns_for("bash", 60) == 80
def test_load_partial_override(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 600,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
}
cfg = _make_config(
timeout_seconds=600,
agents={"bash": {"timeout_seconds": 120, "max_turns": 70}},
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 600
assert cfg.get_timeout_for("bash") == 120
assert cfg.get_max_turns_for("general-purpose", 100) == 100
assert cfg.get_max_turns_for("bash", 60) == 70
sub = cfg.subagents
assert sub.get_timeout_for("general-purpose") == 600
assert sub.get_timeout_for("bash") == 120
assert sub.get_max_turns_for("general-purpose", 100) == 100
assert sub.get_max_turns_for("bash", 60) == 70
def test_load_with_model_overrides(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {
"general-purpose": {"model": "qwen3.5-35b-a3b"},
"bash": {"model": "gpt-5.4", "timeout_seconds": 300},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
assert cfg.get_model_for("bash") == "gpt-5.4"
# Other override fields on the same agent must still load correctly.
assert cfg.get_timeout_for("bash") == 300
def test_load_empty_dict_uses_defaults(self):
load_subagents_config_from_dict({})
cfg = get_subagents_app_config()
assert cfg.timeout_seconds == 900
assert cfg.max_turns is None
assert cfg.agents == {}
def test_load_replaces_previous_config(self):
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
assert get_subagents_app_config().timeout_seconds == 100
assert get_subagents_app_config().max_turns == 90
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
assert get_subagents_app_config().timeout_seconds == 200
assert get_subagents_app_config().max_turns == 110
def test_singleton_returns_same_instance_between_calls(self):
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
assert get_subagents_app_config() is get_subagents_app_config()
# ---------------------------------------------------------------------------
# registry.get_subagent_config runtime overrides applied
# ---------------------------------------------------------------------------
class TestRegistryGetSubagentConfig:
def teardown_method(self):
_reset_subagents_config()
def test_returns_none_for_unknown_agent(self):
from deerflow.subagents.registry import get_subagent_config
assert get_subagent_config("nonexistent") is None
def test_returns_config_for_builtin_agents(self):
from deerflow.subagents.registry import get_subagent_config
assert get_subagent_config("general-purpose") is not None
assert get_subagent_config("bash") is not None
def test_default_timeout_preserved_when_no_config(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=900)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 900
assert config.max_turns == 100
def test_global_timeout_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 1800
assert config.max_turns == 140
def test_per_agent_runtime_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.timeout_seconds == 120
assert bash_config.max_turns == 80
def test_per_agent_override_does_not_affect_other_agents(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
gp_config = get_subagent_config("general-purpose")
assert gp_config.timeout_seconds == 900
assert gp_config.max_turns == 120
def test_per_agent_model_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"model": "gpt-5.4-mini"}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.model == "gpt-5.4-mini"
def test_omitted_model_keeps_builtin_value(self):
"""When config.yaml has no `model` field for an agent, the builtin default must be preserved."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"timeout_seconds": 300}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.model == builtin_bash_model
def test_explicit_null_model_keeps_builtin_value(self):
"""An explicit `model: null` in config.yaml is equivalent to omission — builtin wins."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"model": None}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.model == builtin_bash_model
def test_model_override_does_not_affect_other_agents(self):
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
builtin_gp_model = BUILTIN_SUBAGENTS["general-purpose"].model
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"model": "gpt-5.4"}},
}
)
gp_config = get_subagent_config("general-purpose")
assert gp_config.model == builtin_gp_model
def test_model_override_preserves_other_fields(self):
"""Applying a model override must leave timeout_seconds / max_turns / name intact."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
original = BUILTIN_SUBAGENTS["bash"]
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"model": "gpt-5.4-mini"}},
}
)
overridden = get_subagent_config("bash")
assert overridden.model == "gpt-5.4-mini"
assert overridden.name == original.name
assert overridden.description == original.description
# No timeout / max_turns override was set, so they use global default / builtin.
assert overridden.timeout_seconds == 900
assert overridden.max_turns == original.max_turns
def test_model_override_does_not_mutate_builtin(self):
"""Registry must return a new object, leaving the builtin default intact."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
original_bash_model = BUILTIN_SUBAGENTS["bash"].model
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"model": "gpt-5.4-mini"}},
}
)
_ = get_subagent_config("bash")
assert BUILTIN_SUBAGENTS["bash"].model == original_bash_model
def test_builtin_config_object_is_not_mutated(self):
"""Registry must return a new object, leaving the builtin default intact."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
returned = get_subagent_config("bash")
assert returned.timeout_seconds == 42
assert returned.max_turns == 88
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
def test_config_preserves_other_fields(self):
"""Applying runtime overrides must not change other SubagentConfig fields."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=300, max_turns=140)
original = BUILTIN_SUBAGENTS["general-purpose"]
overridden = get_subagent_config("general-purpose")
assert overridden.name == original.name
assert overridden.description == original.description
assert overridden.max_turns == 140
assert overridden.model == original.model
assert overridden.tools == original.tools
assert overridden.disallowed_tools == original.disallowed_tools
# ---------------------------------------------------------------------------
# registry.list_subagents all agents get overrides
# ---------------------------------------------------------------------------
class TestRegistryListSubagents:
def teardown_method(self):
_reset_subagents_config()
def test_lists_both_builtin_agents(self):
from deerflow.subagents.registry import list_subagents
names = {cfg.name for cfg in list_subagents()}
assert "general-purpose" in names
assert "bash" in names
def test_all_returned_configs_get_global_override(self):
from deerflow.subagents.registry import list_subagents
_reset_subagents_config(timeout_seconds=123, max_turns=77)
for cfg in list_subagents():
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
def test_per_agent_overrides_reflected_in_list(self):
from deerflow.subagents.registry import list_subagents
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
by_name = {cfg.name: cfg for cfg in list_subagents()}
assert by_name["general-purpose"].timeout_seconds == 1800
assert by_name["bash"].timeout_seconds == 60
assert by_name["general-purpose"].max_turns == 200
assert by_name["bash"].max_turns == 80
# ---------------------------------------------------------------------------
# Polling timeout calculation (logic extracted from task_tool)
# ---------------------------------------------------------------------------
class TestPollingTimeoutCalculation:
"""Verify the formula (timeout_seconds + 60) // 5 is correct for various inputs."""
@pytest.mark.parametrize(
"timeout_seconds, expected_max_polls",
[
(900, 192), # default 15 min → (900+60)//5 = 192
(300, 72), # 5 min → (300+60)//5 = 72
(1800, 372), # 30 min → (1800+60)//5 = 372
(60, 24), # 1 min → (60+60)//5 = 24
(1, 12), # minimum → (1+60)//5 = 12
],
)
def test_polling_timeout_formula(self, timeout_seconds: int, expected_max_polls: int):
dummy_config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
timeout_seconds=timeout_seconds,
)
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
assert max_poll_count == expected_max_polls
def test_polling_timeout_exceeds_execution_timeout(self):
"""Safety-net polling window must always be longer than the execution timeout."""
for timeout_seconds in [60, 300, 900, 1800]:
dummy_config = SubagentConfig(
name="test",
description="test",
system_prompt="test",
timeout_seconds=timeout_seconds,
)
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
polling_window_seconds = max_poll_count * 5
assert polling_window_seconds > timeout_seconds
def test_load_empty_uses_defaults(self):
cfg = _make_config()
sub = cfg.subagents
assert sub.timeout_seconds == 900
assert sub.max_turns is None
assert sub.agents == {}
+46 -253
View File
@@ -8,6 +8,9 @@ from unittest.mock import MagicMock
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.subagents.config import SubagentConfig
# Use module import so tests can patch the exact symbols referenced inside task_tool().
@@ -24,6 +27,13 @@ class FakeSubagentStatus(Enum):
TIMED_OUT = "timed_out"
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime() -> SimpleNamespace:
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
return SimpleNamespace(
@@ -35,7 +45,7 @@ def _make_runtime() -> SimpleNamespace:
"outputs_path": "/tmp/outputs",
},
},
context={"thread_id": "thread-1"},
context=_make_context("thread-1"),
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
)
@@ -83,11 +93,11 @@ class _DummyScheduledTask:
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: None)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
result = _run_task_tool(
runtime=None,
runtime=_make_runtime(),
description="执行任务",
prompt="do work",
subagent_type="general-purpose",
@@ -98,8 +108,8 @@ def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: _make_subagent_config())
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda: False)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: _make_subagent_config())
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda *a, **k: False)
result = _run_task_tool(
runtime=_make_runtime(),
@@ -142,9 +152,8 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
@@ -165,225 +174,20 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
assert captured["executor_kwargs"]["config"].max_turns == 7
# Skills are no longer appended to system_prompt; they are loaded per-session
# by SubagentExecutor and injected as conversation items (Codex pattern).
assert captured["executor_kwargs"]["config"].system_prompt == "Base system prompt"
# Skills are now loaded per-session by SubagentExecutor (mirroring Codex's pattern);
# task_tool no longer appends them to ``system_prompt`` here.
assert "Skills Appendix" not in captured["executor_kwargs"]["config"].system_prompt
get_available_tools.assert_called_once_with(model_name="ark-model", groups=None, subagent_enabled=False)
from unittest.mock import ANY
get_available_tools.assert_called_once_with(model_name="ark-model", groups=ANY, subagent_enabled=False, app_config=ANY)
event_types = [e["type"] for e in events]
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
assert events[-1]["result"] == "all done"
def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch):
"""Verify tool_groups from parent metadata are passed to get_available_tools(groups=...)."""
config = _make_subagent_config()
parent_tool_groups = ["file:read", "file:write", "bash"]
runtime = SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
"thread_data": {"workspace_path": "/tmp/workspace"},
},
context={"thread_id": "thread-1"},
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1", "tool_groups": parent_tool_groups}},
)
events = []
get_available_tools = MagicMock(return_value=["tool-a"])
class DummyExecutor:
def __init__(self, **kwargs):
pass
def execute_async(self, prompt, task_id=None):
return task_id or "generated-task-id"
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
output = _run_task_tool(
runtime=runtime,
description="执行任务",
prompt="file work only",
subagent_type="general-purpose",
tool_call_id="tc-groups",
)
assert output == "Task Succeeded. Result: done"
# The key assertion: groups should be propagated from parent metadata
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
config = _make_subagent_config()
runtime = _make_runtime()
runtime.config["metadata"]["available_skills"] = ["safe-skill"]
events = []
captured = {}
class DummyExecutor:
def __init__(self, **kwargs):
captured["config"] = kwargs["config"]
def execute_async(self, prompt, task_id=None):
return task_id or "generated-task-id"
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
output = _run_task_tool(
runtime=runtime,
description="执行任务",
prompt="use skills",
subagent_type="general-purpose",
tool_call_id="tc-skills",
)
assert output == "Task Succeeded. Result: done"
assert captured["config"].skills == ["safe-skill"]
def test_task_tool_intersects_parent_and_subagent_skill_allowlists(monkeypatch):
config = _make_subagent_config()
config = SubagentConfig(
name=config.name,
description=config.description,
system_prompt=config.system_prompt,
max_turns=config.max_turns,
timeout_seconds=config.timeout_seconds,
skills=["safe-skill", "other-skill"],
)
runtime = _make_runtime()
runtime.config["metadata"]["available_skills"] = ["safe-skill"]
events = []
captured = {}
class DummyExecutor:
def __init__(self, **kwargs):
captured["config"] = kwargs["config"]
def execute_async(self, prompt, task_id=None):
return task_id or "generated-task-id"
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
output = _run_task_tool(
runtime=runtime,
description="执行任务",
prompt="use skills",
subagent_type="general-purpose",
tool_call_id="tc-skills-intersection",
)
assert output == "Task Succeeded. Result: done"
assert captured["config"].skills == ["safe-skill"]
def test_task_tool_no_tool_groups_passes_none(monkeypatch):
"""Verify that when metadata has no tool_groups, groups=None is passed (backward compat)."""
config = _make_subagent_config()
# Default _make_runtime() has no tool_groups in metadata
runtime = _make_runtime()
events = []
get_available_tools = MagicMock(return_value=[])
class DummyExecutor:
def __init__(self, **kwargs):
pass
def execute_async(self, prompt, task_id=None):
return task_id or "generated-task-id"
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="ok"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
output = _run_task_tool(
runtime=runtime,
description="执行任务",
prompt="normal work",
subagent_type="general-purpose",
tool_call_id="tc-no-groups",
)
assert output == "Task Succeeded. Result: ok"
# No tool_groups in metadata → groups=None (default behavior preserved)
get_available_tools.assert_called_once_with(model_name="ark-model", groups=None, subagent_enabled=False)
def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
"""Verify that when runtime is None, groups=None is passed (e.g., unknown subagent path exits early, but tools still load correctly)."""
config = _make_subagent_config()
events = []
get_available_tools = MagicMock(return_value=[])
class DummyExecutor:
def __init__(self, **kwargs):
pass
def execute_async(self, prompt, task_id=None):
return task_id or "generated-task-id"
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="ok"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
output = _run_task_tool(
runtime=None,
description="执行任务",
prompt="no runtime",
subagent_type="general-purpose",
tool_call_id="tc-no-runtime",
)
assert output == "Task Succeeded. Result: ok"
# runtime is None → metadata is empty dict → groups=None
get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False)
def test_task_tool_returns_failed_message(monkeypatch):
config = _make_subagent_config()
events = []
@@ -393,12 +197,11 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -427,12 +230,11 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -463,12 +265,11 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -499,12 +300,11 @@ def test_cleanup_called_on_completed(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
lambda *a, **k: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -539,12 +339,11 @@ def test_cleanup_called_on_failed(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="error"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -579,12 +378,11 @@ def test_cleanup_called_on_timed_out(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -626,12 +424,11 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@@ -679,8 +476,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
@@ -730,12 +526,11 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
@@ -785,12 +580,11 @@ def test_cancellation_calls_request_cancel(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
@@ -843,9 +637,8 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
+23 -26
View File
@@ -1,58 +1,55 @@
import pytest
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
def _as_posix(path: str) -> str:
return path.replace("\\", "/")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
class TestThreadDataMiddleware:
def test_before_agent_returns_paths_when_thread_id_present_in_context(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
from langgraph.runtime import Runtime
result = middleware.before_agent(state={}, runtime=Runtime(context={"thread_id": "thread-123"}))
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-123")))
assert result is not None
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-123/user-data/workspace")
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-123/user-data/uploads")
assert _as_posix(result["thread_data"]["outputs_path"]).endswith("threads/thread-123/user-data/outputs")
def test_before_agent_uses_thread_id_from_configurable_when_context_is_none(self, tmp_path, monkeypatch):
def test_before_agent_uses_thread_id_from_context(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
runtime = Runtime(context=None)
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
lambda: {"configurable": {"thread_id": "thread-from-config"}},
)
from langgraph.runtime import Runtime
result = middleware.before_agent(state={}, runtime=runtime)
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-config")))
assert result is not None
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-from-config/user-data/workspace")
assert runtime.context is None
def test_before_agent_uses_thread_id_from_configurable_when_context_missing_thread_id(self, tmp_path, monkeypatch):
def test_before_agent_uses_thread_id_from_typed_context(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
runtime = Runtime(context={})
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
lambda: {"configurable": {"thread_id": "thread-from-config"}},
)
from langgraph.runtime import Runtime
result = middleware.before_agent(state={}, runtime=runtime)
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-dict")))
assert result is not None
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-config/user-data/uploads")
assert runtime.context == {}
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-dict/user-data/uploads")
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
def test_before_agent_raises_clear_error_when_thread_id_missing(self, tmp_path):
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
monkeypatch.setattr(
"deerflow.agents.middlewares.thread_data_middleware.get_config",
lambda: {"configurable": {}},
)
from langgraph.runtime import Runtime
with pytest.raises(ValueError, match="Thread ID is required in runtime context or config.configurable"):
middleware.before_agent(state={}, runtime=Runtime(context=None))
with pytest.raises(ValueError, match="Thread ID is required"):
middleware.before_agent(state={}, runtime=Runtime(context=_make_context("")))
@@ -1,14 +1,15 @@
"""Tests for paginated GET /api/threads/{thread_id}/runs/{run_id}/messages endpoint."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.gateway.routers import thread_runs
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -77,8 +78,7 @@ def test_after_seq_forwarded_to_event_store():
response = client.get("/api/threads/thread-3/runs/run-3/messages?after_seq=5")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-3",
"run-3",
"thread-3", "run-3",
limit=51, # default limit(50) + 1
before_seq=None,
after_seq=5,
@@ -94,8 +94,7 @@ def test_before_seq_forwarded_to_event_store():
response = client.get("/api/threads/thread-4/runs/run-4/messages?before_seq=10")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-4",
"run-4",
"thread-4", "run-4",
limit=51,
before_seq=10,
after_seq=None,
@@ -111,8 +110,7 @@ def test_custom_limit_forwarded_to_event_store():
response = client.get("/api/threads/thread-5/runs/run-5/messages?limit=10")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-5",
"run-5",
"thread-5", "run-5",
limit=11, # 10 + 1
before_seq=None,
after_seq=None,
+1 -36
View File
@@ -3,7 +3,7 @@
import pytest
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
from deerflow.config.title_config import TitleConfig
class TestTitleConfig:
@@ -44,21 +44,6 @@ class TestTitleConfig:
with pytest.raises(ValueError):
TitleConfig(max_chars=201)
def test_get_set_config(self):
"""Test global config getter and setter."""
original_config = get_title_config()
# Set new config
new_config = TitleConfig(enabled=False, max_words=10)
set_title_config(new_config)
# Verify it was set
assert get_title_config().enabled is False
assert get_title_config().max_words == 10
# Restore original config
set_title_config(original_config)
class TestTitleMiddleware:
"""Tests for TitleMiddleware."""
@@ -68,23 +53,3 @@ class TestTitleMiddleware:
middleware = TitleMiddleware()
assert middleware is not None
assert middleware.state_schema is not None
# TODO: Add integration tests with mock Runtime
# def test_should_generate_title(self):
# """Test title generation trigger logic."""
# pass
# def test_generate_title(self):
# """Test title generation."""
# pass
# def test_after_agent_hook(self):
# """Test after_agent hook."""
# pass
# TODO: Add integration tests
# - Test with real LangGraph runtime
# - Test title persistence with checkpointer
# - Test fallback behavior when LLM fails
# - Test concurrent title generation
@@ -1,38 +1,32 @@
"""Core behavior tests for TitleMiddleware."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.title_config import TitleConfig
def _clone_title_config(config: TitleConfig) -> TitleConfig:
# Avoid mutating shared global config objects across tests.
return TitleConfig(**config.model_dump())
def _make_title_config(**overrides) -> TitleConfig:
return TitleConfig(**overrides)
def _set_test_title_config(**overrides) -> TitleConfig:
config = _clone_title_config(get_title_config())
for key, value in overrides.items():
setattr(config, key, value)
set_title_config(config)
return config
def _make_runtime(**title_overrides) -> SimpleNamespace:
"""Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
return SimpleNamespace(context=ctx)
class TestTitleMiddlewareCoreLogic:
def setup_method(self):
# Title config is a global singleton; snapshot and restore for test isolation.
self._original = _clone_title_config(get_title_config())
def teardown_method(self):
set_title_config(self._original)
def test_should_generate_title_for_first_complete_exchange(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
@@ -41,27 +35,24 @@ class TestTitleMiddlewareCoreLogic:
]
}
assert middleware._should_generate_title(state) is True
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
def test_should_not_generate_title_when_disabled_or_already_set(self):
middleware = TitleMiddleware()
_set_test_title_config(enabled=False)
disabled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": None,
}
assert middleware._should_generate_title(disabled_state) is False
assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
_set_test_title_config(enabled=True)
titled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": "Existing Title",
}
assert middleware._should_generate_title(titled_state) is False
assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
def test_should_not_generate_title_after_second_user_turn(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
@@ -72,10 +63,9 @@ class TestTitleMiddlewareCoreLogic:
]
}
assert middleware._should_generate_title(state) is False
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
@@ -87,11 +77,13 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=12))))
title = result["title"]
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
from unittest.mock import ANY
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, app_config=ANY)
model.ainvoke.assert_awaited_once()
assert model.ainvoke.await_args.kwargs["config"] == {
"run_name": "title_agent",
@@ -99,7 +91,6 @@ class TestTitleMiddlewareCoreLogic:
}
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
@@ -112,13 +103,12 @@ class TestTitleMiddlewareCoreLogic:
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
title = result["title"]
assert title == "请帮我总结这段代码"
def test_generate_title_fallback_for_long_message(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
@@ -130,7 +120,7 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="收到"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
@@ -141,25 +131,24 @@ class TestTitleMiddlewareCoreLogic:
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
assert result == {"title": "异步标题"}
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
result = middleware.after_model({"messages": []}, runtime=MagicMock())
result = middleware.after_model({"messages": []}, runtime=_make_runtime())
assert result == {"title": "同步标题"}
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
def test_sync_generate_title_uses_fallback_without_model(self):
"""Sync path avoids LLM calls and derives a local fallback title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
state = {
@@ -168,12 +157,11 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
assert result == {"title": "请帮我写测试"}
def test_sync_generate_title_respects_fallback_truncation(self):
"""Sync fallback path still respects max_chars truncation rules."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
state = {
@@ -182,7 +170,7 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")
+1 -2
View File
@@ -154,8 +154,7 @@ class TestStreamUsageIntegration:
"""Test that stream() emits usage_metadata in messages-tuple and end events."""
def _make_client(self):
with patch("deerflow.client.get_app_config", return_value=_mock_app_config()):
return DeerFlowClient()
return DeerFlowClient()
def test_stream_emits_usage_in_messages_tuple(self):
"""messages-tuple AI event should include usage_metadata when present."""
+9 -12
View File
@@ -56,27 +56,25 @@ def _make_minimal_config(tools):
return config
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
def test_no_duplicates_returned(mock_reset, mock_bash):
"""get_available_tools() never returns two tools with the same name."""
mock_cfg.return_value = _make_minimal_config([])
cfg = _make_minimal_config([])
# Patch the builtin tools so we control exactly what comes back.
with patch("deerflow.tools.tools.BUILTIN_TOOLS", [_tool_alpha, _tool_alpha_dup, _tool_beta]):
result = get_available_tools(include_mcp=False)
result = get_available_tools(include_mcp=False, app_config=cfg)
names = [t.name for t in result]
assert len(names) == len(set(names)), f"Duplicate names detected: {names}"
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
def test_first_occurrence_wins(mock_reset, mock_bash):
"""When duplicates exist, the first occurrence is kept."""
mock_cfg.return_value = _make_minimal_config([])
cfg = _make_minimal_config([])
sentinel_alpha = MagicMock(spec=BaseTool, name="_sentinel")
sentinel_alpha.name = _tool_alpha.name # same name
@@ -84,23 +82,22 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
sentinel_alpha_dup.name = _tool_alpha.name # same name — should be dropped
with patch("deerflow.tools.tools.BUILTIN_TOOLS", [sentinel_alpha, sentinel_alpha_dup, _tool_beta]):
result = get_available_tools(include_mcp=False)
result = get_available_tools(include_mcp=False, app_config=cfg)
returned_alpha = next(t for t in result if t.name == _tool_alpha.name)
assert returned_alpha is sentinel_alpha
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
def test_duplicate_triggers_warning(mock_reset, mock_bash, caplog):
"""A warning is logged for every skipped duplicate."""
import logging
mock_cfg.return_value = _make_minimal_config([])
cfg = _make_minimal_config([])
with patch("deerflow.tools.tools.BUILTIN_TOOLS", [_tool_alpha, _tool_alpha_dup]):
with caplog.at_level(logging.WARNING, logger="deerflow.tools.tools"):
get_available_tools(include_mcp=False)
get_available_tools(include_mcp=False, app_config=cfg)
assert any("Duplicate tool name" in r.message for r in caplog.records), "Expected a duplicate-tool warning in log output"
+21 -27
View File
@@ -8,7 +8,7 @@ import pytest
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool as langchain_tool
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
from deerflow.config.tool_search_config import ToolSearchConfig
from deerflow.tools.builtins.tool_search import (
DeferredToolRegistry,
get_deferred_registry,
@@ -64,12 +64,12 @@ class TestToolSearchConfig:
config = ToolSearchConfig(enabled=True)
assert config.enabled is True
def test_load_from_dict(self):
config = load_tool_search_config_from_dict({"enabled": True})
def test_validate_from_dict(self):
config = ToolSearchConfig.model_validate({"enabled": True})
assert config.enabled is True
def test_load_from_empty_dict(self):
config = load_tool_search_config_from_dict({})
def test_validate_from_empty_dict(self):
config = ToolSearchConfig.model_validate({})
assert config.enabled is False
@@ -266,48 +266,42 @@ class TestToolSearchTool:
class TestDeferredToolsPromptSection:
@pytest.fixture(autouse=True)
def _mock_app_config(self, monkeypatch):
@pytest.fixture
def mock_config(self):
"""Provide a minimal AppConfig mock so tests don't need config.yaml."""
from unittest.mock import MagicMock
from deerflow.config.tool_search_config import ToolSearchConfig
mock_config = MagicMock()
mock_config.tool_search = ToolSearchConfig() # disabled by default
monkeypatch.setattr("deerflow.config.get_app_config", lambda: mock_config)
config = MagicMock()
config.tool_search = ToolSearchConfig() # disabled by default
return config
def test_empty_when_disabled(self):
def test_empty_when_disabled(self, mock_config):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
# tool_search.enabled defaults to False
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(mock_config)
assert section == ""
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
def test_empty_when_enabled_but_no_registry(self, mock_config):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
from deerflow.config import get_app_config
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
section = get_deferred_tools_prompt_section()
mock_config.tool_search = ToolSearchConfig(enabled=True)
section = get_deferred_tools_prompt_section(mock_config)
assert section == ""
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
def test_empty_when_enabled_but_empty_registry(self, mock_config):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
from deerflow.config import get_app_config
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
mock_config.tool_search = ToolSearchConfig(enabled=True)
set_deferred_registry(DeferredToolRegistry())
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(mock_config)
assert section == ""
def test_lists_tool_names(self, registry, monkeypatch):
def test_lists_tool_names(self, registry, mock_config):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
from deerflow.config import get_app_config
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
mock_config.tool_search = ToolSearchConfig(enabled=True)
set_deferred_registry(registry)
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(mock_config)
assert "<available-deferred-tools>" in section
assert "</available-deferred-tools>" in section
assert "github_create_issue" in section
@@ -13,7 +13,10 @@ from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths
from deerflow.config.sandbox_config import SandboxConfig
THREAD_ID = "thread-abc123"
@@ -23,13 +26,20 @@ THREAD_ID = "thread-abc123"
# ---------------------------------------------------------------------------
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _middleware(tmp_path: Path) -> UploadsMiddleware:
return UploadsMiddleware(base_dir=str(tmp_path))
def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock:
rt = MagicMock()
rt.context = {"thread_id": thread_id}
rt.context = _make_context(thread_id or "")
return rt
+1 -2
View File
@@ -10,8 +10,8 @@ from types import SimpleNamespace
import pytest
from deerflow.runtime.user_context import (
DEFAULT_USER_ID,
CurrentUser,
DEFAULT_USER_ID,
get_current_user,
get_effective_user_id,
require_current_user,
@@ -100,7 +100,6 @@ def test_effective_user_id_returns_user_id_when_set():
def test_effective_user_id_coerces_to_str():
"""User.id might be a UUID object; must come back as str."""
import uuid
uid = uuid.uuid4()
user = SimpleNamespace(id=uid)