mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
Merge refactor/config-deerflow-context into release/2.0-rc
Cherry-pick PR #2271's config refactor onto release/2.0-rc. Used 'git merge -X theirs' to auto-resolve content conflicts in favor of the PR's design (frozen AppConfig + explicit-parameter passing). Limitations: - Release-only changes that overlapped with PR's refactor in 119 files are NOT preserved — those files reflect PR's version. Follow-up commits on this branch will need to re-apply release-only modifications where meaningful. - See PR #2271 for design rationale.
This commit is contained in:
@@ -29,6 +29,7 @@ apps with the real middleware — those should not use this module.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec, TypeVar
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -112,7 +113,11 @@ def make_authed_test_app(
|
||||
return app
|
||||
|
||||
|
||||
def call_unwrapped[*P, R](decorated: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
"""Invoke the underlying function of a ``@require_permission``-decorated route.
|
||||
|
||||
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all
|
||||
|
||||
@@ -68,6 +68,33 @@ def provisioner_module():
|
||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_app_config_from_file(monkeypatch, request):
|
||||
"""Replace ``AppConfig.from_file`` with a minimal factory so tests that
|
||||
(directly or indirectly, e.g. via the LangGraph Server bootstrap path in
|
||||
``make_lead_agent``) load AppConfig from disk do not need a real
|
||||
``config.yaml`` on the filesystem.
|
||||
|
||||
Tests that want to verify the real ``from_file`` behaviour should mark
|
||||
themselves with ``@pytest.mark.real_from_file``.
|
||||
"""
|
||||
if request.node.get_closest_marker("real_from_file"):
|
||||
yield
|
||||
return
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
def _fake_from_file(config_path: str | None = None) -> AppConfig: # noqa: ARG001
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
monkeypatch.setattr(AppConfig, "from_file", _fake_from_file)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_user_context(request):
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||
|
||||
@@ -2,21 +2,27 @@
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
pytestmark = pytest.mark.real_from_file
|
||||
from pydantic import ValidationError
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig, get_acp_agents, load_acp_config_from_dict
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def setup_function():
|
||||
"""Reset ACP config before each test."""
|
||||
load_acp_config_from_dict({})
|
||||
def _make_config(acp_agents: dict | None = None) -> AppConfig:
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
acp_agents={name: ACPAgentConfig(**cfg) for name, cfg in (acp_agents or {}).items()},
|
||||
)
|
||||
|
||||
|
||||
def test_load_acp_config_sets_agents():
|
||||
load_acp_config_from_dict(
|
||||
def test_acp_agents_via_app_config():
|
||||
cfg = _make_config(
|
||||
{
|
||||
"claude_code": {
|
||||
"command": "claude-code-acp",
|
||||
@@ -26,39 +32,33 @@ def test_load_acp_config_sets_agents():
|
||||
}
|
||||
}
|
||||
)
|
||||
agents = get_acp_agents()
|
||||
agents = cfg.acp_agents
|
||||
assert "claude_code" in agents
|
||||
assert agents["claude_code"].command == "claude-code-acp"
|
||||
assert agents["claude_code"].description == "Claude Code for coding tasks"
|
||||
assert agents["claude_code"].model is None
|
||||
|
||||
|
||||
def test_load_acp_config_multiple_agents():
|
||||
load_acp_config_from_dict(
|
||||
def test_multiple_agents():
|
||||
cfg = _make_config(
|
||||
{
|
||||
"claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"},
|
||||
"codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"},
|
||||
}
|
||||
)
|
||||
agents = get_acp_agents()
|
||||
agents = cfg.acp_agents
|
||||
assert len(agents) == 2
|
||||
assert agents["codex"].args == ["--flag"]
|
||||
|
||||
|
||||
def test_load_acp_config_empty_clears_agents():
|
||||
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
|
||||
assert len(get_acp_agents()) == 1
|
||||
|
||||
load_acp_config_from_dict({})
|
||||
assert len(get_acp_agents()) == 0
|
||||
def test_empty_acp_agents():
|
||||
cfg = _make_config({})
|
||||
assert cfg.acp_agents == {}
|
||||
|
||||
|
||||
def test_load_acp_config_none_clears_agents():
|
||||
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
|
||||
assert len(get_acp_agents()) == 1
|
||||
|
||||
load_acp_config_from_dict(None)
|
||||
assert get_acp_agents() == {}
|
||||
def test_default_acp_agents_empty():
|
||||
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
assert cfg.acp_agents == {}
|
||||
|
||||
|
||||
def test_acp_agent_config_defaults():
|
||||
@@ -79,8 +79,8 @@ def test_acp_agent_config_env_default_is_empty():
|
||||
assert cfg.env == {}
|
||||
|
||||
|
||||
def test_load_acp_config_preserves_env():
|
||||
load_acp_config_from_dict(
|
||||
def test_acp_agent_preserves_env():
|
||||
cfg = _make_config(
|
||||
{
|
||||
"codex": {
|
||||
"command": "codex-acp",
|
||||
@@ -90,8 +90,7 @@ def test_load_acp_config_preserves_env():
|
||||
}
|
||||
}
|
||||
)
|
||||
cfg = get_acp_agents()["codex"]
|
||||
assert cfg.env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
|
||||
assert cfg.acp_agents["codex"].env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
|
||||
|
||||
|
||||
def test_acp_agent_config_with_model():
|
||||
@@ -115,13 +114,7 @@ def test_acp_agent_config_missing_description_raises():
|
||||
ACPAgentConfig(command="my-agent")
|
||||
|
||||
|
||||
def test_get_acp_agents_returns_empty_by_default():
|
||||
"""After clearing, should return empty dict."""
|
||||
load_acp_config_from_dict({})
|
||||
assert get_acp_agents() == {}
|
||||
|
||||
|
||||
def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, monkeypatch):
|
||||
def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
@@ -157,9 +150,9 @@ def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, mo
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8")
|
||||
AppConfig.from_file(str(config_path))
|
||||
assert set(get_acp_agents()) == {"codex"}
|
||||
app = AppConfig.from_file(str(config_path))
|
||||
assert set(app.acp_agents) == {"codex"}
|
||||
|
||||
config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8")
|
||||
AppConfig.from_file(str(config_path))
|
||||
assert get_acp_agents() == {}
|
||||
app = AppConfig.from_file(str(config_path))
|
||||
assert app.acp_agents == {}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from deerflow.config.agents_api_config import get_agents_api_config
|
||||
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
pytestmark = pytest.mark.real_from_file
|
||||
|
||||
|
||||
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
|
||||
@@ -29,149 +30,66 @@ def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> No
|
||||
)
|
||||
|
||||
|
||||
def _write_config_with_agents_api(
|
||||
path: Path,
|
||||
*,
|
||||
model_name: str,
|
||||
supports_thinking: bool,
|
||||
agents_api: dict | None = None,
|
||||
) -> None:
|
||||
config = {
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"models": [
|
||||
{
|
||||
"name": model_name,
|
||||
"use": "langchain_openai:ChatOpenAI",
|
||||
"model": "gpt-test",
|
||||
"supports_thinking": supports_thinking,
|
||||
}
|
||||
],
|
||||
}
|
||||
if agents_api is not None:
|
||||
config["agents_api"] = agents_api
|
||||
|
||||
path.write_text(yaml.safe_dump(config), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_extensions_config(path: Path) -> None:
|
||||
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
|
||||
|
||||
def test_app_config_defaults_missing_database_to_sqlite(tmp_path, monkeypatch):
|
||||
def test_from_file_reads_model_name(tmp_path, monkeypatch):
|
||||
"""``AppConfig.from_file`` is the only lifecycle method now; there is no
|
||||
process-global ``init/current``. Each consumer holds its own captured
|
||||
AppConfig instance.
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config(config_path, model_name="first-model", supports_thinking=False)
|
||||
_write_config(config_path, model_name="test-model", supports_thinking=False)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config = AppConfig.from_file(str(config_path))
|
||||
|
||||
assert config.database.backend == "sqlite"
|
||||
assert config.database.sqlite_dir == ".deer-flow/data"
|
||||
assert config.models[0].name == "test-model"
|
||||
|
||||
|
||||
def test_app_config_defaults_empty_database_to_sqlite(tmp_path, monkeypatch):
|
||||
def test_from_file_each_call_returns_fresh_instance(tmp_path, monkeypatch):
|
||||
"""Two reads of the same file produce separate AppConfig instances —
|
||||
no hidden singleton, no memoization. Callers decide when to re-read.
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config(config_path, model_name="model-a", supports_thinking=False)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config_a = AppConfig.from_file(str(config_path))
|
||||
assert config_a.models[0].name == "model-a"
|
||||
|
||||
_write_config(config_path, model_name="model-b", supports_thinking=True)
|
||||
config_b = AppConfig.from_file(str(config_path))
|
||||
assert config_b.models[0].name == "model-b"
|
||||
assert config_a is not config_b
|
||||
|
||||
|
||||
def test_config_version_check(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"database": {},
|
||||
"config_version": 1,
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"models": [],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config = AppConfig.from_file(str(config_path))
|
||||
|
||||
assert config.database.backend == "sqlite"
|
||||
assert config.database.sqlite_dir == ".deer-flow/data"
|
||||
|
||||
|
||||
def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config(config_path, model_name="first-model", supports_thinking=False)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
initial = get_app_config()
|
||||
assert initial.models[0].supports_thinking is False
|
||||
|
||||
_write_config(config_path, model_name="first-model", supports_thinking=True)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
reloaded = get_app_config()
|
||||
assert reloaded.models[0].supports_thinking is True
|
||||
assert reloaded is not initial
|
||||
finally:
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
|
||||
config_a = tmp_path / "config-a.yaml"
|
||||
config_b = tmp_path / "config-b.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config(config_a, model_name="model-a", supports_thinking=False)
|
||||
_write_config(config_b, model_name="model-b", supports_thinking=True)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_a))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
first = get_app_config()
|
||||
assert first.models[0].name == "model-a"
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_b))
|
||||
second = get_app_config()
|
||||
assert second.models[0].name == "model-b"
|
||||
assert second is not first
|
||||
finally:
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_agents_api(
|
||||
config_path,
|
||||
model_name="first-model",
|
||||
supports_thinking=False,
|
||||
agents_api={"enabled": True},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
initial = get_app_config()
|
||||
assert initial.models[0].name == "first-model"
|
||||
assert get_agents_api_config().enabled is True
|
||||
|
||||
_write_config_with_agents_api(
|
||||
config_path,
|
||||
model_name="first-model",
|
||||
supports_thinking=False,
|
||||
)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
reloaded = get_app_config()
|
||||
assert reloaded is not initial
|
||||
assert get_agents_api_config().enabled is False
|
||||
finally:
|
||||
reset_app_config()
|
||||
assert config is not None
|
||||
|
||||
@@ -174,20 +174,6 @@ def test_protected_post_no_cookie_returns_401(client):
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_post_with_internal_auth_header_passes():
|
||||
from app.gateway.internal_auth import create_internal_auth_headers
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
|
||||
res = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers=create_internal_auth_headers(),
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -5,25 +5,21 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
|
||||
|
||||
def _make_config(checkpointer: CheckpointerConfig | None = None) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset singleton state before each test."""
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
yield
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
|
||||
|
||||
@@ -33,24 +29,18 @@ def reset_state():
|
||||
|
||||
|
||||
class TestCheckpointerConfig:
|
||||
def test_load_memory_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
def test_memory_config(self):
|
||||
config = CheckpointerConfig(type="memory")
|
||||
assert config.type == "memory"
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_load_sqlite_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
def test_sqlite_config(self):
|
||||
config = CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")
|
||||
assert config.type == "sqlite"
|
||||
assert config.connection_string == "/tmp/test.db"
|
||||
|
||||
def test_load_postgres_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
def test_postgres_config(self):
|
||||
config = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
|
||||
assert config.type == "postgres"
|
||||
assert config.connection_string == "postgresql://localhost/db"
|
||||
|
||||
@@ -58,14 +48,9 @@ class TestCheckpointerConfig:
|
||||
config = CheckpointerConfig(type="memory")
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_set_config_to_none(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
set_checkpointer_config(None)
|
||||
assert get_checkpointer_config() is None
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
CheckpointerConfig(type="unknown")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -78,58 +63,58 @@ class TestGetCheckpointer:
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
cfg = _make_config()
|
||||
cp = get_checkpointer(cfg)
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_returns_in_memory_saver(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
cp = get_checkpointer()
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
cp = get_checkpointer(cfg)
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
cp1 = get_checkpointer(cfg)
|
||||
cp2 = get_checkpointer(cfg)
|
||||
assert cp1 is cp2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
cp1 = get_checkpointer(cfg)
|
||||
reset_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
cp2 = get_checkpointer(cfg)
|
||||
assert cp1 is not cp2
|
||||
|
||||
def test_sqlite_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(cfg)
|
||||
|
||||
def test_postgres_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(cfg)
|
||||
|
||||
def test_postgres_raises_when_connection_string_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres"})
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres"))
|
||||
mock_saver = MagicMock()
|
||||
mock_module = MagicMock()
|
||||
mock_module.PostgresSaver = mock_saver
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ValueError, match="connection_string is required"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(cfg)
|
||||
|
||||
def test_sqlite_creates_saver(self):
|
||||
"""SQLite checkpointer is created when package is available."""
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
@@ -144,7 +129,7 @@ class TestGetCheckpointer:
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(cfg)
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once()
|
||||
@@ -225,7 +210,7 @@ class TestGetCheckpointer:
|
||||
|
||||
def test_postgres_creates_saver(self):
|
||||
"""Postgres checkpointer is created when packages are available."""
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
@@ -240,7 +225,7 @@ class TestGetCheckpointer:
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(cfg)
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
|
||||
@@ -268,7 +253,6 @@ class TestAsyncCheckpointer:
|
||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch(
|
||||
@@ -276,7 +260,7 @@ class TestAsyncCheckpointer:
|
||||
return_value="/tmp/resolved/test.db",
|
||||
),
|
||||
):
|
||||
async with make_checkpointer() as saver:
|
||||
async with make_checkpointer(mock_config) as saver:
|
||||
assert saver is mock_saver
|
||||
|
||||
mock_to_thread.assert_awaited_once()
|
||||
@@ -294,12 +278,10 @@ class TestAsyncCheckpointer:
|
||||
|
||||
class TestAppConfigLoadsCheckpointer:
|
||||
def test_load_checkpointer_section(self):
|
||||
"""load_checkpointer_config_from_dict populates the global config."""
|
||||
set_checkpointer_config(None)
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cfg = get_checkpointer_config()
|
||||
assert cfg is not None
|
||||
assert cfg.type == "memory"
|
||||
"""AppConfig with checkpointer section has the correct config."""
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
assert cfg.checkpointer is not None
|
||||
assert cfg.checkpointer.type == "memory"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -309,69 +291,7 @@ class TestAppConfigLoadsCheckpointer:
|
||||
|
||||
class TestClientCheckpointerFallback:
|
||||
def test_client_uses_config_checkpointer_when_none_provided(self):
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("deerflow.client.get_app_config", return_value=config_mock),
|
||||
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value=""),
|
||||
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=None)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert "checkpointer" in captured_kwargs
|
||||
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
|
||||
|
||||
def test_client_explicit_checkpointer_takes_precedence(self):
|
||||
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
explicit_cp = MagicMock()
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("deerflow.client.get_app_config", return_value=config_mock),
|
||||
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value=""),
|
||||
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=explicit_cp)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert captured_kwargs["checkpointer"] is explicit_cp
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer(app_config) when checkpointer=None."""
|
||||
# This is a structural test — verifying the fallback path exists.
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
assert cfg.checkpointer is not None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test for issue #1016: checkpointer should not return None."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
@@ -14,42 +14,38 @@ class TestCheckpointerNoneFix:
|
||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
mock_config.database = None
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
async with make_checkpointer(mock_config) as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
# Should be able to call alist() without AttributeError
|
||||
# This is what LangGraph does and what was failing in issue #1016
|
||||
result = []
|
||||
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
|
||||
result.append(item)
|
||||
# Should be able to call alist() without AttributeError
|
||||
# This is what LangGraph does and what was failing in issue #1016
|
||||
result = []
|
||||
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
|
||||
result.append(item)
|
||||
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
|
||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with checkpointer_context() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
with checkpointer_context(mock_config) as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
# Should be able to call list() without AttributeError
|
||||
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
|
||||
# Should be able to call list() without AttributeError
|
||||
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
|
||||
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.gateway.routers.models import ModelResponse, ModelsListResponse
|
||||
from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
|
||||
from app.gateway.routers.uploads import UploadResponse
|
||||
from deerflow.client import DeerFlowClient
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.uploads.manager import PathTraversalError
|
||||
|
||||
@@ -44,9 +45,12 @@ def mock_app_config():
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_app_config):
|
||||
"""Create a DeerFlowClient with mocked config loading."""
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
return DeerFlowClient()
|
||||
"""Create a DeerFlowClient holding the mocked config directly.
|
||||
|
||||
Passing ``config=`` is the documented post-refactor way to inject a
|
||||
test AppConfig; nothing relies on process-global state.
|
||||
"""
|
||||
return DeerFlowClient(config=mock_app_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -67,8 +71,7 @@ class TestClientInit:
|
||||
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
@@ -78,24 +81,21 @@ class TestClientInit:
|
||||
assert c._middlewares == [mock_middleware]
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="invalid name with spaces!")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="../path/traversal")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="invalid name with spaces!")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="../path/traversal")
|
||||
|
||||
def test_custom_config_path(self, mock_app_config):
|
||||
with (
|
||||
patch("deerflow.client.reload_app_config") as mock_reload,
|
||||
patch("deerflow.client.get_app_config", return_value=mock_app_config),
|
||||
):
|
||||
DeerFlowClient(config_path="/tmp/custom.yaml")
|
||||
mock_reload.assert_called_once_with("/tmp/custom.yaml")
|
||||
# rather than touching AppConfig.init() / process-global state.
|
||||
with patch.object(AppConfig, "from_file", return_value=mock_app_config) as mock_from_file:
|
||||
client = DeerFlowClient(config_path="/tmp/custom.yaml")
|
||||
mock_from_file.assert_called_once_with("/tmp/custom.yaml")
|
||||
assert client._app_config is mock_app_config
|
||||
|
||||
def test_checkpointer_stored(self, mock_app_config):
|
||||
cp = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(checkpointer=cp)
|
||||
c = DeerFlowClient(checkpointer=cp)
|
||||
assert c._checkpointer is cp
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class TestConfigQueries:
|
||||
|
||||
with patch("deerflow.skills.loader.load_skills", return_value=[skill]) as mock_load:
|
||||
result = client.list_skills()
|
||||
mock_load.assert_called_once_with(enabled_only=False)
|
||||
mock_load.assert_called_once_with(client._app_config, enabled_only=False)
|
||||
|
||||
assert "skills" in result
|
||||
assert len(result["skills"]) == 1
|
||||
@@ -141,7 +141,7 @@ class TestConfigQueries:
|
||||
def test_list_skills_enabled_only(self, client):
|
||||
with patch("deerflow.skills.loader.load_skills", return_value=[]) as mock_load:
|
||||
client.list_skills(enabled_only=True)
|
||||
mock_load.assert_called_once_with(enabled_only=True)
|
||||
mock_load.assert_called_once_with(client._app_config, enabled_only=True)
|
||||
|
||||
def test_get_memory(self, client):
|
||||
memory = {"version": "1.0", "facts": []}
|
||||
@@ -251,8 +251,8 @@ class TestStream:
|
||||
# Verify context passed to agent.stream
|
||||
agent.stream.assert_called_once()
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert call_kwargs["context"]["thread_id"] == "t1"
|
||||
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
|
||||
ctx = call_kwargs["context"]
|
||||
assert ctx.app_config is client._app_config
|
||||
|
||||
def test_custom_mode_is_normalized_to_string(self, client):
|
||||
"""stream() forwards custom events even when the mode is not a plain string."""
|
||||
@@ -1091,8 +1091,8 @@ class TestMcpConfig:
|
||||
ext_config = MagicMock()
|
||||
ext_config.mcp_servers = {"github": server}
|
||||
|
||||
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
|
||||
result = client.get_mcp_config()
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
result = client.get_mcp_config()
|
||||
|
||||
assert "mcp_servers" in result
|
||||
assert "github" in result["mcp_servers"]
|
||||
@@ -1116,10 +1116,11 @@ class TestMcpConfig:
|
||||
# Pre-set agent to verify it gets invalidated
|
||||
client._agent = MagicMock()
|
||||
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
|
||||
patch("deerflow.client.get_extensions_config", return_value=current_config),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
|
||||
):
|
||||
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
|
||||
|
||||
@@ -1177,12 +1178,12 @@ class TestSkillsManagement:
|
||||
try:
|
||||
# Pre-set agent to verify it gets invalidated
|
||||
client._agent = MagicMock()
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
result = client.update_skill("test-skill", enabled=False)
|
||||
assert result["enabled"] is False
|
||||
@@ -1245,7 +1246,7 @@ class TestMemoryManagement:
|
||||
|
||||
assert mock_import.call_count == 1
|
||||
call_args = mock_import.call_args
|
||||
assert call_args.args == (imported,)
|
||||
assert call_args.args == (client._app_config.memory, imported)
|
||||
assert "user_id" in call_args.kwargs
|
||||
assert result == imported
|
||||
|
||||
@@ -1270,6 +1271,7 @@ class TestMemoryManagement:
|
||||
confidence=0.88,
|
||||
)
|
||||
create_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
content="User prefers concise code reviews.",
|
||||
category="preference",
|
||||
confidence=0.88,
|
||||
@@ -1280,7 +1282,7 @@ class TestMemoryManagement:
|
||||
data = {"version": "1.0", "facts": []}
|
||||
with patch("deerflow.agents.memory.updater.delete_memory_fact", return_value=data) as delete_fact:
|
||||
result = client.delete_memory_fact("fact_123")
|
||||
delete_fact.assert_called_once_with("fact_123")
|
||||
delete_fact.assert_called_once_with(client._app_config.memory, "fact_123")
|
||||
assert result == data
|
||||
|
||||
def test_update_memory_fact(self, client):
|
||||
@@ -1293,6 +1295,7 @@ class TestMemoryManagement:
|
||||
confidence=0.91,
|
||||
)
|
||||
update_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
fact_id="fact_123",
|
||||
content="User prefers spaces",
|
||||
category="workflow",
|
||||
@@ -1308,6 +1311,7 @@ class TestMemoryManagement:
|
||||
"User prefers spaces",
|
||||
)
|
||||
update_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
fact_id="fact_123",
|
||||
content="User prefers spaces",
|
||||
category=None,
|
||||
@@ -1316,37 +1320,40 @@ class TestMemoryManagement:
|
||||
assert result == data
|
||||
|
||||
def test_get_memory_config(self, client):
|
||||
config = MagicMock()
|
||||
config.enabled = True
|
||||
config.storage_path = ".deer-flow/memory.json"
|
||||
config.debounce_seconds = 30
|
||||
config.max_facts = 100
|
||||
config.fact_confidence_threshold = 0.7
|
||||
config.injection_enabled = True
|
||||
config.max_injection_tokens = 2000
|
||||
mem_config = MagicMock()
|
||||
mem_config.enabled = True
|
||||
mem_config.storage_path = ".deer-flow/memory.json"
|
||||
mem_config.debounce_seconds = 30
|
||||
mem_config.max_facts = 100
|
||||
mem_config.fact_confidence_threshold = 0.7
|
||||
mem_config.injection_enabled = True
|
||||
mem_config.max_injection_tokens = 2000
|
||||
|
||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=config):
|
||||
result = client.get_memory_config()
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_config
|
||||
|
||||
client._app_config = app_cfg
|
||||
result = client.get_memory_config()
|
||||
|
||||
assert result["enabled"] is True
|
||||
assert result["max_facts"] == 100
|
||||
|
||||
def test_get_memory_status(self, client):
|
||||
config = MagicMock()
|
||||
config.enabled = True
|
||||
config.storage_path = ".deer-flow/memory.json"
|
||||
config.debounce_seconds = 30
|
||||
config.max_facts = 100
|
||||
config.fact_confidence_threshold = 0.7
|
||||
config.injection_enabled = True
|
||||
config.max_injection_tokens = 2000
|
||||
mem_config = MagicMock()
|
||||
mem_config.enabled = True
|
||||
mem_config.storage_path = ".deer-flow/memory.json"
|
||||
mem_config.debounce_seconds = 30
|
||||
mem_config.max_facts = 100
|
||||
mem_config.fact_confidence_threshold = 0.7
|
||||
mem_config.injection_enabled = True
|
||||
mem_config.max_injection_tokens = 2000
|
||||
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_config
|
||||
data = {"version": "1.0", "facts": []}
|
||||
|
||||
with (
|
||||
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=data),
|
||||
):
|
||||
client._app_config = app_cfg
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=data):
|
||||
result = client.get_memory_status()
|
||||
|
||||
assert "config" in result
|
||||
@@ -1800,10 +1807,10 @@ class TestScenarioConfigManagement:
|
||||
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
|
||||
|
||||
client._agent = MagicMock() # Simulate existing agent
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=current_config),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
|
||||
):
|
||||
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
|
||||
assert "my-mcp" in mcp_result["mcp_servers"]
|
||||
@@ -1832,8 +1839,7 @@ class TestScenarioConfigManagement:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
skill_result = client.update_skill("code-gen", enabled=False)
|
||||
assert skill_result["enabled"] is False
|
||||
@@ -2021,10 +2027,10 @@ class TestScenarioMemoryWorkflow:
|
||||
refreshed = client.reload_memory()
|
||||
assert len(refreshed["facts"]) == 2
|
||||
|
||||
with (
|
||||
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data),
|
||||
):
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = config
|
||||
client._app_config = app_cfg
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data):
|
||||
status = client.get_memory_status()
|
||||
assert status["config"]["enabled"] is True
|
||||
assert len(status["data"]["facts"]) == 2
|
||||
@@ -2085,8 +2091,7 @@ class TestScenarioSkillInstallAndUse:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
toggled = client.update_skill("my-analyzer", enabled=False)
|
||||
assert toggled["enabled"] is False
|
||||
@@ -2220,8 +2225,7 @@ class TestGatewayConformance:
|
||||
mock_app_config.models = [model]
|
||||
mock_app_config.token_usage.enabled = True
|
||||
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
client = DeerFlowClient()
|
||||
client = DeerFlowClient(config=mock_app_config)
|
||||
|
||||
result = client.list_models()
|
||||
parsed = ModelsListResponse(**result)
|
||||
@@ -2240,8 +2244,7 @@ class TestGatewayConformance:
|
||||
mock_app_config.models = [model]
|
||||
mock_app_config.get_model_config.return_value = model
|
||||
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
client = DeerFlowClient()
|
||||
client = DeerFlowClient(config=mock_app_config)
|
||||
|
||||
result = client.get_model("test-model")
|
||||
assert result is not None
|
||||
@@ -2310,8 +2313,8 @@ class TestGatewayConformance:
|
||||
ext_config = MagicMock()
|
||||
ext_config.mcp_servers = {"test": server}
|
||||
|
||||
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
|
||||
result = client.get_mcp_config()
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
result = client.get_mcp_config()
|
||||
|
||||
parsed = McpConfigResponse(**result)
|
||||
assert "test" in parsed.mcp_servers
|
||||
@@ -2335,10 +2338,10 @@ class TestGatewayConformance:
|
||||
config_file = tmp_path / "extensions_config.json"
|
||||
config_file.write_text("{}")
|
||||
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
with (
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=ext_config)),
|
||||
):
|
||||
result = client.update_mcp_config({"srv": server.model_dump.return_value})
|
||||
|
||||
@@ -2369,8 +2372,11 @@ class TestGatewayConformance:
|
||||
mem_cfg.injection_enabled = True
|
||||
mem_cfg.max_injection_tokens = 2000
|
||||
|
||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
||||
result = client.get_memory_config()
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_cfg
|
||||
|
||||
client._app_config = app_cfg
|
||||
result = client.get_memory_config()
|
||||
|
||||
parsed = MemoryConfigResponse(**result)
|
||||
assert parsed.enabled is True
|
||||
@@ -2386,6 +2392,8 @@ class TestGatewayConformance:
|
||||
mem_cfg.injection_enabled = True
|
||||
mem_cfg.max_injection_tokens = 2000
|
||||
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_cfg
|
||||
memory_data = {
|
||||
"version": "1.0",
|
||||
"lastUpdated": "",
|
||||
@@ -2402,10 +2410,8 @@ class TestGatewayConformance:
|
||||
"facts": [],
|
||||
}
|
||||
|
||||
with (
|
||||
patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data),
|
||||
):
|
||||
client._app_config = app_cfg
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data):
|
||||
result = client.get_memory_status()
|
||||
|
||||
parsed = MemoryStatusResponse(**result)
|
||||
@@ -2694,8 +2700,7 @@ class TestConfigUpdateErrors:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="disappeared"):
|
||||
client.update_skill("ghost-skill", enabled=False)
|
||||
@@ -3074,10 +3079,10 @@ class TestBugAgentInvalidationInconsistency:
|
||||
config_file = Path(tmp) / "ext.json"
|
||||
config_file.write_text("{}")
|
||||
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=current_config),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=reloaded),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),
|
||||
):
|
||||
client.update_mcp_config({})
|
||||
|
||||
@@ -3109,8 +3114,7 @@ class TestBugAgentInvalidationInconsistency:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
client.update_skill("s1", enabled=False)
|
||||
|
||||
|
||||
@@ -56,6 +56,10 @@ def _make_e2e_config() -> AppConfig:
|
||||
- ``E2E_BASE_URL`` (default: ``https://ark-cn-beijing.bytedance.net/api/v3``)
|
||||
- ``OPENAI_API_KEY`` (required for LLM tests)
|
||||
"""
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
return AppConfig(
|
||||
models=[
|
||||
ModelConfig(
|
||||
@@ -73,6 +77,9 @@ def _make_e2e_config() -> AppConfig:
|
||||
)
|
||||
],
|
||||
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", allow_host_bash=True),
|
||||
title=TitleConfig(enabled=False),
|
||||
memory=MemoryConfig(enabled=False),
|
||||
summarization=SummarizationConfig(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
@@ -87,7 +94,7 @@ def e2e_env(tmp_path, monkeypatch):
|
||||
|
||||
- DEER_FLOW_HOME → tmp_path (all thread data lands in a temp dir)
|
||||
- Singletons reset so they pick up the new env
|
||||
- Title/memory/summarization disabled to avoid extra LLM calls
|
||||
- Title/memory/summarization disabled via AppConfig fields
|
||||
- AppConfig built programmatically (avoids config.yaml param-name issues)
|
||||
"""
|
||||
# 1. Filesystem isolation
|
||||
@@ -95,30 +102,12 @@ def e2e_env(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("deerflow.config.paths._paths", None)
|
||||
monkeypatch.setattr("deerflow.sandbox.sandbox_provider._default_sandbox_provider", None)
|
||||
|
||||
# 2. Inject a clean AppConfig via the global singleton.
|
||||
config = _make_e2e_config()
|
||||
monkeypatch.setattr("deerflow.config.app_config._app_config", config)
|
||||
monkeypatch.setattr("deerflow.config.app_config._app_config_is_custom", True)
|
||||
# 1b. Override the autouse ``AppConfig.from_file`` stub from conftest
|
||||
# (minimal test config) with the e2e-specific config that carries a
|
||||
# real model entry and disables title/memory/summarization.
|
||||
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda config_path=None: _make_e2e_config()))
|
||||
|
||||
# 3. Disable title generation (extra LLM call, non-deterministic)
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
monkeypatch.setattr("deerflow.config.title_config._title_config", TitleConfig(enabled=False))
|
||||
|
||||
# 4. Disable memory queueing (avoids background threads & file writes)
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.memory_middleware.get_memory_config",
|
||||
lambda: MemoryConfig(enabled=False),
|
||||
)
|
||||
|
||||
# 5. Ensure summarization is off (default, but be explicit)
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
|
||||
monkeypatch.setattr("deerflow.config.summarization_config._summarization_config", SummarizationConfig(enabled=False))
|
||||
|
||||
# 6. Exclude TitleMiddleware from the chain.
|
||||
# 2. Exclude TitleMiddleware from the chain.
|
||||
# It triggers an extra LLM call to generate a thread title, which adds
|
||||
# non-determinism and cost to E2E tests (title generation is already
|
||||
# disabled via TitleConfig above, but the middleware still participates
|
||||
@@ -666,10 +655,9 @@ class TestConfigManagement:
|
||||
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
|
||||
|
||||
# Force reload so the singleton picks up our test file
|
||||
from deerflow.config.extensions_config import reload_extensions_config
|
||||
|
||||
reload_extensions_config()
|
||||
# Mock from_file so update_mcp_config's internal reload works without config.yaml
|
||||
e2e_config = _make_e2e_config()
|
||||
monkeypatch.setattr(AppConfig, "from_file", classmethod(lambda cls, path=None: e2e_config))
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
# Simulate a cached agent
|
||||
@@ -693,9 +681,9 @@ class TestConfigManagement:
|
||||
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
|
||||
|
||||
from deerflow.config.extensions_config import reload_extensions_config
|
||||
|
||||
reload_extensions_config()
|
||||
# Mock from_file so update_skill's internal reload works without config.yaml
|
||||
e2e_config = _make_e2e_config()
|
||||
monkeypatch.setattr(AppConfig, "from_file", classmethod(lambda cls, path=None: e2e_config))
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
c._agent = "fake-agent-placeholder"
|
||||
@@ -721,10 +709,6 @@ class TestConfigManagement:
|
||||
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
|
||||
|
||||
from deerflow.config.extensions_config import reload_extensions_config
|
||||
|
||||
reload_extensions_config()
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
c.update_skill("nonexistent-skill-xyz", enabled=True)
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestLiveStreaming:
|
||||
class TestLiveToolUse:
|
||||
def test_agent_uses_bash_tool(self, client):
|
||||
"""Agent uses bash tool when asked to run a command."""
|
||||
if not is_host_bash_allowed():
|
||||
if not is_host_bash_allowed(client._app_config):
|
||||
pytest.skip("Host bash is disabled for LocalSandboxProvider in the active config")
|
||||
|
||||
events = list(client.stream("Use the bash tool to run: echo 'LIVE_TEST_OK'. Then tell me the output."))
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Multi-client isolation regression test.
|
||||
|
||||
Phase 2 Task P2-3: ``DeerFlowClient`` now captures its ``AppConfig`` in the
|
||||
constructor instead of going through a process-global config.
|
||||
This test pins the resulting invariant: two clients with different configs
|
||||
can coexist without contending over shared state.
|
||||
|
||||
Before P2-3, the shared ``AppConfig._global`` caused the second client's
|
||||
``init()`` to clobber the first client's config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_agent_creation(monkeypatch):
|
||||
"""Prevent lazy agent creation — we only care about config access."""
|
||||
monkeypatch.setattr(DeerFlowClient, "_get_or_create_agent", MagicMock(), raising=False)
|
||||
|
||||
|
||||
def test_two_clients_do_not_clobber_each_other(disable_agent_creation):
|
||||
"""Two clients with distinct configs keep their own AppConfig."""
|
||||
cfg_a = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
memory=MemoryConfig(enabled=True),
|
||||
)
|
||||
cfg_b = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
memory=MemoryConfig(enabled=False),
|
||||
)
|
||||
|
||||
client_a = DeerFlowClient(config=cfg_a)
|
||||
client_b = DeerFlowClient(config=cfg_b)
|
||||
|
||||
# Identity: each client retains its own instance, not a shared ref
|
||||
assert client_a._app_config is cfg_a
|
||||
assert client_b._app_config is cfg_b
|
||||
|
||||
# Semantic: memory flag differs
|
||||
assert client_a._app_config.memory.enabled is True
|
||||
assert client_b._app_config.memory.enabled is False
|
||||
|
||||
|
||||
def test_client_config_precedes_path(disable_agent_creation, tmp_path):
|
||||
"""When both config= and config_path= are given, config= wins."""
|
||||
cfg = AppConfig(sandbox=SandboxConfig(use="test"), log_level="debug")
|
||||
|
||||
# config_path points at a file that doesn't exist — proves it's unused
|
||||
bogus_path = str(tmp_path / "nope.yaml")
|
||||
client = DeerFlowClient(config_path=bogus_path, config=cfg)
|
||||
|
||||
assert client._app_config is cfg
|
||||
assert client._app_config.log_level == "debug"
|
||||
|
||||
|
||||
def test_multi_client_gateway_dict_returns_distinct(disable_agent_creation):
|
||||
"""get_mcp_config() reads from self._app_config, not process-global."""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
ext_a = ExtensionsConfig(mcp_servers={"server-a": McpServerConfig(enabled=True)})
|
||||
ext_b = ExtensionsConfig(mcp_servers={"server-b": McpServerConfig(enabled=True)})
|
||||
|
||||
cfg_a = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ext_a)
|
||||
cfg_b = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ext_b)
|
||||
|
||||
client_a = DeerFlowClient(config=cfg_a)
|
||||
client_b = DeerFlowClient(config=cfg_b)
|
||||
|
||||
servers_a = client_a.get_mcp_config()["mcp_servers"]
|
||||
servers_b = client_b.get_mcp_config()["mcp_servers"]
|
||||
|
||||
assert set(servers_a.keys()) == {"server-a"}
|
||||
assert set(servers_b.keys()) == {"server-b"}
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Verify that all sub-config Pydantic models are frozen (immutable).
|
||||
|
||||
Frozen models reject attribute assignment after construction, raising
|
||||
pydantic.ValidationError. This test collects every BaseModel subclass
|
||||
defined in the deerflow.config package and asserts that mutation is
|
||||
blocked.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import pkgutil
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import deerflow.config as config_pkg
|
||||
|
||||
|
||||
def _collect_config_models() -> list[type[BaseModel]]:
|
||||
"""Walk deerflow.config.* and return all concrete BaseModel subclasses."""
|
||||
import importlib
|
||||
|
||||
models: list[type[BaseModel]] = []
|
||||
package_path = config_pkg.__path__
|
||||
package_prefix = config_pkg.__name__ + "."
|
||||
|
||||
for _importer, modname, _ispkg in pkgutil.walk_packages(package_path, prefix=package_prefix):
|
||||
try:
|
||||
mod = importlib.import_module(modname)
|
||||
except Exception:
|
||||
continue
|
||||
for _name, obj in inspect.getmembers(mod, inspect.isclass):
|
||||
if (
|
||||
issubclass(obj, BaseModel)
|
||||
and obj is not BaseModel
|
||||
and obj.__module__ == mod.__name__
|
||||
):
|
||||
models.append(obj)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
_EXCLUDED: set[str] = set()
|
||||
|
||||
_ALL_MODELS = [m for m in _collect_config_models() if m.__name__ not in _EXCLUDED]
|
||||
|
||||
# Sanity: make sure we actually collected a meaningful set.
|
||||
assert len(_ALL_MODELS) >= 15, f"Expected at least 15 config models, found {len(_ALL_MODELS)}: {[m.__name__ for m in _ALL_MODELS]}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
|
||||
def test_config_model_is_frozen(model_cls: type[BaseModel]):
|
||||
"""Every sub-config model must have frozen=True in its model_config."""
|
||||
cfg = model_cls.model_config
|
||||
assert cfg.get("frozen") is True, (
|
||||
f"{model_cls.__name__} is not frozen. "
|
||||
f"Add `model_config = ConfigDict(frozen=True)` or add `frozen=True` to the existing ConfigDict."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
|
||||
def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
|
||||
"""Constructing then mutating any field must raise ValidationError."""
|
||||
# Build a minimal instance -- use model_construct to skip validation for
|
||||
# required fields, then pick the first field to try mutating.
|
||||
fields = list(model_cls.model_fields.keys())
|
||||
if not fields:
|
||||
pytest.skip(f"{model_cls.__name__} has no fields")
|
||||
|
||||
instance = model_cls.model_construct()
|
||||
first_field = fields[0]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
setattr(instance, first_field, "MUTATED")
|
||||
|
||||
|
||||
def test_extensions_nested_dict_mutation_is_not_blocked_by_pydantic():
|
||||
"""Regression guard: Pydantic `frozen=True` does NOT deep-freeze container fields.
|
||||
|
||||
This test documents the trap — callers MUST compose a new dict and persist
|
||||
it + reload AppConfig instead of reaching into `extensions.skills[x]`.
|
||||
If you need the dict to be truly immutable, wrap with Mapping/frozendict.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
|
||||
|
||||
ext = ExtensionsConfig(mcp_servers={}, skills={"a": SkillStateConfig(enabled=True)})
|
||||
|
||||
# This is the pre-refactor anti-pattern: Pydantic lets it through because
|
||||
# the outer model is frozen but the inner dict is a plain builtin. No error.
|
||||
ext.skills["a"] = SkillStateConfig(enabled=False)
|
||||
ext.skills["b"] = SkillStateConfig(enabled=True)
|
||||
|
||||
# The test asserts the leak exists so a future "add deep-freeze" change
|
||||
# flips this expectation and forces call-site review.
|
||||
assert ext.skills["a"].enabled is False
|
||||
assert "b" in ext.skills
|
||||
@@ -9,7 +9,9 @@ import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from deerflow.config.agents_api_config import AgentsApiConfig, get_agents_api_config, set_agents_api_config
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = MemoryConfig()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -329,38 +331,26 @@ class TestMemoryFilePath:
|
||||
def test_global_memory_path(self, tmp_path):
|
||||
"""None agent_name should return global memory file."""
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
|
||||
def test_agent_memory_path(self, tmp_path):
|
||||
"""Providing agent_name should return per-agent memory file."""
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path("code-reviewer")
|
||||
assert path == tmp_path / "agents" / "code-reviewer" / "memory.json"
|
||||
|
||||
def test_different_paths_for_different_agents(self, tmp_path):
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path_global = storage._get_memory_file_path(None)
|
||||
path_a = storage._get_memory_file_path("agent-a")
|
||||
path_b = storage._get_memory_file_path("agent-b")
|
||||
@@ -380,47 +370,32 @@ def _make_test_app(tmp_path: Path):
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.gateway.routers.agents import router
|
||||
from deerflow.config.agents_api_config import AgentsApiConfig
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
# The agents router gates every route through ``Depends(get_config)`` and
|
||||
# only allows access when ``agents_api.enabled`` is true. Wire a permissive
|
||||
# AppConfig onto ``app.state.config`` so the routes are reachable in tests.
|
||||
app.state.config = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
agents_api=AgentsApiConfig(enabled=True),
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_client(tmp_path):
|
||||
"""TestClient with agents router, using tmp_path as base_dir."""
|
||||
import app.gateway.routers.agents as agents_router
|
||||
|
||||
paths_instance = _make_paths(tmp_path)
|
||||
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
|
||||
set_agents_api_config(AgentsApiConfig(enabled=True))
|
||||
try:
|
||||
app = _make_test_app(tmp_path)
|
||||
with TestClient(app) as client:
|
||||
client._tmp_path = tmp_path # type: ignore[attr-defined]
|
||||
yield client
|
||||
finally:
|
||||
set_agents_api_config(previous_config)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def disabled_agent_client(tmp_path):
|
||||
"""TestClient with agents router while the management API is disabled."""
|
||||
import app.gateway.routers.agents as agents_router
|
||||
|
||||
paths_instance = _make_paths(tmp_path)
|
||||
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
|
||||
set_agents_api_config(AgentsApiConfig(enabled=False))
|
||||
try:
|
||||
app = _make_test_app(tmp_path)
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
finally:
|
||||
set_agents_api_config(previous_config)
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch("app.gateway.routers.agents.get_paths", return_value=paths_instance):
|
||||
app = _make_test_app(tmp_path)
|
||||
with TestClient(app) as client:
|
||||
client._tmp_path = tmp_path # type: ignore[attr-defined]
|
||||
yield client
|
||||
|
||||
|
||||
class TestAgentsAPI:
|
||||
@@ -586,37 +561,3 @@ class TestUserProfileAPI:
|
||||
response = agent_client.put("/api/user-profile", json={"content": ""})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["content"] is None
|
||||
|
||||
|
||||
class TestAgentsApiDisabled:
|
||||
def test_agents_list_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.get("/api/agents")
|
||||
assert response.status_code == 403
|
||||
assert "agents_api.enabled=true" in response.json()["detail"]
|
||||
|
||||
def test_agent_get_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.get("/api/agents/example-agent")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_name_check_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.get("/api/agents/check", params={"name": "example-agent"})
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_create_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.post("/api/agents", json={"name": "example-agent", "soul": "blocked"})
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_update_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.put("/api/agents/example-agent", json={"description": "blocked"})
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_delete_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.delete("/api/agents/example-agent")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_user_profile_routes_return_403(self, disabled_agent_client):
|
||||
get_response = disabled_agent_client.get("/api/user-profile")
|
||||
put_response = disabled_agent_client.put("/api/user-profile", json={"content": "blocked"})
|
||||
|
||||
assert get_response.status_code == 403
|
||||
assert put_response.status_code == 403
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Tests for DeerFlowContext and resolve_context()."""
|
||||
|
||||
from dataclasses import FrozenInstanceError
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _make_config(**overrides) -> AppConfig:
|
||||
defaults = {"sandbox": SandboxConfig(use="test")}
|
||||
defaults.update(overrides)
|
||||
return AppConfig(**defaults)
|
||||
|
||||
|
||||
class TestDeerFlowContext:
|
||||
def test_frozen(self):
|
||||
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
ctx.app_config = _make_config()
|
||||
|
||||
def test_fields(self):
|
||||
config = _make_config()
|
||||
ctx = DeerFlowContext(app_config=config, thread_id="t1", agent_name="test-agent")
|
||||
assert ctx.thread_id == "t1"
|
||||
assert ctx.agent_name == "test-agent"
|
||||
assert ctx.app_config is config
|
||||
|
||||
def test_agent_name_default(self):
|
||||
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
|
||||
assert ctx.agent_name is None
|
||||
|
||||
def test_thread_id_required(self):
|
||||
with pytest.raises(TypeError):
|
||||
DeerFlowContext(app_config=_make_config()) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestResolveContext:
|
||||
def test_returns_typed_context_directly(self):
|
||||
"""Gateway/Client path: runtime.context is DeerFlowContext → return as-is."""
|
||||
config = _make_config()
|
||||
ctx = DeerFlowContext(app_config=config, thread_id="t1")
|
||||
runtime = MagicMock()
|
||||
runtime.context = ctx
|
||||
assert resolve_context(runtime) is ctx
|
||||
|
||||
def test_raises_on_none_context(self):
|
||||
"""Without a typed DeerFlowContext, resolve_context refuses to guess."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = None
|
||||
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
|
||||
resolve_context(runtime)
|
||||
|
||||
def test_raises_on_dict_context(self):
|
||||
"""Legacy dict shape is no longer supported — we raise instead of lazily loading AppConfig."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
|
||||
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
|
||||
resolve_context(runtime)
|
||||
@@ -5,20 +5,36 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
|
||||
|
||||
def _runtime_with_config(config):
|
||||
"""Build a runtime carrying a custom (possibly mocked) app_config.
|
||||
|
||||
``DeerFlowContext`` is a frozen dataclass typed as ``AppConfig`` but
|
||||
dataclasses don't enforce the type at runtime — handing a Mock through
|
||||
lets tests exercise the tool's ``get_tool_config`` lookup without going
|
||||
through a process-global config.
|
||||
"""
|
||||
ctx = _P2Ctx.__new__(_P2Ctx)
|
||||
object.__setattr__(ctx, "app_config", config)
|
||||
object.__setattr__(ctx, "thread_id", "test-thread")
|
||||
object.__setattr__(ctx, "agent_name", None)
|
||||
return _P2NS(context=ctx)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_config():
|
||||
"""Mock the app config to return tool configurations."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 5,
|
||||
"search_type": "auto",
|
||||
"contents_max_characters": 1000,
|
||||
"api_key": "test-api-key",
|
||||
}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock_config
|
||||
"""Fixture retained as a pass-through: tests inject config via runtime directly."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -49,7 +65,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test query"})
|
||||
result = web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert len(parsed) == 2
|
||||
@@ -67,30 +83,30 @@ class TestWebSearchTool:
|
||||
|
||||
def test_search_with_custom_config(self, mock_exa_client):
|
||||
"""Test search respects custom configuration values."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 10,
|
||||
"search_type": "neural",
|
||||
"contents_max_characters": 2000,
|
||||
"api_key": "test-key",
|
||||
}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 10,
|
||||
"search_type": "neural",
|
||||
"contents_max_characters": 2000,
|
||||
"api_key": "test-key",
|
||||
}
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = tool_config
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_exa_client.search.return_value = mock_response
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_exa_client.search.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "neural search"})
|
||||
web_search_tool.func(query="neural search", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
mock_exa_client.search.assert_called_once_with(
|
||||
"neural search",
|
||||
type="neural",
|
||||
num_results=10,
|
||||
contents={"highlights": {"max_characters": 2000}},
|
||||
)
|
||||
mock_exa_client.search.assert_called_once_with(
|
||||
"neural search",
|
||||
type="neural",
|
||||
num_results=10,
|
||||
contents={"highlights": {"max_characters": 2000}},
|
||||
)
|
||||
|
||||
def test_search_with_no_highlights(self, mock_app_config, mock_exa_client):
|
||||
"""Test search handles results with no highlights."""
|
||||
@@ -105,7 +121,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
result = web_search_tool.func(query="test", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed[0]["snippet"] == ""
|
||||
@@ -118,7 +134,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "nothing"})
|
||||
result = web_search_tool.func(query="nothing", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed == []
|
||||
@@ -129,7 +145,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "error"})
|
||||
result = web_search_tool.func(query="error", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: API rate limit exceeded"
|
||||
|
||||
@@ -147,7 +163,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "# Fetched Page\n\nThis is the page content."
|
||||
mock_exa_client.get_contents.assert_called_once_with(
|
||||
@@ -167,7 +183,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result.startswith("# Untitled\n\n")
|
||||
|
||||
@@ -179,7 +195,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com/404"})
|
||||
result = web_fetch_tool.func(url="https://example.com/404", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: No results found"
|
||||
|
||||
@@ -189,16 +205,44 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: Connection timeout"
|
||||
|
||||
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
|
||||
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
mock_result.text = "Content."
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = [mock_result]
|
||||
mock_exa_client.get_contents.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
fake_config.get_tool_config.assert_any_call("web_fetch")
|
||||
|
||||
def test_fetch_uses_independent_api_key(self, mock_exa_client):
|
||||
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
|
||||
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
|
||||
mock_exa_cls.return_value = mock_exa_client
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
|
||||
def get_tool_config(name):
|
||||
if name == "web_fetch":
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
@@ -209,37 +253,9 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
mock_config.return_value.get_tool_config.assert_any_call("web_fetch")
|
||||
|
||||
def test_fetch_uses_independent_api_key(self, mock_exa_client):
|
||||
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
|
||||
mock_exa_cls.return_value = mock_exa_client
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
|
||||
def get_tool_config(name):
|
||||
if name == "web_fetch":
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
mock_config.return_value.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
mock_result.text = "Content."
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = [mock_result]
|
||||
mock_exa_client.get_contents.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
|
||||
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
|
||||
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
|
||||
|
||||
def test_fetch_truncates_long_content(self, mock_app_config, mock_exa_client):
|
||||
"""Test fetch truncates content to 4096 characters."""
|
||||
@@ -253,7 +269,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
# "# Long Page\n\n" is 14 chars, content truncated to 4096
|
||||
content_after_header = result.split("\n\n", 1)[1]
|
||||
|
||||
@@ -3,14 +3,31 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from types import SimpleNamespace as _P2NS
|
||||
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
|
||||
|
||||
def _runtime_with_config(config):
|
||||
ctx = _P2Ctx.__new__(_P2Ctx)
|
||||
object.__setattr__(ctx, "app_config", config)
|
||||
object.__setattr__(ctx, "thread_id", "test-thread")
|
||||
object.__setattr__(ctx, "agent_name", None)
|
||||
return _P2NS(context=ctx)
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
|
||||
@patch("deerflow.community.firecrawl.tools.get_app_config")
|
||||
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
|
||||
def test_search_uses_web_search_config(self, mock_firecrawl_cls):
|
||||
search_config = MagicMock()
|
||||
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
|
||||
mock_get_app_config.return_value.get_tool_config.return_value = search_config
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = search_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.web = [
|
||||
@@ -20,7 +37,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.firecrawl.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test query"})
|
||||
result = web_search_tool.func(query="test query", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
assert json.loads(result) == [
|
||||
{
|
||||
@@ -29,15 +46,14 @@ class TestWebSearchTool:
|
||||
"snippet": "Snippet",
|
||||
}
|
||||
]
|
||||
mock_get_app_config.return_value.get_tool_config.assert_called_with("web_search")
|
||||
fake_config.get_tool_config.assert_called_with("web_search")
|
||||
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-search-key")
|
||||
mock_firecrawl_cls.return_value.search.assert_called_once_with("test query", limit=7)
|
||||
|
||||
|
||||
class TestWebFetchTool:
|
||||
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
|
||||
@patch("deerflow.community.firecrawl.tools.get_app_config")
|
||||
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
|
||||
def test_fetch_uses_web_fetch_config(self, mock_firecrawl_cls):
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}
|
||||
|
||||
@@ -46,7 +62,8 @@ class TestWebFetchTool:
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
mock_get_app_config.return_value.get_tool_config.side_effect = get_tool_config
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_scrape_result = MagicMock()
|
||||
mock_scrape_result.markdown = "Fetched markdown"
|
||||
@@ -55,10 +72,10 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.firecrawl.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
assert result == "# Fetched Page\n\nFetched markdown"
|
||||
mock_get_app_config.return_value.get_tool_config.assert_any_call("web_fetch")
|
||||
fake_config.get_tool_config.assert_any_call("web_fetch")
|
||||
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-fetch-key")
|
||||
mock_firecrawl_cls.return_value.scrape.assert_called_once_with(
|
||||
"https://example.com",
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for the FastAPI get_config dependency.
|
||||
|
||||
Phase 2 step 1: introduces the new explicit-config primitive that
|
||||
resolves ``AppConfig`` from ``request.app.state.config``. After migration,
|
||||
it is the sole mechanism.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def test_get_config_returns_app_state_config():
|
||||
"""get_config returns the AppConfig stored on app.state.config."""
|
||||
app = FastAPI()
|
||||
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
app.state.config = cfg
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(c: AppConfig = Depends(get_config)):
|
||||
# Identity check: FastAPI must hand us the exact object from app.state
|
||||
return {"same_identity": c is cfg, "log_level": c.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["same_identity"] is True
|
||||
assert body["log_level"] == "info"
|
||||
|
||||
|
||||
def test_get_config_reads_updated_app_state():
|
||||
"""When app.state.config is swapped (config reload), get_config sees the new value."""
|
||||
app = FastAPI()
|
||||
original = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
|
||||
replacement = original.model_copy(update={"log_level": "debug"})
|
||||
|
||||
app.state.config = original
|
||||
|
||||
@app.get("/log-level")
|
||||
def log_level(c: AppConfig = Depends(get_config)):
|
||||
return {"level": c.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/log-level").json() == {"level": "info"}
|
||||
|
||||
# Simulate config reload (PUT /mcp/config, etc.)
|
||||
app.state.config = replacement
|
||||
assert client.get("/log-level").json() == {"level": "debug"}
|
||||
@@ -333,12 +333,14 @@ class TestGuardrailsConfig:
|
||||
assert config.provider.use == "deerflow.guardrails.builtin:AllowlistProvider"
|
||||
assert config.provider.config == {"denied_tools": ["bash"]}
|
||||
|
||||
def test_singleton_load_and_get(self):
|
||||
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict, reset_guardrails_config
|
||||
def test_guardrails_config_via_app_config(self):
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.guardrails_config import GuardrailProviderConfig, GuardrailsConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
try:
|
||||
load_guardrails_config_from_dict({"enabled": True, "provider": {"use": "test:Foo"}})
|
||||
config = get_guardrails_config()
|
||||
assert config.enabled is True
|
||||
finally:
|
||||
reset_guardrails_config()
|
||||
cfg = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
guardrails=GuardrailsConfig(enabled=True, provider=GuardrailProviderConfig(use="test:Foo")),
|
||||
)
|
||||
config = cfg.guardrails
|
||||
assert config.enabled is True
|
||||
|
||||
@@ -6,6 +6,16 @@ from unittest.mock import MagicMock, patch
|
||||
from deerflow.community.infoquest import tools
|
||||
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
class TestInfoQuestClient:
|
||||
def test_infoquest_client_initialization(self):
|
||||
@@ -130,7 +140,7 @@ class TestInfoQuestClient:
|
||||
mock_client.web_search.return_value = json.dumps([])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.web_search_tool.run("test query")
|
||||
result = tools.web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == json.dumps([])
|
||||
mock_get_client.assert_called_once()
|
||||
@@ -143,14 +153,13 @@ class TestInfoQuestClient:
|
||||
mock_client.fetch.return_value = "<html><body>Test content</body></html>"
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.web_fetch_tool.run("https://example.com")
|
||||
result = tools.web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "# Untitled\n\nTest content"
|
||||
mock_get_client.assert_called_once()
|
||||
mock_client.fetch.assert_called_once_with("https://example.com")
|
||||
|
||||
@patch("deerflow.community.infoquest.tools.get_app_config")
|
||||
def test_get_infoquest_client(self, mock_get_app_config):
|
||||
def test_get_infoquest_client(self):
|
||||
"""Test _get_infoquest_client function with config."""
|
||||
mock_config = MagicMock()
|
||||
# Add image_search config to the side_effect
|
||||
@@ -159,9 +168,8 @@ class TestInfoQuestClient:
|
||||
MagicMock(model_extra={"fetch_time": 10, "timeout": 30, "navigation_timeout": 60}), # web_fetch config
|
||||
MagicMock(model_extra={"image_search_time_range": 7, "image_size": "l"}), # image_search config
|
||||
]
|
||||
mock_get_app_config.return_value = mock_config
|
||||
|
||||
client = tools._get_infoquest_client()
|
||||
client = tools._get_infoquest_client(mock_config)
|
||||
|
||||
assert client.search_time_range == 24
|
||||
assert client.fetch_time == 10
|
||||
@@ -321,7 +329,7 @@ class TestImageSearch:
|
||||
mock_client.image_search.return_value = json.dumps([{"image_url": "https://example.com/image1.jpg"}])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.image_search_tool.run({"query": "test query"})
|
||||
result = tools.image_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
|
||||
# Check if result is a valid JSON string
|
||||
result_data = json.loads(result)
|
||||
@@ -340,7 +348,7 @@ class TestImageSearch:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Pass all parameters as a dictionary (extra parameters will be ignored)
|
||||
tools.image_search_tool.run({"query": "sunset", "time_range": 30, "site": "unsplash.com", "image_size": "l"})
|
||||
tools.image_search_tool.func(query="sunset", runtime=_P2_RUNTIME)
|
||||
|
||||
mock_get_client.assert_called_once()
|
||||
# image_search_tool only passes query to client.image_search
|
||||
|
||||
@@ -6,7 +6,7 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
from deerflow.tools.builtins.invoke_acp_agent_tool import (
|
||||
_build_acp_mcp_servers,
|
||||
_build_mcp_servers,
|
||||
@@ -18,7 +18,6 @@ from deerflow.tools.tools import get_available_tools
|
||||
|
||||
|
||||
def test_build_mcp_servers_filters_disabled_and_maps_transports():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
|
||||
fresh_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"]),
|
||||
@@ -40,11 +39,9 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
|
||||
}
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_acp_mcp_servers_formats_list_payload():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
|
||||
fresh_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
|
||||
@@ -77,7 +74,6 @@ def test_build_acp_mcp_servers_formats_list_payload():
|
||||
]
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_permission_response_prefers_allow_once():
|
||||
@@ -669,31 +665,23 @@ async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch,
|
||||
|
||||
|
||||
def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(monkeypatch):
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
|
||||
load_acp_config_from_dict(
|
||||
{
|
||||
"codex": {
|
||||
"command": "codex-acp",
|
||||
"args": [],
|
||||
"description": "Codex CLI",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
tools=[],
|
||||
models=[],
|
||||
tool_search=SimpleNamespace(enabled=False),
|
||||
acp_agents={
|
||||
"codex": ACPAgentConfig(
|
||||
command="codex-acp",
|
||||
args=[],
|
||||
description="Codex CLI",
|
||||
)
|
||||
},
|
||||
get_model_config=lambda name: None,
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: fake_config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
|
||||
)
|
||||
|
||||
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
|
||||
tools = get_available_tools(include_mcp=True, subagent_enabled=False, app_config=fake_config)
|
||||
assert "invoke_acp_agent" in [tool.name for tool in tools]
|
||||
|
||||
load_acp_config_from_dict({})
|
||||
|
||||
@@ -10,6 +10,16 @@ import deerflow.community.jina_ai.jina_client as jina_client_module
|
||||
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||
from deerflow.community.jina_ai.tools import web_fetch_tool
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jina_client():
|
||||
@@ -176,9 +186,8 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
assert result.startswith("Error:")
|
||||
assert "429" in result
|
||||
|
||||
@@ -192,9 +201,8 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
assert "Hello world" in result
|
||||
assert not result.startswith("Error:")
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
@@ -33,7 +32,7 @@ def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
def test_resolve_model_name_falls_back_to_default(caplog):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
@@ -41,16 +40,14 @@ def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
resolved = lead_agent_module._resolve_model_name("missing-model")
|
||||
resolved = lead_agent_module._resolve_model_name(app_config, "missing-model")
|
||||
|
||||
assert resolved == "default-model"
|
||||
assert "fallback to default model 'default-model'" in caplog.text
|
||||
|
||||
|
||||
def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||
def test_resolve_model_name_uses_default_when_none():
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
@@ -58,23 +55,19 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
resolved = lead_agent_module._resolve_model_name(None)
|
||||
resolved = lead_agent_module._resolve_model_name(app_config, None)
|
||||
|
||||
assert resolved == "default-model"
|
||||
|
||||
|
||||
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
|
||||
def test_resolve_model_name_raises_when_no_models_configured():
|
||||
app_config = _make_app_config([])
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="No chat models are configured",
|
||||
):
|
||||
lead_agent_module._resolve_model_name("missing-model")
|
||||
lead_agent_module._resolve_model_name(app_config, "missing-model")
|
||||
|
||||
|
||||
def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch):
|
||||
@@ -82,13 +75,12 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
|
||||
import deerflow.tools as tools_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda app_config, config, model_name, agent_name=None: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
@@ -105,7 +97,8 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert captured["name"] == "safe-model"
|
||||
@@ -113,74 +106,6 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
_make_model("context-model", supports_thinking=True),
|
||||
]
|
||||
)
|
||||
|
||||
import deerflow.tools as tools_module
|
||||
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
result = lead_agent_module.make_lead_agent(
|
||||
{
|
||||
"context": {
|
||||
"model_name": "context-model",
|
||||
"thinking_enabled": False,
|
||||
"reasoning_effort": "high",
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
"max_concurrent_subagents": 7,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert captured == {
|
||||
"name": "context-model",
|
||||
"thinking_enabled": False,
|
||||
"reasoning_effort": "high",
|
||||
}
|
||||
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
def test_make_lead_agent_rejects_invalid_bootstrap_agent_name(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
lead_agent_module.make_lead_agent(
|
||||
{
|
||||
"configurable": {
|
||||
"model_name": "safe-model",
|
||||
"thinking_enabled": False,
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
"is_bootstrap": True,
|
||||
"agent_name": "../../../tmp/evil",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
@@ -197,11 +122,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda _ac: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
middlewares = lead_agent_module._build_middlewares(app_config, {"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
@@ -209,12 +133,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
|
||||
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -222,16 +144,16 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
middleware = lead_agent_module._create_summarization_middleware(patched)
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
|
||||
@@ -4,34 +4,23 @@ from types import SimpleNamespace
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts():
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
|
||||
assert prompt_module._build_custom_mounts_section() == ""
|
||||
assert prompt_module._build_custom_mounts_section(config) == ""
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
|
||||
def test_build_custom_mounts_section_lists_configured_mounts():
|
||||
mounts = [
|
||||
SimpleNamespace(container_path="/home/user/shared", read_only=False),
|
||||
SimpleNamespace(container_path="/mnt/reference", read_only=True),
|
||||
]
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
|
||||
section = prompt_module._build_custom_mounts_section()
|
||||
section = prompt_module._build_custom_mounts_section(config)
|
||||
|
||||
assert "**Custom Mounted Directories:**" in section
|
||||
assert "`/home/user/shared`" in section
|
||||
@@ -45,15 +34,15 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||
config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=mounts),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
prompt = prompt_module.apply_prompt_template(config)
|
||||
|
||||
assert "`/home/user/shared`" in prompt
|
||||
assert "Custom Mounted Directories" in prompt
|
||||
@@ -63,15 +52,15 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
|
||||
config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=[]),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
prompt = prompt_module.apply_prompt_template(config)
|
||||
|
||||
assert "Treat `/mnt/user-data/workspace` as your default current working directory" in prompt
|
||||
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
|
||||
@@ -92,8 +81,8 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
)
|
||||
|
||||
state = {"skills": [make_skill("first-skill")]}
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
|
||||
_set_skills_cache_state()
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda *a, **kwargs: list(state["skills"]))
|
||||
prompt_module._reset_skills_system_prompt_cache_state()
|
||||
|
||||
try:
|
||||
prompt_module.warm_enabled_skills_cache()
|
||||
@@ -128,7 +117,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def fake_load_skills(enabled_only=True):
|
||||
def fake_load_skills(*a, **kwargs):
|
||||
nonlocal active_loads, max_active_loads, call_count
|
||||
with lock:
|
||||
active_loads += 1
|
||||
@@ -165,7 +154,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
|
||||
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
|
||||
event = threading.Event()
|
||||
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda: event)
|
||||
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda *a, **k: event)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
warmed = prompt_module.warm_enabled_skills_cache(timeout_seconds=0.01)
|
||||
|
||||
@@ -19,27 +19,40 @@ def _make_skill(name: str) -> Skill:
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SKILLS_CONFIG = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
def _evolution_enabled_config() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills={"non_existent_skill"})
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"non_existent_skill"})
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=set())
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=set())
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_skills(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills={"skill1"})
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"skill1"})
|
||||
assert "skill1" in result
|
||||
assert "skill2" not in result
|
||||
assert "[built-in]" in result
|
||||
@@ -47,56 +60,41 @@ def test_get_skills_prompt_section_returns_skills(monkeypatch):
|
||||
|
||||
def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=None)
|
||||
assert "skill1" in result
|
||||
assert "skill2" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
|
||||
skills = [_make_skill("skill1")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.get_app_config",
|
||||
lambda: SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
|
||||
assert "Skill Self-Evolution" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.get_app_config",
|
||||
lambda: SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: [])
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
|
||||
assert "Skill Self-Evolution" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeypatch):
|
||||
skills = [_make_skill("skill1")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
config = _evolution_enabled_config()
|
||||
|
||||
enabled_result = get_skills_prompt_section(available_skills=None)
|
||||
enabled_result = get_skills_prompt_section(config, available_skills=None)
|
||||
assert "Skill Self-Evolution" in enabled_result
|
||||
|
||||
config.skill_evolution.enabled = False
|
||||
disabled_result = get_skills_prompt_section(available_skills=None)
|
||||
disabled_config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
disabled_result = get_skills_prompt_section(disabled_config, available_skills=None)
|
||||
assert "Skill Self-Evolution" not in disabled_result
|
||||
|
||||
|
||||
@@ -106,8 +104,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
# Mock dependencies
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda app_config=None, x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
@@ -118,11 +115,10 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = MockModelConfig()
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
captured_skills = []
|
||||
|
||||
def mock_apply_prompt_template(**kwargs):
|
||||
def mock_apply_prompt_template(_app_config, *args, **kwargs):
|
||||
captured_skills.append(kwargs.get("available_skills"))
|
||||
return "mock_prompt"
|
||||
|
||||
@@ -130,15 +126,15 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
# Case 1: Empty skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] == set()
|
||||
|
||||
# Case 2: None skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] is None
|
||||
|
||||
# Case 3: Some skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
@@ -22,26 +22,26 @@ def _make_config(*, allow_host_bash: bool, sandbox_use: str = "deerflow.sandbox.
|
||||
|
||||
|
||||
def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=False))
|
||||
app_config = _make_config(allow_host_bash=False)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=app_config)]
|
||||
|
||||
assert "bash" not in names
|
||||
assert "ls" in names
|
||||
|
||||
|
||||
def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=True))
|
||||
app_config = _make_config(allow_host_bash=True)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=app_config)]
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
@@ -52,13 +52,12 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
|
||||
allow_host_bash=False,
|
||||
extra_tools=[SimpleNamespace(name="shell", group="bash", use="deerflow.sandbox.tools:bash_tool")],
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=config)]
|
||||
|
||||
assert "bash" not in names
|
||||
assert "shell" not in names
|
||||
@@ -70,13 +69,12 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
|
||||
allow_host_bash=False,
|
||||
sandbox_use="deerflow.community.aio_sandbox:AioSandboxProvider",
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=config)]
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import errno
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -314,8 +313,7 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
|
||||
|
||||
@@ -336,8 +334,7 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
@@ -360,8 +357,7 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
@@ -476,7 +472,6 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
|
||||
|
||||
@@ -10,12 +10,22 @@ from deerflow.agents.middlewares.loop_detection_middleware import (
|
||||
LoopDetectionMiddleware,
|
||||
_hash_tool_calls,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(thread_id="test-thread"):
|
||||
"""Build a minimal Runtime mock with context."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
runtime.context = _make_context(thread_id)
|
||||
return runtime
|
||||
|
||||
|
||||
@@ -293,10 +303,10 @@ class TestLoopDetection:
|
||||
assert isinstance(mw._lock, type(mw._lock))
|
||||
|
||||
def test_fallback_thread_id_when_missing(self):
|
||||
"""When runtime context has no thread_id, should use 'default'."""
|
||||
"""When runtime context has empty thread_id, should use 'default'."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2)
|
||||
runtime = MagicMock()
|
||||
runtime.context = {}
|
||||
runtime.context = _make_context("")
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
@@ -3,23 +3,31 @@ import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _memory_config(**overrides: object) -> MemoryConfig:
|
||||
config = MemoryConfig()
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
def _make_config(**memory_overrides) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
with patch.object(queue, "_reset_timer"):
|
||||
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
|
||||
|
||||
@@ -29,7 +37,7 @@ def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
|
||||
|
||||
def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
@@ -55,12 +63,9 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
with patch.object(queue, "_reset_timer"):
|
||||
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||
|
||||
@@ -70,7 +75,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
|
||||
@@ -1,8 +1,30 @@
|
||||
"""Tests for user_id propagation through memory queue."""
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for user_id propagation through memory queue."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_memory(monkeypatch):
|
||||
"""Ensure MemoryUpdateQueue.add() doesn't early-return on disabled memory."""
|
||||
config = MagicMock(spec=AppConfig)
|
||||
config.memory = MemoryConfig(enabled=True)
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
@@ -16,7 +38,7 @@ def test_conversation_context_user_id_default_none():
|
||||
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
@@ -25,7 +47,7 @@ def test_queue_add_stores_user_id():
|
||||
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
|
||||
@@ -4,6 +4,18 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import memory
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
_TEST_APP_CONFIG = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
"""Build a memory-router app pre-populated with a minimal AppConfig."""
|
||||
app = FastAPI()
|
||||
app.state.config = _TEST_APP_CONFIG
|
||||
app.include_router(memory.router)
|
||||
return app
|
||||
|
||||
|
||||
def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
@@ -25,8 +37,7 @@ def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
|
||||
|
||||
def test_export_memory_route_returns_current_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -49,8 +60,7 @@ def test_export_memory_route_returns_current_memory() -> None:
|
||||
|
||||
|
||||
def test_import_memory_route_returns_imported_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -73,8 +83,7 @@ def test_import_memory_route_returns_imported_memory() -> None:
|
||||
|
||||
|
||||
def test_export_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -98,8 +107,7 @@ def test_export_memory_route_preserves_source_error() -> None:
|
||||
|
||||
|
||||
def test_import_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -123,8 +131,7 @@ def test_import_memory_route_preserves_source_error() -> None:
|
||||
|
||||
|
||||
def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.clear_memory_data", return_value=_sample_memory()):
|
||||
with TestClient(app) as client:
|
||||
@@ -135,8 +142,7 @@ def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
|
||||
|
||||
def test_create_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -166,8 +172,7 @@ def test_create_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_delete_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -190,8 +195,7 @@ def test_delete_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.delete_memory_fact", side_effect=KeyError("fact_missing")):
|
||||
with TestClient(app) as client:
|
||||
@@ -202,8 +206,7 @@ def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -233,8 +236,7 @@ def test_update_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -269,8 +271,7 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=KeyError("fact_missing")):
|
||||
with TestClient(app) as client:
|
||||
@@ -288,8 +289,7 @@ def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_specific_error_for_invalid_confidence() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=ValueError("confidence")):
|
||||
with TestClient(app) as client:
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for memory storage providers."""
|
||||
|
||||
import threading
|
||||
@@ -11,7 +23,13 @@ from deerflow.agents.memory.storage import (
|
||||
create_empty_memory,
|
||||
get_memory_storage,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _app_config(**memory_overrides) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
|
||||
|
||||
|
||||
class TestCreateEmptyMemory:
|
||||
@@ -53,10 +71,9 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
|
||||
def test_get_memory_file_path_agent(self, tmp_path):
|
||||
"""Should return per-agent memory file path when agent_name is provided."""
|
||||
@@ -67,14 +84,14 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path("test-agent")
|
||||
assert path == tmp_path / "agents" / "test-agent" / "memory.json"
|
||||
|
||||
@pytest.mark.parametrize("invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"])
|
||||
def test_validate_agent_name_invalid(self, invalid_name):
|
||||
"""Should raise ValueError for invalid agent names that don't match the pattern."""
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
with pytest.raises(ValueError, match="Invalid agent name|Agent name must be a non-empty string"):
|
||||
storage._validate_agent_name(invalid_name)
|
||||
|
||||
@@ -87,11 +104,10 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
memory = storage.load()
|
||||
assert isinstance(memory, dict)
|
||||
assert memory["version"] == "1.0"
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = storage.load()
|
||||
assert isinstance(memory, dict)
|
||||
assert memory["version"] == "1.0"
|
||||
|
||||
def test_save_writes_to_file(self, tmp_path):
|
||||
"""Should save memory data to file."""
|
||||
@@ -103,12 +119,11 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
|
||||
result = storage.save(test_memory)
|
||||
assert result is True
|
||||
assert memory_file.exists()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
|
||||
result = storage.save(test_memory)
|
||||
assert result is True
|
||||
assert memory_file.exists()
|
||||
|
||||
def test_save_does_not_mutate_caller_dict(self, tmp_path):
|
||||
"""save() must not mutate the caller's dict (lastUpdated side-effect)."""
|
||||
@@ -209,18 +224,17 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
# First load
|
||||
memory1 = storage.load()
|
||||
assert memory1["facts"][0]["content"] == "initial fact"
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
# First load
|
||||
memory1 = storage.load()
|
||||
assert memory1["facts"][0]["content"] == "initial fact"
|
||||
|
||||
# Update file directly
|
||||
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
|
||||
# Update file directly
|
||||
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
|
||||
|
||||
# Reload should get updated data
|
||||
memory2 = storage.reload()
|
||||
assert memory2["facts"][0]["content"] == "updated fact"
|
||||
# Reload should get updated data
|
||||
memory2 = storage.reload()
|
||||
assert memory2["facts"][0]["content"] == "updated fact"
|
||||
|
||||
|
||||
class TestGetMemoryStorage:
|
||||
@@ -237,22 +251,19 @@ class TestGetMemoryStorage:
|
||||
|
||||
def test_returns_file_memory_storage_by_default(self):
|
||||
"""Should return FileMemoryStorage by default."""
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_falls_back_to_file_memory_storage_on_error(self):
|
||||
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="non.existent.StorageClass")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_returns_singleton_instance(self):
|
||||
"""Should return the same instance on subsequent calls."""
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
storage1 = get_memory_storage()
|
||||
storage2 = get_memory_storage()
|
||||
assert storage1 is storage2
|
||||
storage1 = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
storage2 = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert storage1 is storage2
|
||||
|
||||
def test_get_memory_storage_thread_safety(self):
|
||||
"""Should safely initialize the singleton even with concurrent calls."""
|
||||
@@ -260,16 +271,15 @@ class TestGetMemoryStorage:
|
||||
|
||||
def get_storage():
|
||||
# get_memory_storage is called concurrently from multiple threads while
|
||||
# get_memory_config is patched once around thread creation. This verifies
|
||||
# AppConfig.get is patched once around thread creation. This verifies
|
||||
# that the singleton initialization remains thread-safe.
|
||||
results.append(get_memory_storage())
|
||||
results.append(get_memory_storage(_TEST_MEMORY_CONFIG))
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
threads = [threading.Thread(target=get_storage) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
threads = [threading.Thread(target=get_storage) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All results should be the exact same instance
|
||||
assert len(results) == 10
|
||||
@@ -278,13 +288,11 @@ class TestGetMemoryStorage:
|
||||
def test_get_memory_storage_invalid_class_fallback(self):
|
||||
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
|
||||
# Using a built-in function instead of a class
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="os.path.join")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_get_memory_storage_non_subclass_fallback(self):
|
||||
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
|
||||
# Using 'dict' as a class that is not a MemoryStorage subclass
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="builtins.dict")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
@@ -1,11 +1,29 @@
|
||||
"""Tests for per-user memory storage isolation."""
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for per-user memory storage isolation."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _mock_app_config() -> AppConfig:
|
||||
"""Build a minimal AppConfig with default (empty) memory storage_path."""
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(storage_path=""))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -15,7 +33,9 @@ def base_dir(tmp_path: Path) -> Path:
|
||||
|
||||
@pytest.fixture
|
||||
def storage() -> FileMemoryStorage:
|
||||
return FileMemoryStorage()
|
||||
return FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
|
||||
|
||||
|
||||
|
||||
class TestUserIsolatedStorage:
|
||||
@@ -43,7 +63,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "memory.json"
|
||||
@@ -54,7 +74,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "A"
|
||||
s.save(memory_a, user_id="alice")
|
||||
@@ -67,38 +87,34 @@ class TestUserIsolatedStorage:
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "A"
|
||||
|
||||
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id=None)
|
||||
expected_path = base_dir / "memory.json"
|
||||
assert expected_path.exists()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id=None)
|
||||
expected_path = base_dir / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
|
||||
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
|
||||
legacy_mem = create_empty_memory()
|
||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||
s.save(legacy_mem, user_id=None)
|
||||
legacy_mem = create_empty_memory()
|
||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||
s.save(legacy_mem, user_id=None)
|
||||
|
||||
user_mem = create_empty_memory()
|
||||
user_mem["user"]["workContext"]["summary"] = "alice"
|
||||
s.save(user_mem, user_id="alice")
|
||||
user_mem = create_empty_memory()
|
||||
user_mem["user"]["workContext"]["summary"] = "alice"
|
||||
s.save(user_mem, user_id="alice")
|
||||
|
||||
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
||||
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
||||
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
||||
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
||||
|
||||
def test_user_agent_memory_file_location(self, base_dir: Path):
|
||||
"""Per-user per-agent memory uses the user_agent_memory_file path."""
|
||||
@@ -106,7 +122,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "agent scoped"
|
||||
s.save(memory, "test-agent", user_id="alice")
|
||||
@@ -119,7 +135,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
# After save, cache should have tuple key
|
||||
@@ -131,7 +147,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "initial"
|
||||
s.save(memory, user_id="alice")
|
||||
|
||||
@@ -6,6 +6,17 @@ the in-memory LangGraph Store backend used when database.backend=memory.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -1,22 +1,32 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
_extract_text,
|
||||
_run_async_update_sync,
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
import_memory_data,
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
|
||||
return {
|
||||
"version": "1.0",
|
||||
@@ -35,15 +45,12 @@ def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, obje
|
||||
}
|
||||
|
||||
|
||||
def _memory_config(**overrides: object) -> MemoryConfig:
|
||||
config = MemoryConfig()
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
def _memory_config(**overrides: object) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig().model_copy(update=overrides))
|
||||
|
||||
|
||||
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -70,19 +77,14 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
|
||||
{"content": "User likes Python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == ["User likes Python"]
|
||||
assert all(fact["id"] != "fact_remove" for fact in result["facts"])
|
||||
|
||||
|
||||
def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@@ -91,12 +93,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
|
||||
{"content": "User works on DeerFlow", "category": "context", "confidence": 0.87},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == [
|
||||
"User prefers dark mode",
|
||||
@@ -107,7 +104,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
|
||||
|
||||
|
||||
def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=2, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -135,12 +132,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
{"content": "User likes noisy logs", "category": "behavior", "confidence": 0.6},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == [
|
||||
"User likes Python",
|
||||
@@ -151,7 +143,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
|
||||
|
||||
def test_apply_updates_preserves_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@@ -163,19 +155,14 @@ def test_apply_updates_preserves_source_error() -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||
assert result["facts"][0]["category"] == "correction"
|
||||
|
||||
|
||||
def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@@ -187,19 +174,14 @@ def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert "sourceError" not in result["facts"][0]
|
||||
|
||||
|
||||
def test_clear_memory_data_resets_all_sections() -> None:
|
||||
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
|
||||
result = clear_memory_data()
|
||||
result = clear_memory_data(_TEST_MEMORY_CONFIG)
|
||||
|
||||
assert result["version"] == "1.0"
|
||||
assert result["facts"] == []
|
||||
@@ -233,7 +215,7 @@ def test_delete_memory_fact_removes_only_matching_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = delete_memory_fact("fact_delete")
|
||||
result = delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_delete")
|
||||
|
||||
assert [fact["id"] for fact in result["facts"]] == ["fact_keep"]
|
||||
|
||||
@@ -243,7 +225,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = create_memory_fact(
|
||||
result = create_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
content=" User prefers concise code reviews. ",
|
||||
category="preference",
|
||||
confidence=0.88,
|
||||
@@ -258,7 +240,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
|
||||
|
||||
def test_create_memory_fact_rejects_empty_content() -> None:
|
||||
try:
|
||||
create_memory_fact(content=" ")
|
||||
create_memory_fact(_TEST_MEMORY_CONFIG, content=" ")
|
||||
except ValueError as exc:
|
||||
assert exc.args == ("content",)
|
||||
else:
|
||||
@@ -268,7 +250,7 @@ def test_create_memory_fact_rejects_empty_content() -> None:
|
||||
def test_create_memory_fact_rejects_invalid_confidence() -> None:
|
||||
for confidence in (-0.1, 1.1, float("nan"), float("inf"), float("-inf")):
|
||||
try:
|
||||
create_memory_fact(content="User likes tests", confidence=confidence)
|
||||
create_memory_fact(_TEST_MEMORY_CONFIG, content="User likes tests", confidence=confidence)
|
||||
except ValueError as exc:
|
||||
assert exc.args == ("confidence",)
|
||||
else:
|
||||
@@ -278,7 +260,7 @@ def test_create_memory_fact_rejects_invalid_confidence() -> None:
|
||||
def test_delete_memory_fact_raises_for_unknown_id() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
|
||||
try:
|
||||
delete_memory_fact("fact_missing")
|
||||
delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_missing")
|
||||
except KeyError as exc:
|
||||
assert exc.args == ("fact_missing",)
|
||||
else:
|
||||
@@ -303,7 +285,7 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
||||
mock_storage.load.return_value = imported_memory
|
||||
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
result = import_memory_data(imported_memory)
|
||||
result = import_memory_data(_TEST_MEMORY_CONFIG, imported_memory)
|
||||
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||
@@ -336,7 +318,7 @@ def test_update_memory_fact_updates_only_matching_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = update_memory_fact(
|
||||
result = update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
category="workflow",
|
||||
@@ -369,7 +351,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = update_memory_fact(
|
||||
result = update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
)
|
||||
@@ -382,7 +364,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
|
||||
def test_update_memory_fact_raises_for_unknown_id() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
|
||||
try:
|
||||
update_memory_fact(
|
||||
update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_missing",
|
||||
content="User prefers concise code reviews.",
|
||||
category="preference",
|
||||
@@ -414,7 +396,7 @@ def test_update_memory_fact_rejects_invalid_confidence() -> None:
|
||||
return_value=current_memory,
|
||||
):
|
||||
try:
|
||||
update_memory_fact(
|
||||
update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
confidence=confidence,
|
||||
@@ -527,17 +509,15 @@ class TestUpdateMemoryStructuredResponse:
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_string_response_parses(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -551,17 +531,15 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_list_content_response_parses(self):
|
||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
list_content = [{"type": "text", "text": valid_json}]
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -576,38 +554,13 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_async_update_memory_uses_ainvoke(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi there"
|
||||
ai_msg.tool_calls = []
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -622,17 +575,16 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -647,95 +599,15 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello from loop"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
async def run_in_loop():
|
||||
return updater.update_memory([msg, ai_msg])
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
||||
updater = MemoryUpdater()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello from loop"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
async def run_in_loop():
|
||||
return updater.update_memory([msg, ai_msg])
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRunAsyncUpdateSync:
|
||||
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
||||
class CloseableAwaitable:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def __await__(self):
|
||||
pytest.fail("awaitable should not have been awaited")
|
||||
yield
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
awaitable = CloseableAwaitable()
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
):
|
||||
|
||||
async def run_in_loop():
|
||||
return _run_async_update_sync(awaitable)
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is False
|
||||
assert awaitable.closed is True
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -755,19 +627,14 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -786,12 +653,7 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
@@ -804,17 +666,16 @@ class TestReinforcementHint:
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -829,17 +690,16 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -854,17 +714,16 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -879,56 +738,6 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
|
||||
class TestFinalizeCacheIsolation:
|
||||
"""_finalize_update must not mutate the cached memory object."""
|
||||
|
||||
def test_deepcopy_prevents_cache_corruption_on_save_failure(self):
|
||||
"""If save() fails, the in-memory snapshot used by _finalize_update
|
||||
must remain independent of any object the storage layer may still hold in
|
||||
its cache. The deepcopy in _finalize_update achieves this — the object
|
||||
passed to _apply_updates is always a fresh copy, never the cache reference.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
original_memory = _make_memory(facts=[{"id": "fact_orig", "content": "original", "category": "context", "confidence": 0.9, "createdAt": "2024-01-01T00:00:00Z", "source": "t1"}])
|
||||
|
||||
import json as _json
|
||||
|
||||
new_fact_json = _json.dumps(
|
||||
{
|
||||
"user": {},
|
||||
"history": {},
|
||||
"newFacts": [{"content": "new fact", "category": "context", "confidence": 0.9}],
|
||||
"factsToRemove": [],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = new_fact_json
|
||||
mock_model = AsyncMock()
|
||||
mock_model.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
saved_objects: list[dict] = []
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=mock_model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=original_memory),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=save_mock)),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "world"
|
||||
ai_msg.tool_calls = []
|
||||
updater.update_memory([msg, ai_msg], thread_id="t1")
|
||||
|
||||
# original_memory must not have been mutated — deepcopy isolates the mutation
|
||||
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
||||
assert original_memory["facts"][0]["content"] == "original"
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
"""Tests for user_id propagation in memory updater."""
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for user_id propagation in memory updater."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.updater import _save_memory_to_file, clear_memory_data, get_memory_data
|
||||
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
||||
|
||||
|
||||
def test_get_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.load.return_value = {"version": "1.0"}
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
get_memory_data(user_id="alice")
|
||||
get_memory_data(_TEST_MEMORY_CONFIG, user_id="alice")
|
||||
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||
|
||||
|
||||
@@ -17,7 +28,7 @@ def test_save_memory_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
||||
_save_memory_to_file(_TEST_MEMORY_CONFIG, {"version": "1.0"}, user_id="bob")
|
||||
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||
|
||||
|
||||
@@ -25,6 +36,6 @@ def test_clear_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
clear_memory_data(user_id="charlie")
|
||||
clear_memory_data(_TEST_MEMORY_CONFIG, user_id="charlie")
|
||||
# Verify save was called with user_id
|
||||
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""Tests for per-user data migration."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
@@ -25,7 +23,6 @@ class TestMigrateThreadDirs:
|
||||
(legacy / "file.txt").write_text("hello")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
|
||||
@@ -38,7 +35,6 @@ class TestMigrateThreadDirs:
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
expected = base_dir / "users" / "default" / "threads" / "t2"
|
||||
@@ -49,7 +45,6 @@ class TestMigrateThreadDirs:
|
||||
new_dir.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
assert new_dir.exists()
|
||||
|
||||
@@ -63,7 +58,6 @@ class TestMigrateThreadDirs:
|
||||
(dest / "new.txt").write_text("new")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
assert (dest / "new.txt").read_text() == "new"
|
||||
@@ -75,7 +69,6 @@ class TestMigrateThreadDirs:
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
assert not (base_dir / "threads").exists()
|
||||
@@ -85,7 +78,6 @@ class TestMigrateThreadDirs:
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
|
||||
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
|
||||
|
||||
assert len(report) == 1
|
||||
@@ -99,7 +91,6 @@ class TestMigrateMemory:
|
||||
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
expected = base_dir / "users" / "default" / "memory.json"
|
||||
@@ -115,7 +106,6 @@ class TestMigrateMemory:
|
||||
dest.write_text(json.dumps({"version": "new"}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
assert json.loads(dest.read_text())["version"] == "new"
|
||||
@@ -123,5 +113,4 @@ class TestMigrateMemory:
|
||||
|
||||
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
|
||||
migrate_memory(paths, user_id="default") # should not raise
|
||||
|
||||
@@ -72,8 +72,7 @@ class FakeChatModel(BaseChatModel):
|
||||
|
||||
|
||||
def _patch_factory(monkeypatch, app_config: AppConfig, model_class=FakeChatModel):
|
||||
"""Patch get_app_config, resolve_class, and tracing for isolated unit tests."""
|
||||
monkeypatch.setattr(factory_module, "get_app_config", lambda: app_config)
|
||||
"""Patch resolve_class and tracing for isolated unit tests."""
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: model_class)
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
|
||||
|
||||
@@ -88,7 +87,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name=None)
|
||||
factory_module.create_chat_model(name=None, app_config=cfg)
|
||||
|
||||
# resolve_class is called — if we reach here without ValueError, the correct model was used
|
||||
assert FakeChatModel.captured_kwargs.get("model") == "alpha"
|
||||
@@ -96,11 +95,10 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
|
||||
|
||||
def test_raises_when_model_not_found(monkeypatch):
|
||||
cfg = _make_app_config([_make_model("only-model")])
|
||||
monkeypatch.setattr(factory_module, "get_app_config", lambda: cfg)
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
|
||||
|
||||
with pytest.raises(ValueError, match="ghost-model"):
|
||||
factory_module.create_chat_model(name="ghost-model")
|
||||
factory_module.create_chat_model(name="ghost-model", app_config=cfg)
|
||||
|
||||
|
||||
def test_appends_all_tracing_callbacks(monkeypatch):
|
||||
@@ -109,7 +107,7 @@ def test_appends_all_tracing_callbacks(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: ["smith-callback", "langfuse-callback"])
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
model = factory_module.create_chat_model(name="alpha")
|
||||
model = factory_module.create_chat_model(name="alpha", app_config=cfg)
|
||||
|
||||
assert model.callbacks == ["smith-callback", "langfuse-callback"]
|
||||
|
||||
@@ -127,7 +125,7 @@ def test_thinking_enabled_raises_when_not_supported_but_when_thinking_enabled_is
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support thinking"):
|
||||
factory_module.create_chat_model(name="no-think", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-think", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
|
||||
def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(monkeypatch):
|
||||
@@ -138,7 +136,7 @@ def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support thinking"):
|
||||
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
|
||||
def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
|
||||
@@ -147,7 +145,7 @@ def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="thinker", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="thinker", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("temperature") == 1.0
|
||||
assert FakeChatModel.captured_kwargs.get("max_tokens") == 16000
|
||||
@@ -183,7 +181,7 @@ def test_thinking_disabled_openai_gateway_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
assert captured.get("reasoning_effort") == "minimal"
|
||||
@@ -216,7 +214,7 @@ def test_thinking_disabled_langchain_anthropic_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
assert "extra_body" not in captured
|
||||
@@ -238,7 +236,7 @@ def test_thinking_disabled_no_when_thinking_enabled_does_nothing(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="plain", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="plain", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert "extra_body" not in captured
|
||||
assert "thinking" not in captured
|
||||
@@ -278,7 +276,7 @@ def test_when_thinking_disabled_takes_precedence_over_hardcoded_disable(monkeypa
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
# User overrode the hardcoded "minimal" with "low"
|
||||
@@ -310,7 +308,7 @@ def test_when_thinking_disabled_not_used_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
# when_thinking_enabled should apply, NOT when_thinking_disabled
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
@@ -339,7 +337,7 @@ def test_when_thinking_disabled_without_when_thinking_enabled_still_applies(monk
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# when_thinking_disabled is now gated independently of has_thinking_settings
|
||||
assert captured.get("reasoning_effort") == "low"
|
||||
@@ -370,7 +368,7 @@ def test_when_thinking_disabled_excluded_from_model_dump(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
# when_thinking_disabled value must NOT appear as a raw key
|
||||
assert "when_thinking_disabled" not in captured
|
||||
@@ -394,7 +392,7 @@ def test_reasoning_effort_cleared_when_not_supported(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-effort", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="no-effort", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
@@ -422,7 +420,7 @@ def test_reasoning_effort_preserved_when_supported(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="effort-model", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="effort-model", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# When supports_reasoning_effort=True, it should NOT be cleared to None
|
||||
# The disable path sets it to "minimal"; supports_reasoning_effort=True keeps it
|
||||
@@ -458,7 +456,7 @@ def test_thinking_shortcut_enables_thinking_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert captured.get("thinking") == thinking_settings
|
||||
|
||||
@@ -488,7 +486,7 @@ def test_thinking_shortcut_disables_thinking_when_thinking_disabled(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
assert "extra_body" not in captured
|
||||
@@ -520,7 +518,7 @@ def test_thinking_shortcut_merges_with_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="merge-model", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="merge-model", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
# Both the thinking shortcut and when_thinking_enabled settings should be applied
|
||||
assert captured.get("thinking") == thinking_settings
|
||||
@@ -552,7 +550,7 @@ def test_thinking_shortcut_not_leaked_into_model_when_disabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-leak", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="no-leak", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# The disable path should have set thinking to disabled (not the raw enabled shortcut)
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
@@ -590,7 +588,7 @@ def test_openai_compatible_provider_passes_base_url(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
factory_module.create_chat_model(name="minimax-m2.5", app_config=cfg)
|
||||
|
||||
assert captured.get("model") == "MiniMax-M2.5"
|
||||
assert captured.get("base_url") == "https://api.minimax.io/v1"
|
||||
@@ -731,11 +729,11 @@ def test_openai_compatible_provider_multiple_models(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
# Create first model
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
factory_module.create_chat_model(name="minimax-m2.5", app_config=cfg)
|
||||
assert captured.get("model") == "MiniMax-M2.5"
|
||||
|
||||
# Create second model
|
||||
factory_module.create_chat_model(name="minimax-m2.5-highspeed")
|
||||
factory_module.create_chat_model(name="minimax-m2.5-highspeed", app_config=cfg)
|
||||
assert captured.get("model") == "MiniMax-M2.5-highspeed"
|
||||
|
||||
|
||||
@@ -763,7 +761,7 @@ def test_codex_provider_disables_reasoning_when_thinking_disabled(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "none"
|
||||
|
||||
@@ -783,7 +781,7 @@ def test_codex_provider_preserves_explicit_reasoning_effort(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high")
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high", app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "high"
|
||||
|
||||
@@ -803,7 +801,7 @@ def test_codex_provider_defaults_reasoning_effort_to_medium(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "medium"
|
||||
|
||||
@@ -824,7 +822,7 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert "max_tokens" not in FakeChatModel.captured_kwargs
|
||||
|
||||
@@ -837,7 +835,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
model = model.model_copy(update={"extra_body": {"top_k": 20}})
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
@@ -850,7 +848,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
@@ -864,7 +862,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
model = model.model_copy(update={"extra_body": {"top_k": 20}})
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
@@ -877,7 +875,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {
|
||||
"top_k": 20,
|
||||
@@ -911,7 +909,7 @@ def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
factory_module.create_chat_model(name="deepseek", app_config=cfg)
|
||||
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
@@ -930,14 +928,25 @@ def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="claude")
|
||||
factory_module.create_chat_model(name="claude", app_config=cfg)
|
||||
|
||||
assert "stream_usage" not in captured
|
||||
|
||||
|
||||
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
|
||||
"""If config dumps stream_usage=False, factory should respect it."""
|
||||
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
|
||||
# Build a ModelConfig with stream_usage=False as an extra field (extra="allow").
|
||||
model_with_stream_usage = ModelConfig(
|
||||
name="deepseek",
|
||||
display_name="deepseek",
|
||||
description=None,
|
||||
use="langchain_deepseek:ChatDeepSeek",
|
||||
model="deepseek",
|
||||
supports_thinking=False,
|
||||
supports_vision=False,
|
||||
stream_usage=False,
|
||||
)
|
||||
cfg = _make_app_config([model_with_stream_usage])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
@@ -949,17 +958,7 @@ def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
# Simulate config having stream_usage explicitly set by patching model_dump
|
||||
original_get_model_config = cfg.get_model_config
|
||||
|
||||
def patched_get_model_config(name):
|
||||
mc = original_get_model_config(name)
|
||||
mc.stream_usage = False # type: ignore[attr-defined]
|
||||
return mc
|
||||
|
||||
monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
factory_module.create_chat_model(name="deepseek", app_config=cfg)
|
||||
|
||||
assert captured.get("stream_usage") is False
|
||||
|
||||
@@ -989,7 +988,7 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="gpt-5-responses")
|
||||
factory_module.create_chat_model(name="gpt-5-responses", app_config=cfg)
|
||||
|
||||
assert captured.get("use_responses_api") is True
|
||||
assert captured.get("output_version") == "responses/v1"
|
||||
@@ -1030,7 +1029,7 @@ def test_no_duplicate_kwarg_when_reasoning_effort_in_config_and_thinking_disable
|
||||
_patch_factory(monkeypatch, cfg, model_class=CapturingModel)
|
||||
|
||||
# Must not raise TypeError
|
||||
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# kwargs (runtime) takes precedence: thinking-disabled path sets reasoning_effort=minimal
|
||||
assert captured.get("reasoning_effort") == "minimal"
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Tests for user-scoped path resolution in Paths."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@@ -3,14 +3,24 @@
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
present_file_tool_module = importlib.import_module("deerflow.tools.builtins.present_file_tool")
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(outputs_path: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
state={"thread_data": {"outputs_path": outputs_path}},
|
||||
context={"thread_id": "thread-1"},
|
||||
config={},
|
||||
context=_make_context("thread-1"),
|
||||
)
|
||||
|
||||
|
||||
@@ -51,34 +61,6 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||
|
||||
|
||||
def test_present_files_uses_config_thread_id_when_context_missing(tmp_path, monkeypatch):
|
||||
outputs_dir = tmp_path / "threads" / "thread-from-config" / "user-data" / "outputs"
|
||||
outputs_dir.mkdir(parents=True)
|
||||
artifact_path = outputs_dir / "summary.json"
|
||||
artifact_path.write_text("{}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
present_file_tool_module,
|
||||
"get_paths",
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
state={"thread_data": {"outputs_path": str(outputs_dir)}},
|
||||
context={},
|
||||
config={"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
|
||||
result = present_file_tool_module.present_file_tool.func(
|
||||
runtime=runtime,
|
||||
filepaths=["/mnt/user-data/outputs/summary.json"],
|
||||
tool_call_id="tc-config",
|
||||
)
|
||||
|
||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||
assert result.update["messages"][0].content == "Successfully presented files"
|
||||
|
||||
|
||||
def test_present_files_rejects_paths_outside_outputs(tmp_path):
|
||||
outputs_dir = tmp_path / "threads" / "thread-1" / "user-data" / "outputs"
|
||||
workspace_dir = tmp_path / "threads" / "thread-1" / "user-data" / "workspace"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
@@ -15,19 +14,14 @@ async def test_list_messages_by_run_default_returns_all(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-a",
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message",
|
||||
content=f"msg-a-{i}",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-b",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=f"msg-b-{i}",
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
||||
|
||||
@@ -42,11 +36,9 @@ async def test_list_messages_by_run_with_limit(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-a",
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message",
|
||||
content=f"msg-a-{i}",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
||||
@@ -60,11 +52,9 @@ async def test_list_messages_by_run_after_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-a",
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message",
|
||||
content=f"msg-a-{i}",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
@@ -79,11 +69,9 @@ async def test_list_messages_by_run_before_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-a",
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message",
|
||||
content=f"msg-a-{i}",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
@@ -98,19 +86,13 @@ async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-a",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=f"msg-a-{i}",
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message", category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="run-b",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=f"msg-b-{i}",
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-b")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,15 @@
|
||||
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -112,8 +113,7 @@ def test_run_messages_passes_after_seq_to_event_store():
|
||||
response = client.get("/api/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3",
|
||||
"run-3",
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
@@ -133,8 +133,7 @@ def test_run_messages_respects_custom_limit():
|
||||
response = client.get("/api/runs/run-4/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4",
|
||||
"run-4",
|
||||
"thread-4", "run-4",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
@@ -154,8 +153,7 @@ def test_run_messages_passes_before_seq_to_event_store():
|
||||
response = client.get("/api/runs/run-5/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5",
|
||||
"run-5",
|
||||
"thread-5", "run-5",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
|
||||
@@ -14,6 +14,10 @@ def _make_runtime(tmp_path):
|
||||
workspace.mkdir()
|
||||
uploads.mkdir()
|
||||
outputs.mkdir()
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
@@ -23,7 +27,10 @@ def _make_runtime(tmp_path):
|
||||
"outputs_path": str(outputs),
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id="thread-1",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -103,8 +110,6 @@ def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
|
||||
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
@@ -324,10 +329,6 @@ def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -
|
||||
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.get_app_config",
|
||||
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
|
||||
)
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
@@ -34,6 +35,53 @@ _THREAD_DATA = {
|
||||
}
|
||||
|
||||
|
||||
def _make_app_config(
|
||||
*,
|
||||
skills_container_path: str = "/mnt/skills",
|
||||
skills_host_path: str | None = None,
|
||||
mounts=None,
|
||||
mcp_servers=None,
|
||||
tool_config_map=None,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a lightweight AppConfig stand-in used by tests.
|
||||
|
||||
Only the attributes accessed by the helpers under test are populated;
|
||||
everything else is omitted to keep the fake minimal and explicit.
|
||||
"""
|
||||
skills_path = Path(skills_host_path) if skills_host_path is not None else None
|
||||
skills_cfg = SimpleNamespace(
|
||||
container_path=skills_container_path,
|
||||
get_skills_path=lambda: skills_path if skills_path is not None else Path("/nonexistent-skills-root-12345"),
|
||||
)
|
||||
sandbox_cfg = SimpleNamespace(mounts=list(mounts) if mounts else [], bash_output_max_chars=20000)
|
||||
extensions_cfg = SimpleNamespace(mcp_servers=dict(mcp_servers) if mcp_servers else {})
|
||||
tool_config_map = dict(tool_config_map or {})
|
||||
return SimpleNamespace(
|
||||
skills=skills_cfg,
|
||||
sandbox=sandbox_cfg,
|
||||
extensions=extensions_cfg,
|
||||
get_tool_config=lambda name: tool_config_map.get(name),
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_APP_CONFIG = _make_app_config()
|
||||
|
||||
|
||||
def _make_ctx(thread_id: str = "thread-1", *, app_config=_DEFAULT_APP_CONFIG, sandbox_key: str | None = None):
|
||||
"""Build a DeerFlowContext-like object with extra attributes allowed.
|
||||
|
||||
``resolve_context`` only checks ``isinstance(ctx, DeerFlowContext)``; for
|
||||
tests that need additional attributes (``sandbox_key``) we use a subclass
|
||||
created at runtime.
|
||||
"""
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _DFC
|
||||
|
||||
ctx = _DFC(app_config=app_config, thread_id=thread_id)
|
||||
if sandbox_key is not None:
|
||||
object.__setattr__(ctx, "sandbox_key", sandbox_key)
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------- replace_virtual_path ----------
|
||||
|
||||
|
||||
@@ -85,7 +133,7 @@ def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing
|
||||
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||
|
||||
|
||||
@@ -94,7 +142,7 @@ def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
|
||||
def test_mask_local_paths_in_output_hides_host_paths() -> None:
|
||||
output = "Created: /tmp/deer-flow/threads/t1/user-data/workspace/result.txt"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert "/tmp/deer-flow/threads/t1/user-data" not in masked
|
||||
assert "/mnt/user-data/workspace/result.txt" in masked
|
||||
@@ -107,7 +155,7 @@ def test_mask_local_paths_in_output_hides_skills_host_paths() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
output = "Reading: /home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert "/home/user/deer-flow/skills" not in masked
|
||||
assert "/mnt/skills/public/bootstrap/SKILL.md" in masked
|
||||
@@ -143,12 +191,12 @@ def test_reject_path_traversal_allows_normal_paths() -> None:
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
|
||||
with pytest.raises(PermissionError, match="configured mount paths"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
|
||||
@@ -158,42 +206,41 @@ def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -
|
||||
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
|
||||
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA)
|
||||
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_user_data_paths() -> None:
|
||||
# Should not raise — user-data paths are always allowed
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_user_data_write() -> None:
|
||||
# read_only=False (default) should still work for user-data paths
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_traversal_in_user_data() -> None:
|
||||
"""Path traversal via .. in user-data paths must be rejected."""
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_traversal_in_skills() -> None:
|
||||
"""Path traversal via .. in skills paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
@@ -201,7 +248,7 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
from deerflow.sandbox.exceptions import SandboxRuntimeError
|
||||
|
||||
with pytest.raises(SandboxRuntimeError):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- _resolve_skills_path ----------
|
||||
@@ -209,32 +256,26 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
|
||||
def test_resolve_skills_path_resolves_correctly() -> None:
|
||||
"""Skills virtual path should resolve to host path."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
|
||||
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
|
||||
# Force get_skills_path().exists() to be True without touching the FS
|
||||
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
|
||||
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", cfg)
|
||||
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
|
||||
|
||||
def test_resolve_skills_path_resolves_root() -> None:
|
||||
"""Skills container root should resolve to host skills directory."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
resolved = _resolve_skills_path("/mnt/skills")
|
||||
assert resolved == "/home/user/deer-flow/skills"
|
||||
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
|
||||
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
|
||||
resolved = _resolve_skills_path("/mnt/skills", cfg)
|
||||
assert resolved == "/home/user/deer-flow/skills"
|
||||
|
||||
|
||||
def test_resolve_skills_path_raises_when_not_configured() -> None:
|
||||
"""Should raise FileNotFoundError when skills directory is not available."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=None),
|
||||
):
|
||||
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
|
||||
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
|
||||
# Default app config has no host path configured → _get_skills_host_path returns None
|
||||
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
|
||||
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- _resolve_and_validate_user_data_path ----------
|
||||
@@ -249,7 +290,7 @@ def test_resolve_and_validate_user_data_path_resolves_correctly(tmp_path: Path)
|
||||
"uploads_path": str(tmp_path / "uploads"),
|
||||
"outputs_path": str(tmp_path / "outputs"),
|
||||
}
|
||||
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data)
|
||||
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data, _DEFAULT_APP_CONFIG)
|
||||
assert resolved == str(workspace / "hello.txt")
|
||||
|
||||
|
||||
@@ -264,7 +305,7 @@ def test_resolve_and_validate_user_data_path_blocks_traversal(tmp_path: Path) ->
|
||||
}
|
||||
# This path resolves outside the allowed roots
|
||||
with pytest.raises(PermissionError):
|
||||
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data)
|
||||
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- replace_virtual_paths_in_command ----------
|
||||
@@ -277,7 +318,7 @@ def test_replace_virtual_paths_in_command_replaces_skills_paths() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
cmd = "cat /mnt/skills/public/bootstrap/SKILL.md"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/skills" not in result
|
||||
assert "/home/user/deer-flow/skills/public/bootstrap/SKILL.md" in result
|
||||
|
||||
@@ -289,7 +330,7 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/skills"),
|
||||
):
|
||||
cmd = "cat /mnt/skills/public/SKILL.md > /mnt/user-data/workspace/out.txt"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/skills" not in result
|
||||
assert "/mnt/user-data" not in result
|
||||
assert "/home/user/skills/public/SKILL.md" in result
|
||||
@@ -301,30 +342,27 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||
validate_local_bash_command_paths(
|
||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> None:
|
||||
@@ -332,8 +370,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> No
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/user-data/workspace/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
|
||||
@@ -342,21 +379,20 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/skills/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) -> None:
|
||||
runtime = SimpleNamespace(
|
||||
state={"sandbox": {"sandbox_id": "local"}, "thread_data": _THREAD_DATA.copy()},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=_make_ctx("thread-1"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: SimpleNamespace(execute_command=lambda command: pytest.fail("host bash should not execute")),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda: False)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = bash_tool.func(
|
||||
runtime=runtime,
|
||||
@@ -371,33 +407,32 @@ def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) ->
|
||||
|
||||
|
||||
def test_is_skills_path_recognises_default_prefix() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
assert _is_skills_path("/mnt/skills") is True
|
||||
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md") is True
|
||||
assert _is_skills_path("/mnt/skills-extra/foo") is False
|
||||
assert _is_skills_path("/mnt/user-data/workspace") is False
|
||||
assert _is_skills_path("/mnt/skills", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_skills_path("/mnt/skills-extra/foo", _DEFAULT_APP_CONFIG) is False
|
||||
assert _is_skills_path("/mnt/user-data/workspace", _DEFAULT_APP_CONFIG) is False
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_skills_read_only() -> None:
|
||||
"""read_file / ls should be able to access /mnt/skills paths."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
# Should not raise
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=True,
|
||||
)
|
||||
# Should not raise — default app config uses /mnt/skills as container path
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_skills_write() -> None:
|
||||
"""write_file / str_replace must NOT write to skills paths."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=False,
|
||||
)
|
||||
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_skills_path() -> None:
|
||||
@@ -405,8 +440,7 @@ def test_validate_local_bash_command_paths_allows_skills_path() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_urls() -> None:
|
||||
@@ -414,40 +448,35 @@ def test_validate_local_bash_command_paths_allows_urls() -> None:
|
||||
# HTTPS URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl -X POST https://example.com/api/v1/risk/check",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# HTTP URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://localhost:8080/health",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# URLs with query strings
|
||||
validate_local_bash_command_paths(
|
||||
"curl https://api.example.com/v2/search?q=test",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# FTP URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl ftp://ftp.example.com/pub/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# URL mixed with valid virtual path
|
||||
validate_local_bash_command_paths(
|
||||
"curl https://example.com/data -o /mnt/user-data/workspace/data.json",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls() -> None:
|
||||
"""file:// URLs should be treated as unsafe and blocked."""
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls_case_insensitive() -> None:
|
||||
"""file:// URL detection should be case-insensitive."""
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -> None:
|
||||
@@ -455,35 +484,36 @@ def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths(
|
||||
"curl file:///etc/passwd -o /mnt/user-data/workspace/out.txt",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_other_paths() -> None:
|
||||
"""Paths outside virtual and system prefixes must still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_skills_custom_container_path() -> None:
|
||||
"""Skills with a custom container_path in config should also work."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/custom/skills"):
|
||||
# Should not raise
|
||||
custom_config = _make_app_config(skills_container_path="/custom/skills")
|
||||
# Should not raise
|
||||
validate_local_tool_path(
|
||||
"/custom/skills/public/my-skill/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
custom_config,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# The default /mnt/skills should not match since container path is /custom/skills
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(
|
||||
"/custom/skills/public/my-skill/SKILL.md",
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
custom_config,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# The default /mnt/skills should not match since container path is /custom/skills
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------- ACP workspace path tests ----------
|
||||
|
||||
@@ -500,6 +530,7 @@ def test_validate_local_tool_path_allows_acp_workspace_read_only() -> None:
|
||||
validate_local_tool_path(
|
||||
"/mnt/acp-workspace/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
@@ -510,6 +541,7 @@ def test_validate_local_tool_path_blocks_acp_workspace_write() -> None:
|
||||
validate_local_tool_path(
|
||||
"/mnt/acp-workspace/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
@@ -518,8 +550,7 @@ def test_validate_local_bash_command_paths_allows_acp_workspace() -> None:
|
||||
"""bash commands referencing /mnt/acp-workspace should be allowed."""
|
||||
validate_local_bash_command_paths(
|
||||
"cp /mnt/acp-workspace/hello_world.py /mnt/user-data/outputs/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -> None:
|
||||
@@ -527,8 +558,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/acp-workspace/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_resolve_acp_workspace_path_resolves_correctly(tmp_path: Path) -> None:
|
||||
@@ -570,7 +600,7 @@ def test_replace_virtual_paths_in_command_replaces_acp_workspace() -> None:
|
||||
acp_host = "/home/user/.deer-flow/acp-workspace"
|
||||
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
|
||||
cmd = "cp /mnt/acp-workspace/hello.py /mnt/user-data/outputs/hello.py"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/acp-workspace" not in result
|
||||
assert f"{acp_host}/hello.py" in result
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/outputs/hello.py" in result
|
||||
@@ -581,7 +611,7 @@ def test_mask_local_paths_in_output_hides_acp_workspace_host_paths() -> None:
|
||||
acp_host = "/home/user/.deer-flow/acp-workspace"
|
||||
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
|
||||
output = f"Copied: {acp_host}/hello.py"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert acp_host not in masked
|
||||
assert "/mnt/acp-workspace/hello.py" in masked
|
||||
@@ -617,39 +647,37 @@ def test_apply_cwd_prefix_quotes_path_with_spaces() -> None:
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None:
|
||||
"""Bash commands referencing MCP filesystem server paths should be allowed."""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
mock_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"filesystem": McpServerConfig(
|
||||
enabled=True,
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
|
||||
)
|
||||
}
|
||||
)
|
||||
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=mock_config):
|
||||
# Should not raise - MCP filesystem paths are allowed
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
|
||||
|
||||
# Path traversal should still be blocked
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
|
||||
|
||||
# Disabled servers should not expose paths
|
||||
disabled_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"filesystem": McpServerConfig(
|
||||
enabled=False,
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
|
||||
)
|
||||
}
|
||||
def _mcp_app_config(enabled: bool) -> AppConfig:
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
extensions=ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"filesystem": McpServerConfig(
|
||||
enabled=enabled,
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
|
||||
enabled_cfg = _mcp_app_config(True)
|
||||
# Should not raise - MCP filesystem paths are allowed
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, enabled_cfg)
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA, enabled_cfg)
|
||||
|
||||
# Path traversal should still be blocked
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA, enabled_cfg)
|
||||
|
||||
# Disabled servers should not expose paths
|
||||
disabled_cfg = _mcp_app_config(False)
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, disabled_cfg)
|
||||
|
||||
|
||||
# ---------- Custom mount path tests ----------
|
||||
@@ -667,12 +695,12 @@ def _mock_custom_mounts():
|
||||
|
||||
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
assert _is_custom_mount_path("/mnt/code-read") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
|
||||
assert _is_custom_mount_path("/mnt/data") is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
|
||||
assert _is_custom_mount_path("/mnt/other") is False
|
||||
assert _is_custom_mount_path("/mnt/code-read", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/data", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG) is False
|
||||
assert _is_custom_mount_path("/mnt/other", _DEFAULT_APP_CONFIG) is False
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
@@ -683,7 +711,7 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py")
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py", _DEFAULT_APP_CONFIG)
|
||||
assert mount is not None
|
||||
assert mount.container_path == "/mnt/code"
|
||||
|
||||
@@ -691,90 +719,72 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
|
||||
"""read_file / ls should be able to access custom mount paths."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
|
||||
"""write_file / str_replace must NOT write to read-only custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
|
||||
"""write_file / str_replace should succeed on writable custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Path traversal via .. in custom mount paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
|
||||
"""bash commands referencing custom mount paths should be allowed."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Bash commands with traversal in custom mount paths should be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
|
||||
"""Paths not matching any custom mount should still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should cache after first successful load."""
|
||||
# Clear any existing cache
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
# Use real directories so host_path.exists() filtering passes
|
||||
def test_get_custom_mounts_reads_from_app_config(tmp_path) -> None:
|
||||
"""_get_custom_mounts should read directly from the supplied AppConfig."""
|
||||
dir_a = tmp_path / "code-read"
|
||||
dir_a.mkdir()
|
||||
dir_b = tmp_path / "data"
|
||||
dir_b.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 2
|
||||
|
||||
# After caching, should return cached value even without mock
|
||||
assert hasattr(_get_custom_mounts, "_cached")
|
||||
assert len(_get_custom_mounts()) == 2
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
cfg = _make_app_config(mounts=mounts)
|
||||
result = _get_custom_mounts(cfg)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(tmp_path) -> None:
|
||||
"""_get_custom_mounts should only return mounts whose host_path exists."""
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
existing_dir = tmp_path / "existing"
|
||||
existing_dir.mkdir()
|
||||
@@ -783,22 +793,16 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
|
||||
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
|
||||
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
cfg = _make_app_config(mounts=mounts)
|
||||
result = _get_custom_mounts(cfg)
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
|
||||
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG)
|
||||
assert mount is None
|
||||
|
||||
|
||||
@@ -829,8 +833,8 @@ def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) ->
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
@@ -905,14 +909,14 @@ def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_pat
|
||||
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
|
||||
}
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1", sandbox_key="sandbox-a"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-2", sandbox_key="sandbox-b"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
|
||||
lambda runtime: sandboxes[runtime.context.sandbox_key],
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
@@ -972,8 +976,8 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
|
||||
@@ -29,10 +29,9 @@ async def test_scan_skill_content_passes_run_name_to_model(monkeypatch):
|
||||
@pytest.mark.anyio
|
||||
async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
|
||||
config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None))
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
|
||||
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
|
||||
result = await scan_skill_content(config, "---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
|
||||
|
||||
assert result.decision == "block"
|
||||
assert "manual review required" in result.reason
|
||||
|
||||
@@ -4,9 +4,20 @@ from types import SimpleNamespace
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
|
||||
|
||||
|
||||
def _make_context(thread_id: str, app_config: object | None = None) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=app_config if app_config is not None else AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _skill_content(name: str, description: str = "Demo skill") -> str:
|
||||
return f"---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"
|
||||
|
||||
@@ -23,18 +34,15 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
|
||||
result = anyio.run(
|
||||
skill_manage_module.skill_manage_tool.coroutine,
|
||||
@@ -67,17 +75,14 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
|
||||
|
||||
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
|
||||
@@ -107,10 +112,8 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
|
||||
runtime = SimpleNamespace(context={}, config={"configurable": {}})
|
||||
runtime = SimpleNamespace(context=_make_context("", config), config={"configurable": {}})
|
||||
|
||||
with pytest.raises(ValueError, match="built-in skill"):
|
||||
anyio.run(
|
||||
@@ -131,17 +134,15 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-sync"}, config={"configurable": {"thread_id": "thread-sync"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-sync", config), config={"configurable": {"thread_id": "thread-sync"}})
|
||||
result = skill_manage_module.skill_manage_tool.func(
|
||||
runtime=runtime,
|
||||
action="create",
|
||||
@@ -159,17 +160,14 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
|
||||
|
||||
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):
|
||||
|
||||
@@ -7,6 +7,9 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import skills as skills_router
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.skills.manager import get_skill_history_file
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
@@ -44,17 +47,16 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -94,14 +96,12 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
get_skill_history_file("demo-skill").write_text(
|
||||
get_skill_history_file("demo-skill", config).write_text(
|
||||
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
@@ -114,6 +114,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -136,17 +137,16 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -238,23 +238,25 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
enabled_state = {"value": True}
|
||||
refresh_calls = []
|
||||
|
||||
def _load_skills(*, enabled_only: bool):
|
||||
def _load_skills(*a, enabled_only: bool = False, **k):
|
||||
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
|
||||
if enabled_only and not skill.enabled:
|
||||
return []
|
||||
return [skill]
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
enabled_state["value"] = False
|
||||
|
||||
_app_cfg = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.get_extensions_config", lambda: SimpleNamespace(mcp_servers={}, skills={}))
|
||||
monkeypatch.setattr("app.gateway.routers.skills.reload_extensions_config", lambda: None)
|
||||
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
|
||||
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = _app_cfg
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path:
|
||||
_write_skill(skills_root / "public" / "parent" / "child-skill", "child-skill", "Child skill")
|
||||
_write_skill(skills_root / "custom" / "team" / "helper", "team-helper", "Team helper")
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
by_name = {skill.name: skill for skill in skills}
|
||||
|
||||
assert {"root-skill", "child-skill", "team-helper"} <= set(by_name)
|
||||
@@ -57,7 +57,7 @@ def test_load_skills_skips_hidden_directories(tmp_path: Path):
|
||||
"Hidden skill",
|
||||
)
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
names = {skill.name for skill in skills}
|
||||
|
||||
assert "ok-skill" in names
|
||||
@@ -69,7 +69,7 @@ def test_load_skills_prefers_custom_over_public_with_same_name(tmp_path: Path):
|
||||
_write_skill(skills_root / "public" / "shared-skill", "shared-skill", "Public version")
|
||||
_write_skill(skills_root / "custom" / "shared-skill", "shared-skill", "Custom version")
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
shared = next(skill for skill in skills if skill.name == "shared-skill")
|
||||
|
||||
assert shared.category == "custom"
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -331,6 +332,9 @@ async def test_concurrent_tasks_end_sentinel():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_make_stream_bridge_defaults():
|
||||
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
|
||||
async with make_stream_bridge() as bridge:
|
||||
"""make_stream_bridge with a config lacking stream_bridge yields a MemoryStreamBridge."""
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
async with make_stream_bridge(config) as bridge:
|
||||
assert isinstance(bridge, MemoryStreamBridge)
|
||||
|
||||
@@ -21,6 +21,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
_TEST_APP_CONFIG = MagicMock(name="TestAppConfig")
|
||||
|
||||
# Module names that need to be mocked to break circular imports
|
||||
_MOCKED_MODULE_NAMES = [
|
||||
"deerflow.agents",
|
||||
@@ -203,7 +205,7 @@ class TestAsyncExecutionPath:
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -232,7 +234,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -259,7 +261,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -285,7 +287,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -306,7 +308,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -327,7 +329,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -348,7 +350,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -384,7 +386,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -419,7 +421,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -456,7 +458,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -477,7 +479,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_aexecute") as mock_aexecute:
|
||||
@@ -511,7 +513,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -565,7 +567,7 @@ class TestAsyncToolSupport:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -602,7 +604,7 @@ class TestAsyncToolSupport:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -648,7 +650,7 @@ class TestThreadSafety:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id=f"thread-{task_id}",
|
||||
thread_id=f"thread-{task_id}", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -858,7 +860,7 @@ class TestCooperativeCancellation:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -898,7 +900,7 @@ class TestCooperativeCancellation:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -977,7 +979,7 @@ class TestCooperativeCancellation:
|
||||
config=short_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
# Wrap _scheduler_pool.submit so we know when run_task finishes
|
||||
|
||||
@@ -1,29 +1,35 @@
|
||||
"""Tests for subagent availability and prompt exposure under local bash hardening."""
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.subagents import registry as registry_module
|
||||
|
||||
|
||||
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: False)
|
||||
def _config() -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
names = registry_module.get_available_subagent_names()
|
||||
|
||||
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
names = registry_module.get_available_subagent_names(_config())
|
||||
|
||||
assert names == ["general-purpose"]
|
||||
|
||||
|
||||
def test_get_available_subagent_names_keeps_bash_when_allowed(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: True)
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: True)
|
||||
|
||||
names = registry_module.get_available_subagent_names()
|
||||
names = registry_module.get_available_subagent_names(_config())
|
||||
|
||||
assert names == ["general-purpose", "bash"]
|
||||
|
||||
|
||||
def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch) -> None:
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose"])
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
|
||||
|
||||
section = prompt_module._build_subagent_section(3)
|
||||
section = prompt_module._build_subagent_section(3, _config())
|
||||
|
||||
# When bash is not available, it should not appear at all (aligned with Codex:
|
||||
# unavailable roles are omitted, not listed as disabled)
|
||||
@@ -34,9 +40,9 @@ def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch
|
||||
|
||||
|
||||
def test_build_subagent_section_includes_bash_when_available(monkeypatch) -> None:
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose", "bash"])
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose", "bash"])
|
||||
|
||||
section = prompt_module._build_subagent_section(3)
|
||||
section = prompt_module._build_subagent_section(3, _config())
|
||||
|
||||
assert "For command execution (git, build, test, deploy operations)" in section
|
||||
assert 'bash("npm test")' in section
|
||||
|
||||
@@ -1,596 +0,0 @@
|
||||
"""Tests for subagent per-agent skill configuration and custom subagent types.
|
||||
|
||||
Covers:
|
||||
- SubagentConfig.skills field
|
||||
- SubagentOverrideConfig.skills field
|
||||
- CustomSubagentConfig model validation
|
||||
- SubagentsAppConfig.custom_agents and get_skills_for()
|
||||
- Registry: custom agent lookup, skills override, merged available names
|
||||
- Skills filter passthrough in task_tool config assembly
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.subagents_config import (
|
||||
CustomSubagentConfig,
|
||||
SubagentOverrideConfig,
|
||||
SubagentsAppConfig,
|
||||
get_subagents_app_config,
|
||||
load_subagents_config_from_dict,
|
||||
)
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reset_subagents_config(**kwargs) -> None:
|
||||
"""Reset global subagents config to a known state."""
|
||||
load_subagents_config_from_dict(kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentConfig.skills field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentConfigSkills:
|
||||
def test_default_skills_is_none(self):
|
||||
config = SubagentConfig(name="test", description="test", system_prompt="test")
|
||||
assert config.skills is None
|
||||
|
||||
def test_skills_whitelist(self):
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=["data-analysis", "visualization"],
|
||||
)
|
||||
assert config.skills == ["data-analysis", "visualization"]
|
||||
|
||||
def test_skills_empty_list_means_no_skills(self):
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=[],
|
||||
)
|
||||
assert config.skills == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentOverrideConfig.skills field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentOverrideConfigSkills:
|
||||
def test_default_skills_is_none(self):
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.skills is None
|
||||
|
||||
def test_skills_whitelist(self):
|
||||
override = SubagentOverrideConfig(skills=["web-search", "data-analysis"])
|
||||
assert override.skills == ["web-search", "data-analysis"]
|
||||
|
||||
def test_skills_empty_list(self):
|
||||
override = SubagentOverrideConfig(skills=[])
|
||||
assert override.skills == []
|
||||
|
||||
def test_skills_coexists_with_other_fields(self):
|
||||
override = SubagentOverrideConfig(
|
||||
timeout_seconds=300,
|
||||
model="gpt-5",
|
||||
skills=["my-skill"],
|
||||
)
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.model == "gpt-5"
|
||||
assert override.skills == ["my-skill"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CustomSubagentConfig model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCustomSubagentConfig:
|
||||
def test_minimal_valid(self):
|
||||
config = CustomSubagentConfig(
|
||||
description="A test agent",
|
||||
system_prompt="You are a test agent.",
|
||||
)
|
||||
assert config.description == "A test agent"
|
||||
assert config.system_prompt == "You are a test agent."
|
||||
assert config.tools is None
|
||||
assert config.disallowed_tools == ["task", "ask_clarification", "present_files"]
|
||||
assert config.skills is None
|
||||
assert config.model == "inherit"
|
||||
assert config.max_turns == 50
|
||||
assert config.timeout_seconds == 900
|
||||
|
||||
def test_full_configuration(self):
|
||||
config = CustomSubagentConfig(
|
||||
description="Data analysis specialist",
|
||||
system_prompt="You are a data analysis subagent.",
|
||||
tools=["bash", "read_file", "write_file"],
|
||||
disallowed_tools=["task"],
|
||||
skills=["data-analysis", "visualization"],
|
||||
model="qwen3:32b",
|
||||
max_turns=80,
|
||||
timeout_seconds=600,
|
||||
)
|
||||
assert config.tools == ["bash", "read_file", "write_file"]
|
||||
assert config.skills == ["data-analysis", "visualization"]
|
||||
assert config.model == "qwen3:32b"
|
||||
assert config.max_turns == 80
|
||||
assert config.timeout_seconds == 600
|
||||
|
||||
def test_skills_empty_list_no_skills(self):
|
||||
config = CustomSubagentConfig(
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=[],
|
||||
)
|
||||
assert config.skills == []
|
||||
|
||||
def test_rejects_zero_max_turns(self):
|
||||
with pytest.raises(ValueError):
|
||||
CustomSubagentConfig(
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
max_turns=0,
|
||||
)
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
CustomSubagentConfig(
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
timeout_seconds=0,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig.custom_agents and get_skills_for()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentsAppConfigCustomAgents:
|
||||
def test_default_custom_agents_empty(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.custom_agents == {}
|
||||
|
||||
def test_custom_agents_loaded(self):
|
||||
config = SubagentsAppConfig(
|
||||
custom_agents={
|
||||
"analysis": CustomSubagentConfig(
|
||||
description="Analysis agent",
|
||||
system_prompt="You analyze data.",
|
||||
skills=["data-analysis"],
|
||||
),
|
||||
}
|
||||
)
|
||||
assert "analysis" in config.custom_agents
|
||||
assert config.custom_agents["analysis"].skills == ["data-analysis"]
|
||||
|
||||
def test_multiple_custom_agents(self):
|
||||
config = SubagentsAppConfig(
|
||||
custom_agents={
|
||||
"analysis": CustomSubagentConfig(
|
||||
description="Analysis",
|
||||
system_prompt="analyze",
|
||||
skills=["data-analysis"],
|
||||
),
|
||||
"researcher": CustomSubagentConfig(
|
||||
description="Research",
|
||||
system_prompt="research",
|
||||
skills=["web-search"],
|
||||
),
|
||||
}
|
||||
)
|
||||
assert len(config.custom_agents) == 2
|
||||
|
||||
|
||||
class TestGetSkillsFor:
|
||||
def test_returns_none_when_no_override(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.get_skills_for("general-purpose") is None
|
||||
assert config.get_skills_for("unknown") is None
|
||||
|
||||
def test_returns_skills_whitelist(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(skills=["web-search", "coding"]),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("general-purpose") == ["web-search", "coding"]
|
||||
|
||||
def test_returns_empty_list_for_no_skills(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"bash": SubagentOverrideConfig(skills=[]),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("bash") == []
|
||||
|
||||
def test_returns_none_for_unrelated_agent(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"bash": SubagentOverrideConfig(skills=["web-search"]),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("general-purpose") is None
|
||||
|
||||
def test_returns_none_when_skills_not_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=300),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("bash") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_subagents_config_from_dict with skills and custom_agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadSubagentsConfigWithSkills:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_load_with_skills_override(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search", "data-analysis"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_skills_for("general-purpose") == ["web-search", "data-analysis"]
|
||||
|
||||
def test_load_with_empty_skills(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"bash": {"skills": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_skills_for("bash") == []
|
||||
|
||||
def test_load_with_custom_agents(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Data analysis specialist",
|
||||
"system_prompt": "You are a data analysis subagent.",
|
||||
"skills": ["data-analysis", "visualization"],
|
||||
"tools": ["bash", "read_file"],
|
||||
"max_turns": 80,
|
||||
"timeout_seconds": 600,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert "analysis" in cfg.custom_agents
|
||||
custom = cfg.custom_agents["analysis"]
|
||||
assert custom.skills == ["data-analysis", "visualization"]
|
||||
assert custom.tools == ["bash", "read_file"]
|
||||
assert custom.max_turns == 80
|
||||
assert custom.timeout_seconds == 600
|
||||
|
||||
def test_load_with_both_overrides_and_custom(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search"]},
|
||||
},
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"skills": ["data-analysis"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_skills_for("general-purpose") == ["web-search"]
|
||||
assert cfg.custom_agents["analysis"].skills == ["data-analysis"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: custom agent lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryCustomAgentLookup:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_custom_agent_found(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Data analysis specialist",
|
||||
"system_prompt": "You are a data analysis subagent.",
|
||||
"skills": ["data-analysis"],
|
||||
"tools": ["bash", "read_file"],
|
||||
"max_turns": 80,
|
||||
"timeout_seconds": 600,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("analysis")
|
||||
assert config is not None
|
||||
assert config.name == "analysis"
|
||||
assert config.skills == ["data-analysis"]
|
||||
assert config.tools == ["bash", "read_file"]
|
||||
assert config.max_turns == 80
|
||||
assert config.timeout_seconds == 600
|
||||
assert config.model == "inherit"
|
||||
|
||||
def test_custom_agent_not_found(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config()
|
||||
assert get_subagent_config("nonexistent") is None
|
||||
|
||||
def test_builtin_takes_priority_over_custom(self):
|
||||
"""If a custom agent has the same name as a builtin, builtin wins."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"general-purpose": {
|
||||
"description": "Custom override attempt",
|
||||
"system_prompt": "Should not be used",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("general-purpose")
|
||||
# Should get the builtin description, not the custom one
|
||||
assert config.description == BUILTIN_SUBAGENTS["general-purpose"].description
|
||||
|
||||
def test_custom_agent_with_override(self):
|
||||
"""Per-agent overrides also apply to custom agents."""
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"timeout_seconds": 600,
|
||||
},
|
||||
},
|
||||
"agents": {
|
||||
"analysis": {"timeout_seconds": 300, "skills": ["overridden-skill"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("analysis")
|
||||
assert config is not None
|
||||
assert config.timeout_seconds == 300 # Override applied
|
||||
assert config.skills == ["overridden-skill"] # Override applied
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: skills override on builtin agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistrySkillsOverride:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_skills_override_applied_to_builtin(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search", "data-analysis"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.skills == ["web-search", "data-analysis"]
|
||||
|
||||
def test_empty_skills_override(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"agents": {
|
||||
"bash": {"skills": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("bash")
|
||||
assert config.skills == []
|
||||
|
||||
def test_no_skills_override_keeps_default(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config()
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.skills is None # Default: inherit all
|
||||
|
||||
def test_skills_override_does_not_mutate_builtin(self):
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
_ = get_subagent_config("general-purpose")
|
||||
assert BUILTIN_SUBAGENTS["general-purpose"].skills is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: get_available_subagent_names merges custom types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryAvailableNames:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_includes_builtin_names(self):
|
||||
from deerflow.subagents.registry import get_subagent_names
|
||||
|
||||
_reset_subagents_config()
|
||||
names = get_subagent_names()
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
|
||||
def test_includes_custom_names(self):
|
||||
from deerflow.subagents.registry import get_subagent_names
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
},
|
||||
"researcher": {
|
||||
"description": "Research",
|
||||
"system_prompt": "Research.",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
names = get_subagent_names()
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
assert "analysis" in names
|
||||
assert "researcher" in names
|
||||
|
||||
def test_no_duplicates_when_custom_name_matches_builtin(self):
|
||||
from deerflow.subagents.registry import get_subagent_names
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"general-purpose": {
|
||||
"description": "Duplicate name",
|
||||
"system_prompt": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
names = get_subagent_names()
|
||||
assert names.count("general-purpose") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: list_subagents includes custom agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryListSubagentsWithCustom:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_list_includes_custom_agents(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"skills": ["data-analysis"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
configs = list_subagents()
|
||||
names = {c.name for c in configs}
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
assert "analysis" in names
|
||||
|
||||
def test_list_custom_agent_has_correct_skills(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"skills": ["data-analysis", "visualization"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
by_name = {c.name: c for c in list_subagents()}
|
||||
assert by_name["analysis"].skills == ["data-analysis", "visualization"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skills filter passthrough: verify config.skills is used in task_tool assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillsFilterPassthrough:
|
||||
"""Test that SubagentConfig.skills is correctly passed to get_skills_prompt_section."""
|
||||
|
||||
def test_none_skills_passes_none_to_prompt(self):
|
||||
"""When config.skills is None, available_skills=None should be passed (inherit all)."""
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=None,
|
||||
)
|
||||
# Verify: set(None) would raise, so the code must check for None first
|
||||
available = set(config.skills) if config.skills is not None else None
|
||||
assert available is None
|
||||
|
||||
def test_empty_skills_passes_empty_set(self):
|
||||
"""When config.skills is [], available_skills=set() should be passed (no skills)."""
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=[],
|
||||
)
|
||||
available = set(config.skills) if config.skills is not None else None
|
||||
assert available == set()
|
||||
|
||||
def test_skills_whitelist_passes_correct_set(self):
|
||||
"""When config.skills has values, those should be passed as available_skills."""
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=["data-analysis", "web-search"],
|
||||
)
|
||||
available = set(config.skills) if config.skills is not None else None
|
||||
assert available == {"data-analysis", "web-search"}
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
|
||||
- get_timeout_for() / get_max_turns_for() resolution logic
|
||||
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
|
||||
- AppConfig.subagents field access
|
||||
- registry.get_subagent_config() applies config overrides
|
||||
- registry.list_subagents() applies overrides for all agents
|
||||
- Polling timeout calculation in task_tool is consistent with config
|
||||
@@ -11,32 +11,28 @@ Covers:
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.subagents_config import (
|
||||
SubagentOverrideConfig,
|
||||
SubagentsAppConfig,
|
||||
get_subagents_app_config,
|
||||
load_subagents_config_from_dict,
|
||||
)
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reset_subagents_config(
|
||||
def _make_config(
|
||||
timeout_seconds: int = 900,
|
||||
*,
|
||||
max_turns: int | None = None,
|
||||
agents: dict | None = None,
|
||||
) -> None:
|
||||
"""Reset global subagents config to a known state."""
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"max_turns": max_turns,
|
||||
"agents": agents or {},
|
||||
}
|
||||
) -> AppConfig:
|
||||
"""Build an AppConfig with the given subagents settings."""
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
subagents=SubagentsAppConfig(
|
||||
timeout_seconds=timeout_seconds,
|
||||
max_turns=max_turns,
|
||||
agents={k: SubagentOverrideConfig(**v) for k, v in (agents or {}).items()},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -50,523 +46,131 @@ class TestSubagentOverrideConfig:
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.timeout_seconds is None
|
||||
assert override.max_turns is None
|
||||
assert override.model is None
|
||||
|
||||
def test_explicit_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42, model="gpt-5.4")
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.max_turns == 42
|
||||
assert override.model == "gpt-5.4"
|
||||
|
||||
def test_model_accepts_any_non_empty_string(self):
|
||||
"""Model name is a free-form non-empty string; cross-reference validation
|
||||
against the `models:` section happens at registry lookup time."""
|
||||
override = SubagentOverrideConfig(model="any-arbitrary-model-name")
|
||||
assert override.model == "any-arbitrary-model-name"
|
||||
|
||||
def test_rejects_zero(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=0)
|
||||
|
||||
def test_rejects_negative(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=-1)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=-1)
|
||||
|
||||
def test_rejects_empty_model(self):
|
||||
"""Empty-string model would silently bypass the `is not None` check and
|
||||
reach `create_chat_model(name="")` as a runtime error. Reject at load time
|
||||
instead, symmetric with the `ge=1` guard on timeout_seconds / max_turns."""
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(model="")
|
||||
|
||||
def test_minimum_valid_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
|
||||
assert override.timeout_seconds == 1
|
||||
assert override.max_turns == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig – defaults and validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentsAppConfigDefaults:
|
||||
def test_default_timeout(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.timeout_seconds == 900
|
||||
|
||||
def test_default_max_turns_override_is_none(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.max_turns is None
|
||||
|
||||
def test_default_agents_empty(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.agents == {}
|
||||
|
||||
def test_custom_global_runtime_overrides(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 120
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=0)
|
||||
def test_explicit_values(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=120, max_turns=50)
|
||||
assert override.timeout_seconds == 120
|
||||
assert override.max_turns == 50
|
||||
|
||||
def test_rejects_negative_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=-60)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=-60)
|
||||
with pytest.raises(Exception):
|
||||
SubagentOverrideConfig(timeout_seconds=-1)
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(Exception):
|
||||
SubagentOverrideConfig(timeout_seconds=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig resolution helpers
|
||||
# SubagentsAppConfig model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRuntimeResolution:
|
||||
def test_returns_global_default_when_no_override(self):
|
||||
class TestSubagentsAppConfig:
|
||||
def test_default_timeout_is_900(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.timeout_seconds == 900
|
||||
assert config.max_turns is None
|
||||
assert config.agents == {}
|
||||
|
||||
def test_custom_defaults(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=300, max_turns=50)
|
||||
assert config.timeout_seconds == 300
|
||||
assert config.max_turns == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_timeout_for / get_max_turns_for
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTimeoutResolution:
|
||||
def test_global_timeout_for_unknown_agent(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=600)
|
||||
assert config.get_timeout_for("unknown") == 600
|
||||
|
||||
def test_per_agent_timeout_overrides_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=600,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=120)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 120
|
||||
assert config.get_timeout_for("general-purpose") == 600
|
||||
|
||||
def test_per_agent_override_none_falls_back_to_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=600,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=None)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 600
|
||||
assert config.get_timeout_for("unknown-agent") == 600
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 100
|
||||
|
||||
|
||||
class TestMaxTurnsResolution:
|
||||
def test_builtin_default_when_no_override(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.get_max_turns_for("bash", 60) == 60
|
||||
|
||||
def test_returns_per_agent_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
def test_global_max_turns_overrides_builtin(self):
|
||||
config = SubagentsAppConfig(max_turns=100)
|
||||
assert config.get_max_turns_for("bash", 60) == 100
|
||||
|
||||
def test_other_agents_still_use_global_default(self):
|
||||
def test_per_agent_max_turns_overrides_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=140,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
max_turns=100,
|
||||
agents={"bash": SubagentOverrideConfig(max_turns=30)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 140
|
||||
assert config.get_max_turns_for("bash", 60) == 30
|
||||
assert config.get_max_turns_for("general-purpose", 60) == 100
|
||||
|
||||
def test_agent_with_none_override_falls_back_to_global(self):
|
||||
def test_per_agent_override_none_falls_back(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=150,
|
||||
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
|
||||
max_turns=100,
|
||||
agents={"bash": SubagentOverrideConfig(max_turns=None)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 150
|
||||
|
||||
def test_multiple_per_agent_overrides(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
|
||||
},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 1800
|
||||
assert config.get_timeout_for("bash") == 120
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_get_model_for_returns_none_when_no_override(self):
|
||||
"""No per-agent model override -> returns None so callers fall back to builtin/parent."""
|
||||
config = SubagentsAppConfig(timeout_seconds=900)
|
||||
assert config.get_model_for("general-purpose") is None
|
||||
assert config.get_model_for("bash") is None
|
||||
assert config.get_model_for("unknown-agent") is None
|
||||
|
||||
def test_get_model_for_returns_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(model="qwen3.5-35b-a3b"),
|
||||
"bash": SubagentOverrideConfig(model="gpt-5.4"),
|
||||
},
|
||||
)
|
||||
assert config.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
|
||||
assert config.get_model_for("bash") == "gpt-5.4"
|
||||
|
||||
def test_get_model_for_returns_none_for_omitted_agent(self):
|
||||
"""An agent not listed in overrides returns None even when other agents have model overrides."""
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(model="gpt-5.4")},
|
||||
)
|
||||
assert config.get_model_for("general-purpose") is None
|
||||
|
||||
def test_get_model_for_handles_explicit_none(self):
|
||||
"""Explicit model=None in the override is equivalent to no override."""
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, model=None)},
|
||||
)
|
||||
assert config.get_model_for("bash") is None
|
||||
# Timeout override is still applied even when model is None.
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
assert config.get_max_turns_for("bash", 60) == 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_subagents_config_from_dict / get_subagents_app_config singleton
|
||||
# AppConfig.subagents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadSubagentsConfig:
|
||||
def teardown_method(self):
|
||||
"""Restore defaults after each test."""
|
||||
_reset_subagents_config()
|
||||
|
||||
class TestAppConfigSubagents:
|
||||
def test_load_global_timeout(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
|
||||
assert get_subagents_app_config().timeout_seconds == 300
|
||||
assert get_subagents_app_config().max_turns == 120
|
||||
cfg = _make_config(timeout_seconds=300, max_turns=120)
|
||||
sub = cfg.subagents
|
||||
assert sub.timeout_seconds == 300
|
||||
assert sub.max_turns == 120
|
||||
|
||||
def test_load_with_per_agent_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
cfg = _make_config(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 1800
|
||||
assert cfg.get_timeout_for("bash") == 60
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert cfg.get_max_turns_for("bash", 60) == 80
|
||||
sub = cfg.subagents
|
||||
assert sub.get_timeout_for("general-purpose") == 1800
|
||||
assert sub.get_timeout_for("bash") == 60
|
||||
assert sub.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert sub.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_load_partial_override(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 600,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
|
||||
}
|
||||
cfg = _make_config(
|
||||
timeout_seconds=600,
|
||||
agents={"bash": {"timeout_seconds": 120, "max_turns": 70}},
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 600
|
||||
assert cfg.get_timeout_for("bash") == 120
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert cfg.get_max_turns_for("bash", 60) == 70
|
||||
sub = cfg.subagents
|
||||
assert sub.get_timeout_for("general-purpose") == 600
|
||||
assert sub.get_timeout_for("bash") == 120
|
||||
assert sub.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert sub.get_max_turns_for("bash", 60) == 70
|
||||
|
||||
def test_load_with_model_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"model": "qwen3.5-35b-a3b"},
|
||||
"bash": {"model": "gpt-5.4", "timeout_seconds": 300},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
|
||||
assert cfg.get_model_for("bash") == "gpt-5.4"
|
||||
# Other override fields on the same agent must still load correctly.
|
||||
assert cfg.get_timeout_for("bash") == 300
|
||||
|
||||
def test_load_empty_dict_uses_defaults(self):
|
||||
load_subagents_config_from_dict({})
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.timeout_seconds == 900
|
||||
assert cfg.max_turns is None
|
||||
assert cfg.agents == {}
|
||||
|
||||
def test_load_replaces_previous_config(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
|
||||
assert get_subagents_app_config().timeout_seconds == 100
|
||||
assert get_subagents_app_config().max_turns == 90
|
||||
|
||||
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
|
||||
assert get_subagents_app_config().timeout_seconds == 200
|
||||
assert get_subagents_app_config().max_turns == 110
|
||||
|
||||
def test_singleton_returns_same_instance_between_calls(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
|
||||
assert get_subagents_app_config() is get_subagents_app_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# registry.get_subagent_config – runtime overrides applied
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryGetSubagentConfig:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_returns_none_for_unknown_agent(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
assert get_subagent_config("nonexistent") is None
|
||||
|
||||
def test_returns_config_for_builtin_agents(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
assert get_subagent_config("general-purpose") is not None
|
||||
assert get_subagent_config("bash") is not None
|
||||
|
||||
def test_default_timeout_preserved_when_no_config(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=900)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 900
|
||||
assert config.max_turns == 100
|
||||
|
||||
def test_global_timeout_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 140
|
||||
|
||||
def test_per_agent_runtime_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.timeout_seconds == 120
|
||||
assert bash_config.max_turns == 80
|
||||
|
||||
def test_per_agent_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.timeout_seconds == 900
|
||||
assert gp_config.max_turns == 120
|
||||
|
||||
def test_per_agent_model_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == "gpt-5.4-mini"
|
||||
|
||||
def test_omitted_model_keeps_builtin_value(self):
|
||||
"""When config.yaml has no `model` field for an agent, the builtin default must be preserved."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 300}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == builtin_bash_model
|
||||
|
||||
def test_explicit_null_model_keeps_builtin_value(self):
|
||||
"""An explicit `model: null` in config.yaml is equivalent to omission — builtin wins."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": None}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == builtin_bash_model
|
||||
|
||||
def test_model_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_gp_model = BUILTIN_SUBAGENTS["general-purpose"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4"}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.model == builtin_gp_model
|
||||
|
||||
def test_model_override_preserves_other_fields(self):
|
||||
"""Applying a model override must leave timeout_seconds / max_turns / name intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original = BUILTIN_SUBAGENTS["bash"]
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
overridden = get_subagent_config("bash")
|
||||
assert overridden.model == "gpt-5.4-mini"
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
# No timeout / max_turns override was set, so they use global default / builtin.
|
||||
assert overridden.timeout_seconds == 900
|
||||
assert overridden.max_turns == original.max_turns
|
||||
|
||||
def test_model_override_does_not_mutate_builtin(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
_ = get_subagent_config("bash")
|
||||
assert BUILTIN_SUBAGENTS["bash"].model == original_bash_model
|
||||
|
||||
def test_builtin_config_object_is_not_mutated(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
|
||||
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
|
||||
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
|
||||
|
||||
returned = get_subagent_config("bash")
|
||||
assert returned.timeout_seconds == 42
|
||||
assert returned.max_turns == 88
|
||||
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
|
||||
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
|
||||
|
||||
def test_config_preserves_other_fields(self):
|
||||
"""Applying runtime overrides must not change other SubagentConfig fields."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=300, max_turns=140)
|
||||
original = BUILTIN_SUBAGENTS["general-purpose"]
|
||||
overridden = get_subagent_config("general-purpose")
|
||||
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
assert overridden.max_turns == 140
|
||||
assert overridden.model == original.model
|
||||
assert overridden.tools == original.tools
|
||||
assert overridden.disallowed_tools == original.disallowed_tools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# registry.list_subagents – all agents get overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryListSubagents:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_lists_both_builtin_agents(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
names = {cfg.name for cfg in list_subagents()}
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
|
||||
def test_all_returned_configs_get_global_override(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
_reset_subagents_config(timeout_seconds=123, max_turns=77)
|
||||
for cfg in list_subagents():
|
||||
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
|
||||
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
|
||||
|
||||
def test_per_agent_overrides_reflected_in_list(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
)
|
||||
by_name = {cfg.name: cfg for cfg in list_subagents()}
|
||||
assert by_name["general-purpose"].timeout_seconds == 1800
|
||||
assert by_name["bash"].timeout_seconds == 60
|
||||
assert by_name["general-purpose"].max_turns == 200
|
||||
assert by_name["bash"].max_turns == 80
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Polling timeout calculation (logic extracted from task_tool)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPollingTimeoutCalculation:
|
||||
"""Verify the formula (timeout_seconds + 60) // 5 is correct for various inputs."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"timeout_seconds, expected_max_polls",
|
||||
[
|
||||
(900, 192), # default 15 min → (900+60)//5 = 192
|
||||
(300, 72), # 5 min → (300+60)//5 = 72
|
||||
(1800, 372), # 30 min → (1800+60)//5 = 372
|
||||
(60, 24), # 1 min → (60+60)//5 = 24
|
||||
(1, 12), # minimum → (1+60)//5 = 12
|
||||
],
|
||||
)
|
||||
def test_polling_timeout_formula(self, timeout_seconds: int, expected_max_polls: int):
|
||||
dummy_config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
|
||||
assert max_poll_count == expected_max_polls
|
||||
|
||||
def test_polling_timeout_exceeds_execution_timeout(self):
|
||||
"""Safety-net polling window must always be longer than the execution timeout."""
|
||||
for timeout_seconds in [60, 300, 900, 1800]:
|
||||
dummy_config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
|
||||
polling_window_seconds = max_poll_count * 5
|
||||
assert polling_window_seconds > timeout_seconds
|
||||
def test_load_empty_uses_defaults(self):
|
||||
cfg = _make_config()
|
||||
sub = cfg.subagents
|
||||
assert sub.timeout_seconds == 900
|
||||
assert sub.max_turns is None
|
||||
assert sub.agents == {}
|
||||
|
||||
@@ -8,6 +8,9 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
# Use module import so tests can patch the exact symbols referenced inside task_tool().
|
||||
@@ -24,6 +27,13 @@ class FakeSubagentStatus(Enum):
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime() -> SimpleNamespace:
|
||||
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
|
||||
return SimpleNamespace(
|
||||
@@ -35,7 +45,7 @@ def _make_runtime() -> SimpleNamespace:
|
||||
"outputs_path": "/tmp/outputs",
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=_make_context("thread-1"),
|
||||
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
|
||||
)
|
||||
|
||||
@@ -83,11 +93,11 @@ class _DummyScheduledTask:
|
||||
|
||||
|
||||
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: None)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
|
||||
|
||||
result = _run_task_tool(
|
||||
runtime=None,
|
||||
runtime=_make_runtime(),
|
||||
description="执行任务",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
@@ -98,8 +108,8 @@ def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
|
||||
|
||||
def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda: False)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = _run_task_tool(
|
||||
runtime=_make_runtime(),
|
||||
@@ -142,9 +152,8 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
|
||||
@@ -165,225 +174,20 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
|
||||
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
|
||||
assert captured["executor_kwargs"]["config"].max_turns == 7
|
||||
# Skills are no longer appended to system_prompt; they are loaded per-session
|
||||
# by SubagentExecutor and injected as conversation items (Codex pattern).
|
||||
assert captured["executor_kwargs"]["config"].system_prompt == "Base system prompt"
|
||||
# Skills are now loaded per-session by SubagentExecutor (mirroring Codex's pattern);
|
||||
# task_tool no longer appends them to ``system_prompt`` here.
|
||||
assert "Skills Appendix" not in captured["executor_kwargs"]["config"].system_prompt
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=None, subagent_enabled=False)
|
||||
from unittest.mock import ANY
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=ANY, subagent_enabled=False, app_config=ANY)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
|
||||
assert events[-1]["result"] == "all done"
|
||||
|
||||
|
||||
def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch):
|
||||
"""Verify tool_groups from parent metadata are passed to get_available_tools(groups=...)."""
|
||||
config = _make_subagent_config()
|
||||
parent_tool_groups = ["file:read", "file:write", "bash"]
|
||||
runtime = SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
"thread_data": {"workspace_path": "/tmp/workspace"},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1", "tool_groups": parent_tool_groups}},
|
||||
)
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=["tool-a"])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="file work only",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-groups",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
# The key assertion: groups should be propagated from parent metadata
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
|
||||
|
||||
|
||||
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
runtime.config["metadata"]["available_skills"] = ["safe-skill"]
|
||||
events = []
|
||||
captured = {}
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
captured["config"] = kwargs["config"]
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="use skills",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-skills",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
assert captured["config"].skills == ["safe-skill"]
|
||||
|
||||
|
||||
def test_task_tool_intersects_parent_and_subagent_skill_allowlists(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
config = SubagentConfig(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
system_prompt=config.system_prompt,
|
||||
max_turns=config.max_turns,
|
||||
timeout_seconds=config.timeout_seconds,
|
||||
skills=["safe-skill", "other-skill"],
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
runtime.config["metadata"]["available_skills"] = ["safe-skill"]
|
||||
events = []
|
||||
captured = {}
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
captured["config"] = kwargs["config"]
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="use skills",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-skills-intersection",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
assert captured["config"].skills == ["safe-skill"]
|
||||
|
||||
|
||||
def test_task_tool_no_tool_groups_passes_none(monkeypatch):
|
||||
"""Verify that when metadata has no tool_groups, groups=None is passed (backward compat)."""
|
||||
config = _make_subagent_config()
|
||||
# Default _make_runtime() has no tool_groups in metadata
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="ok"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="normal work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-groups",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: ok"
|
||||
# No tool_groups in metadata → groups=None (default behavior preserved)
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=None, subagent_enabled=False)
|
||||
|
||||
|
||||
def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
"""Verify that when runtime is None, groups=None is passed (e.g., unknown subagent path exits early, but tools still load correctly)."""
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="ok"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=None,
|
||||
description="执行任务",
|
||||
prompt="no runtime",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-runtime",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: ok"
|
||||
# runtime is None → metadata is empty dict → groups=None
|
||||
get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False)
|
||||
|
||||
def test_task_tool_returns_failed_message(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
|
||||
@@ -393,12 +197,11 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -427,12 +230,11 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -463,12 +265,11 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -499,12 +300,11 @@ def test_cleanup_called_on_completed(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -539,12 +339,11 @@ def test_cleanup_called_on_failed(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -579,12 +378,11 @@ def test_cleanup_called_on_timed_out(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -626,12 +424,11 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -679,8 +476,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@@ -730,12 +526,11 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@@ -785,12 +580,11 @@ def test_cancellation_calls_request_cancel(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@@ -843,9 +637,8 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
|
||||
@@ -1,58 +1,55 @@
|
||||
import pytest
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _as_posix(path: str) -> str:
|
||||
return path.replace("\\", "/")
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadDataMiddleware:
|
||||
def test_before_agent_returns_paths_when_thread_id_present_in_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context={"thread_id": "thread-123"}))
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-123")))
|
||||
|
||||
assert result is not None
|
||||
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-123/user-data/workspace")
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-123/user-data/uploads")
|
||||
assert _as_posix(result["thread_data"]["outputs_path"]).endswith("threads/thread-123/user-data/outputs")
|
||||
|
||||
def test_before_agent_uses_thread_id_from_configurable_when_context_is_none(self, tmp_path, monkeypatch):
|
||||
def test_before_agent_uses_thread_id_from_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
runtime = Runtime(context=None)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
||||
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
result = middleware.before_agent(state={}, runtime=runtime)
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-config")))
|
||||
|
||||
assert result is not None
|
||||
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-from-config/user-data/workspace")
|
||||
assert runtime.context is None
|
||||
|
||||
def test_before_agent_uses_thread_id_from_configurable_when_context_missing_thread_id(self, tmp_path, monkeypatch):
|
||||
def test_before_agent_uses_thread_id_from_typed_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
runtime = Runtime(context={})
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
||||
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
result = middleware.before_agent(state={}, runtime=runtime)
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-dict")))
|
||||
|
||||
assert result is not None
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-config/user-data/uploads")
|
||||
assert runtime.context == {}
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-dict/user-data/uploads")
|
||||
|
||||
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
|
||||
def test_before_agent_raises_clear_error_when_thread_id_missing(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
||||
lambda: {"configurable": {}},
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
with pytest.raises(ValueError, match="Thread ID is required in runtime context or config.configurable"):
|
||||
middleware.before_agent(state={}, runtime=Runtime(context=None))
|
||||
with pytest.raises(ValueError, match="Thread ID is required"):
|
||||
middleware.before_agent(state={}, runtime=Runtime(context=_make_context("")))
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""Tests for paginated GET /api/threads/{thread_id}/runs/{run_id}/messages endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import thread_runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -77,8 +78,7 @@ def test_after_seq_forwarded_to_event_store():
|
||||
response = client.get("/api/threads/thread-3/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3",
|
||||
"run-3",
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
@@ -94,8 +94,7 @@ def test_before_seq_forwarded_to_event_store():
|
||||
response = client.get("/api/threads/thread-4/runs/run-4/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4",
|
||||
"run-4",
|
||||
"thread-4", "run-4",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
@@ -111,8 +110,7 @@ def test_custom_limit_forwarded_to_event_store():
|
||||
response = client.get("/api/threads/thread-5/runs/run-5/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5",
|
||||
"run-5",
|
||||
"thread-5", "run-5",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
|
||||
class TestTitleConfig:
|
||||
@@ -44,21 +44,6 @@ class TestTitleConfig:
|
||||
with pytest.raises(ValueError):
|
||||
TitleConfig(max_chars=201)
|
||||
|
||||
def test_get_set_config(self):
|
||||
"""Test global config getter and setter."""
|
||||
original_config = get_title_config()
|
||||
|
||||
# Set new config
|
||||
new_config = TitleConfig(enabled=False, max_words=10)
|
||||
set_title_config(new_config)
|
||||
|
||||
# Verify it was set
|
||||
assert get_title_config().enabled is False
|
||||
assert get_title_config().max_words == 10
|
||||
|
||||
# Restore original config
|
||||
set_title_config(original_config)
|
||||
|
||||
|
||||
class TestTitleMiddleware:
|
||||
"""Tests for TitleMiddleware."""
|
||||
@@ -68,23 +53,3 @@ class TestTitleMiddleware:
|
||||
middleware = TitleMiddleware()
|
||||
assert middleware is not None
|
||||
assert middleware.state_schema is not None
|
||||
|
||||
# TODO: Add integration tests with mock Runtime
|
||||
# def test_should_generate_title(self):
|
||||
# """Test title generation trigger logic."""
|
||||
# pass
|
||||
|
||||
# def test_generate_title(self):
|
||||
# """Test title generation."""
|
||||
# pass
|
||||
|
||||
# def test_after_agent_hook(self):
|
||||
# """Test after_agent hook."""
|
||||
# pass
|
||||
|
||||
|
||||
# TODO: Add integration tests
|
||||
# - Test with real LangGraph runtime
|
||||
# - Test title persistence with checkpointer
|
||||
# - Test fallback behavior when LLM fails
|
||||
# - Test concurrent title generation
|
||||
|
||||
@@ -1,38 +1,32 @@
|
||||
"""Core behavior tests for TitleMiddleware."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
|
||||
def _clone_title_config(config: TitleConfig) -> TitleConfig:
|
||||
# Avoid mutating shared global config objects across tests.
|
||||
return TitleConfig(**config.model_dump())
|
||||
def _make_title_config(**overrides) -> TitleConfig:
|
||||
return TitleConfig(**overrides)
|
||||
|
||||
|
||||
def _set_test_title_config(**overrides) -> TitleConfig:
|
||||
config = _clone_title_config(get_title_config())
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
set_title_config(config)
|
||||
return config
|
||||
def _make_runtime(**title_overrides) -> SimpleNamespace:
|
||||
"""Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
|
||||
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
|
||||
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
|
||||
return SimpleNamespace(context=ctx)
|
||||
|
||||
|
||||
class TestTitleMiddlewareCoreLogic:
|
||||
def setup_method(self):
|
||||
# Title config is a global singleton; snapshot and restore for test isolation.
|
||||
self._original = _clone_title_config(get_title_config())
|
||||
|
||||
def teardown_method(self):
|
||||
set_title_config(self._original)
|
||||
|
||||
def test_should_generate_title_for_first_complete_exchange(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -41,27 +35,24 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is True
|
||||
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
|
||||
|
||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
_set_test_title_config(enabled=False)
|
||||
disabled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": None,
|
||||
}
|
||||
assert middleware._should_generate_title(disabled_state) is False
|
||||
assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
|
||||
|
||||
_set_test_title_config(enabled=True)
|
||||
titled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": "Existing Title",
|
||||
}
|
||||
assert middleware._should_generate_title(titled_state) is False
|
||||
assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
|
||||
|
||||
def test_should_not_generate_title_after_second_user_turn(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -72,10 +63,9 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is False
|
||||
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
|
||||
|
||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=12)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
@@ -87,11 +77,13 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=12))))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
from unittest.mock import ANY
|
||||
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, app_config=ANY)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {
|
||||
"run_name": "title_agent",
|
||||
@@ -99,7 +91,6 @@ class TestTitleMiddlewareCoreLogic:
|
||||
}
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
||||
@@ -112,13 +103,12 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "请帮我总结这段代码"
|
||||
|
||||
def test_generate_title_fallback_for_long_message(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
||||
@@ -130,7 +120,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="收到"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
|
||||
title = result["title"]
|
||||
|
||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||
@@ -141,25 +131,24 @@ class TestTitleMiddlewareCoreLogic:
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
|
||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
|
||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
|
||||
assert result == {"title": "异步标题"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
|
||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
|
||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
|
||||
|
||||
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
|
||||
result = middleware.after_model({"messages": []}, runtime=MagicMock())
|
||||
result = middleware.after_model({"messages": []}, runtime=_make_runtime())
|
||||
assert result == {"title": "同步标题"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
|
||||
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
|
||||
assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
|
||||
|
||||
def test_sync_generate_title_uses_fallback_without_model(self):
|
||||
"""Sync path avoids LLM calls and derives a local fallback title."""
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
state = {
|
||||
@@ -168,12 +157,11 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="好的"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
|
||||
assert result == {"title": "请帮我写测试"}
|
||||
|
||||
def test_sync_generate_title_respects_fallback_truncation(self):
|
||||
"""Sync fallback path still respects max_chars truncation rules."""
|
||||
_set_test_title_config(max_chars=50)
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
state = {
|
||||
@@ -182,7 +170,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="回复"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
|
||||
assert result["title"].endswith("...")
|
||||
assert result["title"].startswith("这是一个非常长的问题描述")
|
||||
|
||||
|
||||
@@ -154,8 +154,7 @@ class TestStreamUsageIntegration:
|
||||
"""Test that stream() emits usage_metadata in messages-tuple and end events."""
|
||||
|
||||
def _make_client(self):
|
||||
with patch("deerflow.client.get_app_config", return_value=_mock_app_config()):
|
||||
return DeerFlowClient()
|
||||
return DeerFlowClient()
|
||||
|
||||
def test_stream_emits_usage_in_messages_tuple(self):
|
||||
"""messages-tuple AI event should include usage_metadata when present."""
|
||||
|
||||
@@ -56,27 +56,25 @@ def _make_minimal_config(tools):
|
||||
return config
|
||||
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
|
||||
def test_no_duplicates_returned(mock_reset, mock_bash):
|
||||
"""get_available_tools() never returns two tools with the same name."""
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
cfg = _make_minimal_config([])
|
||||
|
||||
# Patch the builtin tools so we control exactly what comes back.
|
||||
with patch("deerflow.tools.tools.BUILTIN_TOOLS", [_tool_alpha, _tool_alpha_dup, _tool_beta]):
|
||||
result = get_available_tools(include_mcp=False)
|
||||
result = get_available_tools(include_mcp=False, app_config=cfg)
|
||||
|
||||
names = [t.name for t in result]
|
||||
assert len(names) == len(set(names)), f"Duplicate names detected: {names}"
|
||||
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||
def test_first_occurrence_wins(mock_reset, mock_bash):
|
||||
"""When duplicates exist, the first occurrence is kept."""
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
cfg = _make_minimal_config([])
|
||||
|
||||
sentinel_alpha = MagicMock(spec=BaseTool, name="_sentinel")
|
||||
sentinel_alpha.name = _tool_alpha.name # same name
|
||||
@@ -84,23 +82,22 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||
sentinel_alpha_dup.name = _tool_alpha.name # same name — should be dropped
|
||||
|
||||
with patch("deerflow.tools.tools.BUILTIN_TOOLS", [sentinel_alpha, sentinel_alpha_dup, _tool_beta]):
|
||||
result = get_available_tools(include_mcp=False)
|
||||
result = get_available_tools(include_mcp=False, app_config=cfg)
|
||||
|
||||
returned_alpha = next(t for t in result if t.name == _tool_alpha.name)
|
||||
assert returned_alpha is sentinel_alpha
|
||||
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
|
||||
def test_duplicate_triggers_warning(mock_reset, mock_bash, caplog):
|
||||
"""A warning is logged for every skipped duplicate."""
|
||||
import logging
|
||||
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
cfg = _make_minimal_config([])
|
||||
|
||||
with patch("deerflow.tools.tools.BUILTIN_TOOLS", [_tool_alpha, _tool_alpha_dup]):
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.tools.tools"):
|
||||
get_available_tools(include_mcp=False)
|
||||
get_available_tools(include_mcp=False, app_config=cfg)
|
||||
|
||||
assert any("Duplicate tool name" in r.message for r in caplog.records), "Expected a duplicate-tool warning in log output"
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool as langchain_tool
|
||||
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
from deerflow.tools.builtins.tool_search import (
|
||||
DeferredToolRegistry,
|
||||
get_deferred_registry,
|
||||
@@ -64,12 +64,12 @@ class TestToolSearchConfig:
|
||||
config = ToolSearchConfig(enabled=True)
|
||||
assert config.enabled is True
|
||||
|
||||
def test_load_from_dict(self):
|
||||
config = load_tool_search_config_from_dict({"enabled": True})
|
||||
def test_validate_from_dict(self):
|
||||
config = ToolSearchConfig.model_validate({"enabled": True})
|
||||
assert config.enabled is True
|
||||
|
||||
def test_load_from_empty_dict(self):
|
||||
config = load_tool_search_config_from_dict({})
|
||||
def test_validate_from_empty_dict(self):
|
||||
config = ToolSearchConfig.model_validate({})
|
||||
assert config.enabled is False
|
||||
|
||||
|
||||
@@ -266,48 +266,42 @@ class TestToolSearchTool:
|
||||
|
||||
|
||||
class TestDeferredToolsPromptSection:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_app_config(self, monkeypatch):
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Provide a minimal AppConfig mock so tests don't need config.yaml."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.tool_search = ToolSearchConfig() # disabled by default
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: mock_config)
|
||||
config = MagicMock()
|
||||
config.tool_search = ToolSearchConfig() # disabled by default
|
||||
return config
|
||||
|
||||
def test_empty_when_disabled(self):
|
||||
def test_empty_when_disabled(self, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
|
||||
# tool_search.enabled defaults to False
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
|
||||
def test_empty_when_enabled_but_no_registry(self, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
mock_config.tool_search = ToolSearchConfig(enabled=True)
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
|
||||
def test_empty_when_enabled_but_empty_registry(self, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
mock_config.tool_search = ToolSearchConfig(enabled=True)
|
||||
set_deferred_registry(DeferredToolRegistry())
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert section == ""
|
||||
|
||||
def test_lists_tool_names(self, registry, monkeypatch):
|
||||
def test_lists_tool_names(self, registry, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
mock_config.tool_search = ToolSearchConfig(enabled=True)
|
||||
set_deferred_registry(registry)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert "<available-deferred-tools>" in section
|
||||
assert "</available-deferred-tools>" in section
|
||||
assert "github_create_issue" in section
|
||||
|
||||
@@ -13,7 +13,10 @@ from unittest.mock import MagicMock
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
THREAD_ID = "thread-abc123"
|
||||
|
||||
@@ -23,13 +26,20 @@ THREAD_ID = "thread-abc123"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _middleware(tmp_path: Path) -> UploadsMiddleware:
|
||||
return UploadsMiddleware(base_dir=str(tmp_path))
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock:
|
||||
rt = MagicMock()
|
||||
rt.context = {"thread_id": thread_id}
|
||||
rt.context = _make_context(thread_id or "")
|
||||
return rt
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
DEFAULT_USER_ID,
|
||||
CurrentUser,
|
||||
DEFAULT_USER_ID,
|
||||
get_current_user,
|
||||
get_effective_user_id,
|
||||
require_current_user,
|
||||
@@ -100,7 +100,6 @@ def test_effective_user_id_returns_user_id_when_set():
|
||||
def test_effective_user_id_coerces_to_str():
|
||||
"""User.id might be a UUID object; must come back as str."""
|
||||
import uuid
|
||||
|
||||
uid = uuid.uuid4()
|
||||
|
||||
user = SimpleNamespace(id=uid)
|
||||
|
||||
Reference in New Issue
Block a user