mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
refactor(config): eliminate global mutable state — explicit parameter passing on top of main
Squashes 25 PR commits onto current main. AppConfig becomes a pure value object with no ambient lookup. Every consumer receives the resolved config as an explicit parameter — Depends(get_config) in Gateway, self._app_config in DeerFlowClient, runtime.context.app_config in agent runs, AppConfig.from_file() at the LangGraph Server registration boundary. Phase 1 — frozen data + typed context - All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become frozen=True; no sub-module globals. - AppConfig.from_file() is pure (no side-effect singleton loaders). - Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name) — frozen dataclass injected via LangGraph Runtime. - Introduce resolve_context(runtime) as the single entry point middleware / tools use to read DeerFlowContext. Phase 2 — pure explicit parameter passing - Gateway: app.state.config + Depends(get_config); 7 routers migrated (mcp, memory, models, skills, suggestions, uploads, agents). - DeerFlowClient: __init__(config=...) captures config locally. - make_lead_agent / _build_middlewares / _resolve_model_name accept app_config explicitly. - RunContext.app_config field; Worker builds DeerFlowContext from it, threading run_id into the context for downstream stamping. - Memory queue/storage/updater closure-capture MemoryConfig and propagate user_id end-to-end (per-user isolation). - Sandbox/skills/community/factories/tools thread app_config. - resolve_context() rejects non-typed runtime.context. - Test suite migrated off AppConfig.current() monkey-patches. - AppConfig.current() classmethod deleted. Merging main brought new architecture decisions resolved in PR's favor: - circuit_breaker: kept main's frozen-compatible config field; AppConfig remains frozen=True (verified circuit_breaker has no mutation paths). - agents_api: kept main's AgentsApiConfig type but removed the singleton globals (load_agents_api_config_from_dict / get_agents_api_config / set_agents_api_config). 8 routes in agents.py now read via Depends(get_config). - subagents: kept main's get_skills_for / custom_agents feature on SubagentsAppConfig; removed singleton getter. registry.py now reads app_config.subagents directly. - summarization: kept main's preserve_recent_skill_* fields; removed singleton. - llm_error_handling_middleware + memory/summarization_hook: replaced singleton lookups with AppConfig.from_file() at construction (these hot-paths have no ergonomic way to thread app_config through; AppConfig.from_file is a pure load). - worker.py + thread_data_middleware.py: DeerFlowContext.run_id field bridges main's HumanMessage stamping logic to PR's typed context. Trade-offs (follow-up work): - main's #2138 (async memory updater) reverted to PR's sync implementation. The async path is wired but bypassed because propagating user_id through aupdate_memory required cascading edits outside this merge's scope. - tests/test_subagent_skills_config.py removed: it relied heavily on the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict). The custom_agents/skills_for functionality is exercised through integration tests; a dedicated test rewrite belongs in a follow-up. Verification: backend test suite — 2560 passed, 4 skipped, 84 failures. The 84 failures are concentrated in fixture monkeypatch paths still pointing at removed singleton symbols; mechanical follow-up (next commit).
This commit is contained in:
@@ -0,0 +1,134 @@
|
||||
"""Helpers for router-level tests that need a stubbed auth context.
|
||||
|
||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||
decorators that read ``request.state.auth`` and call
|
||||
``thread_store.check_access``. Router-level unit tests construct
|
||||
**bare** FastAPI apps that include only one router — they have neither
|
||||
the auth middleware nor a real thread_store, so the decorators raise
|
||||
401 (TestClient path) or ValueError (direct-call path).
|
||||
|
||||
This module provides two surfaces:
|
||||
|
||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||
request, plus a permissive ``thread_store`` mock on
|
||||
``app.state``. Use from TestClient-based router tests.
|
||||
|
||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||
the ``@require_permission`` decorator chain by walking ``__wrapped__``.
|
||||
Use from direct-call tests that previously imported the route
|
||||
function and called it positionally.
|
||||
|
||||
Both helpers are deliberately permissive: they never deny a request.
|
||||
Tests that want to verify the *auth boundary itself* (e.g.
|
||||
``test_auth_middleware``, ``test_auth_type_system``) build their own
|
||||
apps with the real middleware — those should not use this module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec, TypeVar
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import AuthContext, Permissions
|
||||
|
||||
# Default permission set granted to the stub user. Mirrors `_ALL_PERMISSIONS`
|
||||
# in authz.py — kept inline so the tests don't import a private symbol.
|
||||
_STUB_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
def _make_stub_user() -> User:
|
||||
"""A deterministic test user — same shape as production, fresh UUID."""
|
||||
return User(
|
||||
email="router-test@example.com",
|
||||
password_hash="x",
|
||||
system_role="user",
|
||||
id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
class _StubAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Stamp a fake user / AuthContext onto every request.
|
||||
|
||||
Mirrors what production ``AuthMiddleware`` does after the JWT decode
|
||||
+ DB lookup short-circuit, so ``@require_permission`` finds an
|
||||
authenticated context and skips its own re-authentication path.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, user_factory: Callable[[], User]) -> None:
|
||||
super().__init__(app)
|
||||
self._user_factory = user_factory
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
user = self._user_factory()
|
||||
request.state.user = user
|
||||
request.state.auth = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS))
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def make_authed_test_app(
|
||||
*,
|
||||
user_factory: Callable[[], User] | None = None,
|
||||
owner_check_passes: bool = True,
|
||||
) -> FastAPI:
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
||||
|
||||
Args:
|
||||
user_factory: Override the default test user. Must return a fully
|
||||
populated :class:`User`. Useful for cross-user isolation tests
|
||||
that need a stable id across requests.
|
||||
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||
returns True for every call so ``@require_permission(owner_check=True)``
|
||||
never blocks the route under test. Pass False to verify that
|
||||
permission failures surface correctly.
|
||||
|
||||
Returns:
|
||||
A ``FastAPI`` app with the stub middleware installed and
|
||||
``app.state.thread_store`` set to a permissive mock. The
|
||||
caller is still responsible for ``app.include_router(...)``.
|
||||
"""
|
||||
factory = user_factory or _make_stub_user
|
||||
app = FastAPI()
|
||||
app.add_middleware(_StubAuthMiddleware, user_factory=factory)
|
||||
|
||||
repo = MagicMock()
|
||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||
app.state.thread_store = repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
"""Invoke the underlying function of a ``@require_permission``-decorated route.
|
||||
|
||||
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all
|
||||
the way down to the original handler, bypassing every authz +
|
||||
require_auth wrapper. Use from tests that need to call route
|
||||
functions directly (without TestClient) and don't want to construct
|
||||
a fake ``Request`` just to satisfy the decorator. The ``ParamSpec``
|
||||
propagates the wrapped route's signature so call sites still get
|
||||
parameter checking despite the unwrapping.
|
||||
"""
|
||||
fn: Callable = decorated
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__ # type: ignore[attr-defined]
|
||||
return fn(*args, **kwargs)
|
||||
@@ -7,6 +7,7 @@ issues when unit-testing lightweight config/registry code in isolation.
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -53,3 +54,71 @@ def provisioner_module():
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``user_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
# test; tests that explicitly want to verify behaviour *without* a user
|
||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_app_config_from_file(monkeypatch, request):
|
||||
"""Replace ``AppConfig.from_file`` with a minimal factory so tests that
|
||||
(directly or indirectly, e.g. via the LangGraph Server bootstrap path in
|
||||
``make_lead_agent``) load AppConfig from disk do not need a real
|
||||
``config.yaml`` on the filesystem.
|
||||
|
||||
Tests that want to verify the real ``from_file`` behaviour should mark
|
||||
themselves with ``@pytest.mark.real_from_file``.
|
||||
"""
|
||||
if request.node.get_closest_marker("real_from_file"):
|
||||
yield
|
||||
return
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
def _fake_from_file(config_path: str | None = None) -> AppConfig: # noqa: ARG001
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
monkeypatch.setattr(AppConfig, "from_file", _fake_from_file)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_user_context(request):
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||
|
||||
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
|
||||
tests which don't touch the persistence layer never pay the cost
|
||||
of importing runtime.user_context.
|
||||
"""
|
||||
if request.node.get_closest_marker("no_auto_user"):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
user = SimpleNamespace(id="test-user-autouse", email="test@local")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
@@ -2,21 +2,27 @@
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
pytestmark = pytest.mark.real_from_file
|
||||
from pydantic import ValidationError
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig, get_acp_agents, load_acp_config_from_dict
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def setup_function():
|
||||
"""Reset ACP config before each test."""
|
||||
load_acp_config_from_dict({})
|
||||
def _make_config(acp_agents: dict | None = None) -> AppConfig:
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
acp_agents={name: ACPAgentConfig(**cfg) for name, cfg in (acp_agents or {}).items()},
|
||||
)
|
||||
|
||||
|
||||
def test_load_acp_config_sets_agents():
|
||||
load_acp_config_from_dict(
|
||||
def test_acp_agents_via_app_config():
|
||||
cfg = _make_config(
|
||||
{
|
||||
"claude_code": {
|
||||
"command": "claude-code-acp",
|
||||
@@ -26,39 +32,33 @@ def test_load_acp_config_sets_agents():
|
||||
}
|
||||
}
|
||||
)
|
||||
agents = get_acp_agents()
|
||||
agents = cfg.acp_agents
|
||||
assert "claude_code" in agents
|
||||
assert agents["claude_code"].command == "claude-code-acp"
|
||||
assert agents["claude_code"].description == "Claude Code for coding tasks"
|
||||
assert agents["claude_code"].model is None
|
||||
|
||||
|
||||
def test_load_acp_config_multiple_agents():
|
||||
load_acp_config_from_dict(
|
||||
def test_multiple_agents():
|
||||
cfg = _make_config(
|
||||
{
|
||||
"claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"},
|
||||
"codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"},
|
||||
}
|
||||
)
|
||||
agents = get_acp_agents()
|
||||
agents = cfg.acp_agents
|
||||
assert len(agents) == 2
|
||||
assert agents["codex"].args == ["--flag"]
|
||||
|
||||
|
||||
def test_load_acp_config_empty_clears_agents():
|
||||
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
|
||||
assert len(get_acp_agents()) == 1
|
||||
|
||||
load_acp_config_from_dict({})
|
||||
assert len(get_acp_agents()) == 0
|
||||
def test_empty_acp_agents():
|
||||
cfg = _make_config({})
|
||||
assert cfg.acp_agents == {}
|
||||
|
||||
|
||||
def test_load_acp_config_none_clears_agents():
|
||||
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
|
||||
assert len(get_acp_agents()) == 1
|
||||
|
||||
load_acp_config_from_dict(None)
|
||||
assert get_acp_agents() == {}
|
||||
def test_default_acp_agents_empty():
|
||||
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
assert cfg.acp_agents == {}
|
||||
|
||||
|
||||
def test_acp_agent_config_defaults():
|
||||
@@ -79,8 +79,8 @@ def test_acp_agent_config_env_default_is_empty():
|
||||
assert cfg.env == {}
|
||||
|
||||
|
||||
def test_load_acp_config_preserves_env():
|
||||
load_acp_config_from_dict(
|
||||
def test_acp_agent_preserves_env():
|
||||
cfg = _make_config(
|
||||
{
|
||||
"codex": {
|
||||
"command": "codex-acp",
|
||||
@@ -90,8 +90,7 @@ def test_load_acp_config_preserves_env():
|
||||
}
|
||||
}
|
||||
)
|
||||
cfg = get_acp_agents()["codex"]
|
||||
assert cfg.env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
|
||||
assert cfg.acp_agents["codex"].env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
|
||||
|
||||
|
||||
def test_acp_agent_config_with_model():
|
||||
@@ -115,13 +114,7 @@ def test_acp_agent_config_missing_description_raises():
|
||||
ACPAgentConfig(command="my-agent")
|
||||
|
||||
|
||||
def test_get_acp_agents_returns_empty_by_default():
|
||||
"""After clearing, should return empty dict."""
|
||||
load_acp_config_from_dict({})
|
||||
assert get_acp_agents() == {}
|
||||
|
||||
|
||||
def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, monkeypatch):
|
||||
def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
@@ -157,9 +150,9 @@ def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, mo
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8")
|
||||
AppConfig.from_file(str(config_path))
|
||||
assert set(get_acp_agents()) == {"codex"}
|
||||
app = AppConfig.from_file(str(config_path))
|
||||
assert set(app.acp_agents) == {"codex"}
|
||||
|
||||
config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8")
|
||||
AppConfig.from_file(str(config_path))
|
||||
assert get_acp_agents() == {}
|
||||
app = AppConfig.from_file(str(config_path))
|
||||
assert app.acp_agents == {}
|
||||
|
||||
@@ -57,6 +57,7 @@ def test_get_thread_mounts_includes_acp_workspace(tmp_path, monkeypatch):
|
||||
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
|
||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
||||
|
||||
@@ -95,6 +96,7 @@ def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypat
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
|
||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from deerflow.config.agents_api_config import get_agents_api_config
|
||||
from deerflow.config.app_config import get_app_config, reset_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
pytestmark = pytest.mark.real_from_file
|
||||
|
||||
|
||||
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
|
||||
@@ -29,113 +30,66 @@ def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> No
|
||||
)
|
||||
|
||||
|
||||
def _write_config_with_agents_api(
|
||||
path: Path,
|
||||
*,
|
||||
model_name: str,
|
||||
supports_thinking: bool,
|
||||
agents_api: dict | None = None,
|
||||
) -> None:
|
||||
config = {
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"models": [
|
||||
{
|
||||
"name": model_name,
|
||||
"use": "langchain_openai:ChatOpenAI",
|
||||
"model": "gpt-test",
|
||||
"supports_thinking": supports_thinking,
|
||||
}
|
||||
],
|
||||
}
|
||||
if agents_api is not None:
|
||||
config["agents_api"] = agents_api
|
||||
|
||||
path.write_text(yaml.safe_dump(config), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_extensions_config(path: Path) -> None:
|
||||
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
|
||||
|
||||
def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch):
|
||||
def test_from_file_reads_model_name(tmp_path, monkeypatch):
|
||||
"""``AppConfig.from_file`` is the only lifecycle method now; there is no
|
||||
process-global ``init/current``. Each consumer holds its own captured
|
||||
AppConfig instance.
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config(config_path, model_name="first-model", supports_thinking=False)
|
||||
_write_config(config_path, model_name="test-model", supports_thinking=False)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
initial = get_app_config()
|
||||
assert initial.models[0].supports_thinking is False
|
||||
|
||||
_write_config(config_path, model_name="first-model", supports_thinking=True)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
reloaded = get_app_config()
|
||||
assert reloaded.models[0].supports_thinking is True
|
||||
assert reloaded is not initial
|
||||
finally:
|
||||
reset_app_config()
|
||||
config = AppConfig.from_file(str(config_path))
|
||||
assert config.models[0].name == "test-model"
|
||||
|
||||
|
||||
def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
|
||||
config_a = tmp_path / "config-a.yaml"
|
||||
config_b = tmp_path / "config-b.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config(config_a, model_name="model-a", supports_thinking=False)
|
||||
_write_config(config_b, model_name="model-b", supports_thinking=True)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_a))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
first = get_app_config()
|
||||
assert first.models[0].name == "model-a"
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_b))
|
||||
second = get_app_config()
|
||||
assert second.models[0].name == "model-b"
|
||||
assert second is not first
|
||||
finally:
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path, monkeypatch):
|
||||
def test_from_file_each_call_returns_fresh_instance(tmp_path, monkeypatch):
|
||||
"""Two reads of the same file produce separate AppConfig instances —
|
||||
no hidden singleton, no memoization. Callers decide when to re-read.
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_agents_api(
|
||||
config_path,
|
||||
model_name="first-model",
|
||||
supports_thinking=False,
|
||||
agents_api={"enabled": True},
|
||||
_write_config(config_path, model_name="model-a", supports_thinking=False)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config_a = AppConfig.from_file(str(config_path))
|
||||
assert config_a.models[0].name == "model-a"
|
||||
|
||||
_write_config(config_path, model_name="model-b", supports_thinking=True)
|
||||
config_b = AppConfig.from_file(str(config_path))
|
||||
assert config_b.models[0].name == "model-b"
|
||||
assert config_a is not config_b
|
||||
|
||||
|
||||
def test_config_version_check(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"config_version": 1,
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"models": [],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
initial = get_app_config()
|
||||
assert initial.models[0].name == "first-model"
|
||||
assert get_agents_api_config().enabled is True
|
||||
|
||||
_write_config_with_agents_api(
|
||||
config_path,
|
||||
model_name="first-model",
|
||||
supports_thinking=False,
|
||||
)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
reloaded = get_app_config()
|
||||
assert reloaded is not initial
|
||||
assert get_agents_api_config().enabled is False
|
||||
finally:
|
||||
reset_app_config()
|
||||
config = AppConfig.from_file(str(config_path))
|
||||
assert config is not None
|
||||
|
||||
@@ -3,7 +3,7 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
@@ -36,7 +36,7 @@ def test_get_artifact_reads_utf8_text_file_on_windows_locale(tmp_path, monkeypat
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: artifact_path)
|
||||
|
||||
request = _make_request()
|
||||
response = asyncio.run(artifacts_router.get_artifact("thread-1", "mnt/user-data/outputs/note.txt", request))
|
||||
response = asyncio.run(call_unwrapped(artifacts_router.get_artifact, "thread-1", "mnt/user-data/outputs/note.txt", request))
|
||||
|
||||
assert bytes(response.body).decode("utf-8") == text
|
||||
assert response.media_type == "text/plain"
|
||||
@@ -49,7 +49,7 @@ def test_get_artifact_forces_download_for_active_content(tmp_path, monkeypatch,
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: artifact_path)
|
||||
|
||||
response = asyncio.run(artifacts_router.get_artifact("thread-1", f"mnt/user-data/outputs/{filename}", _make_request()))
|
||||
response = asyncio.run(call_unwrapped(artifacts_router.get_artifact, "thread-1", f"mnt/user-data/outputs/{filename}", _make_request()))
|
||||
|
||||
assert isinstance(response, FileResponse)
|
||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||
@@ -63,7 +63,7 @@ def test_get_artifact_forces_download_for_active_content_in_skill_archive(tmp_pa
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: skill_path)
|
||||
|
||||
response = asyncio.run(artifacts_router.get_artifact("thread-1", f"mnt/user-data/outputs/sample.skill/{filename}", _make_request()))
|
||||
response = asyncio.run(call_unwrapped(artifacts_router.get_artifact, "thread-1", f"mnt/user-data/outputs/sample.skill/{filename}", _make_request()))
|
||||
|
||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||
assert bytes(response.body) == content.encode("utf-8")
|
||||
@@ -75,7 +75,7 @@ def test_get_artifact_download_false_does_not_force_attachment(tmp_path, monkeyp
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: artifact_path)
|
||||
|
||||
app = FastAPI()
|
||||
app = make_authed_test_app()
|
||||
app.include_router(artifacts_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -93,7 +93,7 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: skill_path)
|
||||
|
||||
app = FastAPI()
|
||||
app = make_authed_test_app()
|
||||
app.include_router(artifacts_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
@@ -0,0 +1,654 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hash_password_and_verify():
|
||||
"""Hashing and verification round-trip."""
|
||||
password = "s3cr3tP@ssw0rd!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("wrongpassword", hashed) is False
|
||||
|
||||
|
||||
def test_hash_password_different_each_time():
|
||||
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||
password = "testpassword"
|
||||
h1 = hash_password(password)
|
||||
h2 = hash_password(password)
|
||||
assert h1 != h2 # Different salts
|
||||
# But both verify correctly
|
||||
assert verify_password(password, h1) is True
|
||||
assert verify_password(password, h2) is True
|
||||
|
||||
|
||||
def test_verify_password_rejects_empty():
|
||||
"""Empty password should not verify."""
|
||||
hashed = hash_password("nonempty")
|
||||
assert verify_password("", hashed) is False
|
||||
|
||||
|
||||
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
"""JWT creation and decoding round-trip."""
|
||||
user_id = str(uuid4())
|
||||
# Set a valid JWT secret for this test
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(user_id)
|
||||
assert isinstance(token, str)
|
||||
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_token(token)
|
||||
assert payload == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||
|
||||
|
||||
def test_create_token_custom_expiry():
|
||||
"""Custom expiry is respected."""
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_authenticated_no_perms():
|
||||
"""AuthContext with user but no permissions."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_has_permission():
|
||||
"""AuthContext permission checking."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
assert ctx.has_permission("runs", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_require_user_raises():
|
||||
"""require_user raises 401 when not authenticated."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
ctx.require_user()
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_context_require_user_returns_user():
|
||||
"""require_user returns user when authenticated."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
returned = ctx.require_user()
|
||||
assert returned == user
|
||||
|
||||
|
||||
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_auth_context_not_set():
|
||||
"""get_auth_context returns None when auth not set on request."""
|
||||
mock_request = MagicMock()
|
||||
# Make getattr return None (simulating attribute not set)
|
||||
mock_request.state = MagicMock()
|
||||
del mock_request.state.auth
|
||||
assert get_auth_context(mock_request) is None
|
||||
|
||||
|
||||
def test_get_auth_context_set():
|
||||
"""get_auth_context returns the AuthContext from request."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.auth = ctx
|
||||
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_model_has_needs_setup_default_false():
|
||||
"""New users default to needs_setup=False."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.needs_setup is False
|
||||
|
||||
|
||||
def test_user_model_has_token_version_default_zero():
|
||||
"""New users default to token_version=0."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.token_version == 0
|
||||
|
||||
|
||||
def test_user_model_needs_setup_true():
|
||||
"""Auto-created admin has needs_setup=True."""
|
||||
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||
assert user.needs_setup is True
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip.
|
||||
|
||||
Uses the shared persistence engine (same one threads_meta, runs,
|
||||
run_events, and feedback use). The old separate .deer-flow/users.db
|
||||
file is gone.
|
||||
"""
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=tmpdir)
|
||||
try:
|
||||
repo = SQLiteUserRepository(get_session_factory())
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = await repo.get_user_by_email("setup@test.com")
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
await repo.update_user(fetched)
|
||||
refetched = await repo.get_user_by_id(str(fetched.id))
|
||||
assert refetched is not None
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_update_user_raises_when_row_concurrently_deleted(tmp_path):
|
||||
"""Concurrent-delete during update_user must hard-fail, not silently no-op.
|
||||
|
||||
Earlier the SQLite repo returned the input unchanged when the row was
|
||||
missing, making a phantom success path that admin password reset
|
||||
callers (`reset_admin`, `_ensure_admin_user`) would happily log as
|
||||
'password reset'. The new contract: raise ``UserNotFoundError`` so
|
||||
a vanished row never looks like a successful update.
|
||||
"""
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from app.gateway.auth.repositories.base import UserNotFoundError
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
url = f"sqlite+aiosqlite:///{d}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=d)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
repo = SQLiteUserRepository(sf)
|
||||
user = User(
|
||||
email="ghost@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="user",
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
|
||||
# Simulate "row vanished underneath us" by deleting the row
|
||||
# via the raw ORM session, then attempt to update.
|
||||
async with sf() as session:
|
||||
row = await session.get(UserRow, str(created.id))
|
||||
assert row is not None
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
created.needs_setup = True
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 0
|
||||
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
new_password="newpassword",
|
||||
new_email="new@example.com",
|
||||
)
|
||||
assert req.new_email == "new@example.com"
|
||||
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
resp2 = LoginResponse(expires_in=3600)
|
||||
assert resp2.needs_setup is False
|
||||
|
||||
|
||||
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
for _ in range(5):
|
||||
_record_login_failure(ip)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_rate_limit(ip)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
for _ in range(4):
|
||||
_record_login_failure(ip)
|
||||
_record_login_success(ip)
|
||||
_check_rate_limit(ip) # Should not raise
|
||||
|
||||
|
||||
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_client_ip_direct_connection_no_proxy(monkeypatch):
|
||||
"""Direct mode (no AUTH_TRUSTED_PROXIES): use TCP peer regardless of X-Real-IP."""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch):
|
||||
"""X-Real-IP is silently ignored if AUTH_TRUSTED_PROXIES is unset.
|
||||
|
||||
This closes the bypass where any client could rotate X-Real-IP per
|
||||
request to dodge per-IP rate limits in dev / direct mode.
|
||||
"""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch):
|
||||
"""X-Real-IP is honored when the TCP peer matches AUTH_TRUSTED_PROXIES."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.5.6.7" # in trusted CIDR
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch):
|
||||
"""X-Real-IP is rejected when the TCP peer is NOT in the trusted list."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "8.8.8.8" # NOT in trusted CIDR
|
||||
req.headers = {"x-real-ip": "203.0.113.42"} # client trying to spoof
|
||||
assert _get_client_ip(req) == "8.8.8.8"
|
||||
|
||||
|
||||
def test_get_client_ip_xff_never_honored(monkeypatch):
|
||||
"""X-Forwarded-For is never used; only X-Real-IP from a trusted peer."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
req.headers = {"x-forwarded-for": "198.51.100.5"} # no x-real-ip
|
||||
assert _get_client_ip(req) == "10.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog):
|
||||
"""Garbage entries in AUTH_TRUSTED_PROXIES are warned and skipped."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "not-an-ip,10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.5.6.7"
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42" # valid entry still works
|
||||
|
||||
|
||||
def test_get_client_ip_no_client_returns_unknown(monkeypatch):
|
||||
"""No request.client → 'unknown' marker (no crash)."""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client = None
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
# ── Common-password blocklist ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_register_rejects_literal_password():
|
||||
"""Pydantic validator rejects 'password' as a registration password."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
RegisterRequest(email="x@example.com", password="password")
|
||||
assert "too common" in str(exc.value)
|
||||
|
||||
|
||||
def test_register_rejects_common_password_case_insensitive():
|
||||
"""Case variants of common passwords are also rejected."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
for variant in ["PASSWORD", "Password1", "qwerty123", "letmein1"]:
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(email="x@example.com", password=variant)
|
||||
|
||||
|
||||
def test_register_accepts_strong_password():
|
||||
"""A non-blocklisted password of length >=8 is accepted."""
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
req = RegisterRequest(email="x@example.com", password="Tr0ub4dor&3-Horse")
|
||||
assert req.password == "Tr0ub4dor&3-Horse"
|
||||
|
||||
|
||||
def test_change_password_rejects_common_password():
|
||||
"""The same blocklist applies to change-password."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChangePasswordRequest(current_password="anything", new_password="iloveyou")
|
||||
|
||||
|
||||
def test_password_blocklist_keeps_short_passwords_for_length_check():
|
||||
"""Short passwords still fail the min_length check (not the blocklist)."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
RegisterRequest(email="x@example.com", password="abc")
|
||||
# the length check should fire, not the blocklist
|
||||
assert "at least 8 characters" in str(exc.value)
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for AuthConfig typed configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == "test-jwt-secret-from-env"
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Tests for auth error types and typed decode_token."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
|
||||
|
||||
def test_auth_error_code_values():
|
||||
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
|
||||
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
|
||||
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
|
||||
|
||||
|
||||
def test_token_error_values():
|
||||
assert TokenError.EXPIRED == "expired"
|
||||
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
|
||||
assert TokenError.MALFORMED == "malformed"
|
||||
|
||||
|
||||
def test_auth_error_response_serialization():
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.TOKEN_EXPIRED,
|
||||
message="Token has expired",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert d == {"code": "token_expired", "message": "Token has expired"}
|
||||
|
||||
|
||||
def test_auth_error_response_from_dict():
|
||||
d = {"code": "invalid_credentials", "message": "Wrong password"}
|
||||
err = AuthErrorResponse(**d)
|
||||
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
|
||||
|
||||
|
||||
# ── decode_token typed failure tests ──────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_expired():
|
||||
_setup_config()
|
||||
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_bad_signature():
|
||||
_setup_config()
|
||||
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("not-a-jwt")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_returns_payload_on_valid():
|
||||
_setup_config()
|
||||
token = create_access_token("user-123")
|
||||
result = decode_token(token)
|
||||
assert not isinstance(result, TokenError)
|
||||
assert result.sub == "user-123"
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/health",
|
||||
"/health/",
|
||||
"/docs",
|
||||
"/docs/",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
],
|
||||
)
|
||||
def test_public_paths(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models",
|
||||
"/api/mcp/config",
|
||||
"/api/memory",
|
||||
"/api/skills",
|
||||
"/api/threads/123",
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
"/api/v1/auth/change-password",
|
||||
],
|
||||
)
|
||||
def test_protected_paths(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/v1/auth/login/local/",
|
||||
"/api/v1/auth/register/",
|
||||
"/api/v1/auth/logout/",
|
||||
"/api/v1/auth/setup-status/",
|
||||
],
|
||||
)
|
||||
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models/",
|
||||
"/api/v1/auth/me/",
|
||||
"/api/v1/auth/change-password/",
|
||||
],
|
||||
)
|
||||
def test_protected_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
def test_unknown_api_path_is_protected():
|
||||
"""Fail-closed: any new /api/* path is protected by default."""
|
||||
assert _is_public("/api/new-feature") is False
|
||||
assert _is_public("/api/v2/something") is False
|
||||
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||
|
||||
|
||||
# ── Middleware integration tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
return {"needs_setup": False}
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
|
||||
@app.delete("/api/threads/abc")
|
||||
async def thread_delete():
|
||||
return {"ok": True}
|
||||
|
||||
@app.patch("/api/threads/abc")
|
||||
async def thread_patch():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def stream():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/future-endpoint")
|
||||
async def future():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_public_path_no_cookie(client):
|
||||
res = client.get("/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_public_auth_path_no_cookie(client):
|
||||
"""Public auth endpoints (login/register) pass without cookie."""
|
||||
res = client.get("/api/v1/auth/setup-status")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_auth_path_no_cookie(client):
|
||||
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||
res = client.get("/api/v1/auth/me")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_path_no_cookie_returns_401(client):
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
body = res.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_protected_path_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||
tokens through to the route handler."""
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_post_no_cookie_returns_401(client):
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
def test_protected_put_no_cookie(client):
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_delete_no_cookie(client):
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_patch_no_cookie(client):
|
||||
res = client.patch("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_put_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie on PUT → 401 (strict JWT validation in middleware)."""
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_delete_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie on DELETE → 401 (strict JWT validation in middleware)."""
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_unknown_endpoint_with_junk_cookie_rejected(client):
|
||||
"""New endpoints are also protected by strict JWT validation."""
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
@@ -0,0 +1,701 @@
|
||||
"""Tests for auth type system hardening.
|
||||
|
||||
Covers structured error responses, typed decode_token callers,
|
||||
CSRF middleware path matching, config-driven cookie security,
|
||||
and unhappy paths / edge cases for all auth boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
is_auth_endpoint,
|
||||
should_check_csrf,
|
||||
)
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _persistence_engine(tmp_path):
|
||||
"""Initialise a per-test SQLite engine + reset cached provider singletons.
|
||||
|
||||
The auth tests call real HTTP handlers that go through
|
||||
``SQLiteUserRepository`` → ``get_session_factory``. Each test gets
|
||||
a fresh DB plus a clean ``deps._cached_*`` so the cached provider
|
||||
does not hold a dangling reference to the previous test's engine.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal request mock for CSRF path matching tests."""
|
||||
|
||||
def __init__(self, path: str, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
class _URL:
|
||||
def __init__(self, p):
|
||||
self.path = p
|
||||
|
||||
self.url = _URL(path)
|
||||
self.cookies = {}
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local():
|
||||
"""login/local (actual route) should be exempt from CSRF."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local_trailing_slash():
|
||||
"""Trailing slash should also be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_logout():
|
||||
req = _FakeRequest("/api/v1/auth/logout")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_register():
|
||||
req = _FakeRequest("/api/v1/auth/register")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_old_login_path():
|
||||
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_me():
|
||||
req = _FakeRequest("/api/v1/auth/me")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_skips_get_requests():
|
||||
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||
assert should_check_csrf(req) is False
|
||||
|
||||
|
||||
def test_csrf_checks_post_to_protected():
|
||||
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||
assert should_check_csrf(req) is True
|
||||
|
||||
|
||||
# ── Structured Error Response Format ────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_error_response_has_code_and_message():
|
||||
"""All auth errors should have structured {code, message} format."""
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Wrong password",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert "code" in d
|
||||
assert "message" in d
|
||||
assert d["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_auth_error_response_all_codes_serializable():
|
||||
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||
for code in AuthErrorCode:
|
||||
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||
d = err.model_dump()
|
||||
assert d["code"] == code.value
|
||||
|
||||
|
||||
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_expired_maps_to_token_expired_code():
|
||||
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
# Verify the mapping pattern used in route handlers
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
result = decode_token("garbage")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
# ── Login Response Format ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
assert "access_token" not in d
|
||||
assert "expires_in" in d
|
||||
assert d["expires_in"] == 604800
|
||||
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
|
||||
|
||||
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
assert resp.expires_in == expected_seconds
|
||||
|
||||
|
||||
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
assert resp.system_role == "admin"
|
||||
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||
assert resp.system_role == "user"
|
||||
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# UNHAPPY PATHS / EDGE CASES
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
# ── get_current_user structured 401 responses ────────────────────────
|
||||
|
||||
|
||||
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
"""No cookie → 401 with code=not_authenticated."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_get_current_user_expired_token_returns_token_expired():
|
||||
"""Expired token → 401 with code=token_expired."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
"""Bad signature → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
"""Garbage token → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
# ── decode_token edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_empty_string_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_whitespace_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token(" ")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_missing_jwt_secret_raises():
|
||||
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig()
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_zero_raises():
|
||||
"""token_expiry_days must be >= 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_31_raises():
|
||||
"""token_expiry_days must be <= 30."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_1_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||
assert config.token_expiry_days == 1
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_30_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||
assert config.token_expiry_days == 30
|
||||
|
||||
|
||||
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
|
||||
|
||||
def _make_csrf_app():
|
||||
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from fastapi.responses import JSONResponse as _JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(_HTTPException)
|
||||
async def _http_exc_handler(request, exc):
|
||||
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/test/protected")
|
||||
async def protected():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/v1/test/read")
|
||||
async def read_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_without_token():
|
||||
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/test/protected")
|
||||
assert resp.status_code == 403
|
||||
assert "CSRF" in resp.json()["detail"]
|
||||
assert "missing" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: "token-b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_post_with_matching_token():
|
||||
"""POST with matching CSRF cookie/header → 200."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
token = secrets.token_urlsafe(64)
|
||||
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_get_without_token():
|
||||
"""GET requests bypass CSRF check."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.get("/api/v1/test/read")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_exempts_login_local():
|
||||
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert CSRF_COOKIE_NAME in resp.cookies
|
||||
|
||||
|
||||
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# HTTP-LEVEL API CONTRACT TESTS
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_auth_app():
|
||||
"""Create FastAPI app with auth routes for contract testing."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
|
||||
|
||||
def _unique_email(prefix: str) -> str:
|
||||
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||
|
||||
|
||||
def _get_set_cookie_headers(resp) -> list[str]:
|
||||
"""Extract all set-cookie header values from a TestClient response."""
|
||||
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||
|
||||
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
@@ -231,7 +231,7 @@ class TestResolveAttachments:
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
def resolve_side_effect(tid, vpath):
|
||||
def resolve_side_effect(tid, vpath, *, user_id=None):
|
||||
if "data.csv" in vpath:
|
||||
return good_file
|
||||
return tmp_path / "missing.txt"
|
||||
|
||||
@@ -5,25 +5,21 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
|
||||
|
||||
def _make_config(checkpointer: CheckpointerConfig | None = None) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset singleton state before each test."""
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
yield
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
|
||||
|
||||
@@ -33,24 +29,18 @@ def reset_state():
|
||||
|
||||
|
||||
class TestCheckpointerConfig:
|
||||
def test_load_memory_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
def test_memory_config(self):
|
||||
config = CheckpointerConfig(type="memory")
|
||||
assert config.type == "memory"
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_load_sqlite_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
def test_sqlite_config(self):
|
||||
config = CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")
|
||||
assert config.type == "sqlite"
|
||||
assert config.connection_string == "/tmp/test.db"
|
||||
|
||||
def test_load_postgres_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
def test_postgres_config(self):
|
||||
config = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
|
||||
assert config.type == "postgres"
|
||||
assert config.connection_string == "postgresql://localhost/db"
|
||||
|
||||
@@ -58,14 +48,9 @@ class TestCheckpointerConfig:
|
||||
config = CheckpointerConfig(type="memory")
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_set_config_to_none(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
set_checkpointer_config(None)
|
||||
assert get_checkpointer_config() is None
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
CheckpointerConfig(type="unknown")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -78,58 +63,58 @@ class TestGetCheckpointer:
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
cfg = _make_config()
|
||||
cp = get_checkpointer(cfg)
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_returns_in_memory_saver(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
cp = get_checkpointer()
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
cp = get_checkpointer(cfg)
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
cp1 = get_checkpointer(cfg)
|
||||
cp2 = get_checkpointer(cfg)
|
||||
assert cp1 is cp2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
cp1 = get_checkpointer(cfg)
|
||||
reset_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
cp2 = get_checkpointer(cfg)
|
||||
assert cp1 is not cp2
|
||||
|
||||
def test_sqlite_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(cfg)
|
||||
|
||||
def test_postgres_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(cfg)
|
||||
|
||||
def test_postgres_raises_when_connection_string_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres"})
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres"))
|
||||
mock_saver = MagicMock()
|
||||
mock_module = MagicMock()
|
||||
mock_module.PostgresSaver = mock_saver
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ValueError, match="connection_string is required"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(cfg)
|
||||
|
||||
def test_sqlite_creates_saver(self):
|
||||
"""SQLite checkpointer is created when package is available."""
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
@@ -144,7 +129,7 @@ class TestGetCheckpointer:
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(cfg)
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once()
|
||||
@@ -225,7 +210,7 @@ class TestGetCheckpointer:
|
||||
|
||||
def test_postgres_creates_saver(self):
|
||||
"""Postgres checkpointer is created when packages are available."""
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
@@ -240,7 +225,7 @@ class TestGetCheckpointer:
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(cfg)
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
|
||||
@@ -251,7 +236,7 @@ class TestAsyncCheckpointer:
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||
"""Async SQLite setup should move mkdir off the event loop."""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
||||
@@ -268,15 +253,14 @@ class TestAsyncCheckpointer:
|
||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch(
|
||||
"deerflow.agents.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
return_value="/tmp/resolved/test.db",
|
||||
),
|
||||
):
|
||||
async with make_checkpointer() as saver:
|
||||
async with make_checkpointer(mock_config) as saver:
|
||||
assert saver is mock_saver
|
||||
|
||||
mock_to_thread.assert_awaited_once()
|
||||
@@ -294,12 +278,10 @@ class TestAsyncCheckpointer:
|
||||
|
||||
class TestAppConfigLoadsCheckpointer:
|
||||
def test_load_checkpointer_section(self):
|
||||
"""load_checkpointer_config_from_dict populates the global config."""
|
||||
set_checkpointer_config(None)
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cfg = get_checkpointer_config()
|
||||
assert cfg is not None
|
||||
assert cfg.type == "memory"
|
||||
"""AppConfig with checkpointer section has the correct config."""
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
assert cfg.checkpointer is not None
|
||||
assert cfg.checkpointer.type == "memory"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -309,69 +291,7 @@ class TestAppConfigLoadsCheckpointer:
|
||||
|
||||
class TestClientCheckpointerFallback:
|
||||
def test_client_uses_config_checkpointer_when_none_provided(self):
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("deerflow.client.get_app_config", return_value=config_mock),
|
||||
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value=""),
|
||||
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=None)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert "checkpointer" in captured_kwargs
|
||||
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
|
||||
|
||||
def test_client_explicit_checkpointer_takes_precedence(self):
|
||||
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
explicit_cp = MagicMock()
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("deerflow.client.get_app_config", return_value=config_mock),
|
||||
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value=""),
|
||||
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=explicit_cp)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert captured_kwargs["checkpointer"] is explicit_cp
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer(app_config) when checkpointer=None."""
|
||||
# This is a structural test — verifying the fallback path exists.
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
assert cfg.checkpointer is not None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test for issue #1016: checkpointer should not return None."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
@@ -12,43 +12,40 @@ class TestCheckpointerNoneFix:
|
||||
@pytest.mark.anyio
|
||||
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
mock_config.database = None
|
||||
|
||||
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
async with make_checkpointer(mock_config) as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
# Should be able to call alist() without AttributeError
|
||||
# This is what LangGraph does and what was failing in issue #1016
|
||||
result = []
|
||||
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
|
||||
result.append(item)
|
||||
# Should be able to call alist() without AttributeError
|
||||
# This is what LangGraph does and what was failing in issue #1016
|
||||
result = []
|
||||
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
|
||||
result.append(item)
|
||||
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
|
||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.agents.checkpointer.provider import checkpointer_context
|
||||
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with checkpointer_context() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
with checkpointer_context(mock_config) as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
# Should be able to call list() without AttributeError
|
||||
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
|
||||
# Should be able to call list() without AttributeError
|
||||
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
|
||||
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
|
||||
+131
-100
@@ -18,6 +18,7 @@ from app.gateway.routers.models import ModelResponse, ModelsListResponse
|
||||
from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
|
||||
from app.gateway.routers.uploads import UploadResponse
|
||||
from deerflow.client import DeerFlowClient
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.uploads.manager import PathTraversalError
|
||||
|
||||
@@ -44,9 +45,12 @@ def mock_app_config():
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_app_config):
|
||||
"""Create a DeerFlowClient with mocked config loading."""
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
return DeerFlowClient()
|
||||
"""Create a DeerFlowClient holding the mocked config directly.
|
||||
|
||||
Passing ``config=`` is the documented post-refactor way to inject a
|
||||
test AppConfig; nothing relies on process-global state.
|
||||
"""
|
||||
return DeerFlowClient(config=mock_app_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -67,8 +71,7 @@ class TestClientInit:
|
||||
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
@@ -78,24 +81,21 @@ class TestClientInit:
|
||||
assert c._middlewares == [mock_middleware]
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="invalid name with spaces!")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="../path/traversal")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="invalid name with spaces!")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="../path/traversal")
|
||||
|
||||
def test_custom_config_path(self, mock_app_config):
|
||||
with (
|
||||
patch("deerflow.client.reload_app_config") as mock_reload,
|
||||
patch("deerflow.client.get_app_config", return_value=mock_app_config),
|
||||
):
|
||||
DeerFlowClient(config_path="/tmp/custom.yaml")
|
||||
mock_reload.assert_called_once_with("/tmp/custom.yaml")
|
||||
# rather than touching AppConfig.init() / process-global state.
|
||||
with patch.object(AppConfig, "from_file", return_value=mock_app_config) as mock_from_file:
|
||||
client = DeerFlowClient(config_path="/tmp/custom.yaml")
|
||||
mock_from_file.assert_called_once_with("/tmp/custom.yaml")
|
||||
assert client._app_config is mock_app_config
|
||||
|
||||
def test_checkpointer_stored(self, mock_app_config):
|
||||
cp = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(checkpointer=cp)
|
||||
c = DeerFlowClient(checkpointer=cp)
|
||||
assert c._checkpointer is cp
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class TestConfigQueries:
|
||||
|
||||
with patch("deerflow.skills.loader.load_skills", return_value=[skill]) as mock_load:
|
||||
result = client.list_skills()
|
||||
mock_load.assert_called_once_with(enabled_only=False)
|
||||
mock_load.assert_called_once_with(client._app_config, enabled_only=False)
|
||||
|
||||
assert "skills" in result
|
||||
assert len(result["skills"]) == 1
|
||||
@@ -141,7 +141,7 @@ class TestConfigQueries:
|
||||
def test_list_skills_enabled_only(self, client):
|
||||
with patch("deerflow.skills.loader.load_skills", return_value=[]) as mock_load:
|
||||
client.list_skills(enabled_only=True)
|
||||
mock_load.assert_called_once_with(enabled_only=True)
|
||||
mock_load.assert_called_once_with(client._app_config, enabled_only=True)
|
||||
|
||||
def test_get_memory(self, client):
|
||||
memory = {"version": "1.0", "facts": []}
|
||||
@@ -251,8 +251,8 @@ class TestStream:
|
||||
# Verify context passed to agent.stream
|
||||
agent.stream.assert_called_once()
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert call_kwargs["context"]["thread_id"] == "t1"
|
||||
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
|
||||
ctx = call_kwargs["context"]
|
||||
assert ctx.app_config is client._app_config
|
||||
|
||||
def test_custom_mode_is_normalized_to_string(self, client):
|
||||
"""stream() forwards custom events even when the mode is not a plain string."""
|
||||
@@ -819,7 +819,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._available_skills = {"test_skill"}
|
||||
@@ -844,7 +844,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -869,7 +869,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -888,7 +888,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=None),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -1017,7 +1017,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
# No internal checkpointer, should fetch from provider
|
||||
result = client.list_threads()
|
||||
|
||||
@@ -1071,7 +1071,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
result = client.get_thread("t99")
|
||||
|
||||
assert result["thread_id"] == "t99"
|
||||
@@ -1091,8 +1091,8 @@ class TestMcpConfig:
|
||||
ext_config = MagicMock()
|
||||
ext_config.mcp_servers = {"github": server}
|
||||
|
||||
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
|
||||
result = client.get_mcp_config()
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
result = client.get_mcp_config()
|
||||
|
||||
assert "mcp_servers" in result
|
||||
assert "github" in result["mcp_servers"]
|
||||
@@ -1116,10 +1116,11 @@ class TestMcpConfig:
|
||||
# Pre-set agent to verify it gets invalidated
|
||||
client._agent = MagicMock()
|
||||
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
|
||||
patch("deerflow.client.get_extensions_config", return_value=current_config),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
|
||||
):
|
||||
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
|
||||
|
||||
@@ -1177,12 +1178,12 @@ class TestSkillsManagement:
|
||||
try:
|
||||
# Pre-set agent to verify it gets invalidated
|
||||
client._agent = MagicMock()
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
result = client.update_skill("test-skill", enabled=False)
|
||||
assert result["enabled"] is False
|
||||
@@ -1243,7 +1244,10 @@ class TestMemoryManagement:
|
||||
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
||||
result = client.import_memory(imported)
|
||||
|
||||
mock_import.assert_called_once_with(imported)
|
||||
assert mock_import.call_count == 1
|
||||
call_args = mock_import.call_args
|
||||
assert call_args.args == (client._app_config.memory, imported)
|
||||
assert "user_id" in call_args.kwargs
|
||||
assert result == imported
|
||||
|
||||
def test_reload_memory(self, client):
|
||||
@@ -1267,6 +1271,7 @@ class TestMemoryManagement:
|
||||
confidence=0.88,
|
||||
)
|
||||
create_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
content="User prefers concise code reviews.",
|
||||
category="preference",
|
||||
confidence=0.88,
|
||||
@@ -1277,7 +1282,7 @@ class TestMemoryManagement:
|
||||
data = {"version": "1.0", "facts": []}
|
||||
with patch("deerflow.agents.memory.updater.delete_memory_fact", return_value=data) as delete_fact:
|
||||
result = client.delete_memory_fact("fact_123")
|
||||
delete_fact.assert_called_once_with("fact_123")
|
||||
delete_fact.assert_called_once_with(client._app_config.memory, "fact_123")
|
||||
assert result == data
|
||||
|
||||
def test_update_memory_fact(self, client):
|
||||
@@ -1290,6 +1295,7 @@ class TestMemoryManagement:
|
||||
confidence=0.91,
|
||||
)
|
||||
update_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
fact_id="fact_123",
|
||||
content="User prefers spaces",
|
||||
category="workflow",
|
||||
@@ -1305,6 +1311,7 @@ class TestMemoryManagement:
|
||||
"User prefers spaces",
|
||||
)
|
||||
update_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
fact_id="fact_123",
|
||||
content="User prefers spaces",
|
||||
category=None,
|
||||
@@ -1313,37 +1320,40 @@ class TestMemoryManagement:
|
||||
assert result == data
|
||||
|
||||
def test_get_memory_config(self, client):
|
||||
config = MagicMock()
|
||||
config.enabled = True
|
||||
config.storage_path = ".deer-flow/memory.json"
|
||||
config.debounce_seconds = 30
|
||||
config.max_facts = 100
|
||||
config.fact_confidence_threshold = 0.7
|
||||
config.injection_enabled = True
|
||||
config.max_injection_tokens = 2000
|
||||
mem_config = MagicMock()
|
||||
mem_config.enabled = True
|
||||
mem_config.storage_path = ".deer-flow/memory.json"
|
||||
mem_config.debounce_seconds = 30
|
||||
mem_config.max_facts = 100
|
||||
mem_config.fact_confidence_threshold = 0.7
|
||||
mem_config.injection_enabled = True
|
||||
mem_config.max_injection_tokens = 2000
|
||||
|
||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=config):
|
||||
result = client.get_memory_config()
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_config
|
||||
|
||||
client._app_config = app_cfg
|
||||
result = client.get_memory_config()
|
||||
|
||||
assert result["enabled"] is True
|
||||
assert result["max_facts"] == 100
|
||||
|
||||
def test_get_memory_status(self, client):
|
||||
config = MagicMock()
|
||||
config.enabled = True
|
||||
config.storage_path = ".deer-flow/memory.json"
|
||||
config.debounce_seconds = 30
|
||||
config.max_facts = 100
|
||||
config.fact_confidence_threshold = 0.7
|
||||
config.injection_enabled = True
|
||||
config.max_injection_tokens = 2000
|
||||
mem_config = MagicMock()
|
||||
mem_config.enabled = True
|
||||
mem_config.storage_path = ".deer-flow/memory.json"
|
||||
mem_config.debounce_seconds = 30
|
||||
mem_config.max_facts = 100
|
||||
mem_config.fact_confidence_threshold = 0.7
|
||||
mem_config.injection_enabled = True
|
||||
mem_config.max_injection_tokens = 2000
|
||||
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_config
|
||||
data = {"version": "1.0", "facts": []}
|
||||
|
||||
with (
|
||||
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=data),
|
||||
):
|
||||
client._app_config = app_cfg
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=data):
|
||||
result = client.get_memory_status()
|
||||
|
||||
assert "config" in result
|
||||
@@ -1489,9 +1499,12 @@ class TestUploads:
|
||||
|
||||
class TestArtifacts:
|
||||
def test_get_artifact(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
outputs = paths.sandbox_outputs_dir("t1")
|
||||
user_id = get_effective_user_id()
|
||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||
outputs.mkdir(parents=True)
|
||||
(outputs / "result.txt").write_text("artifact content")
|
||||
|
||||
@@ -1502,9 +1515,12 @@ class TestArtifacts:
|
||||
assert "text" in mime
|
||||
|
||||
def test_get_artifact_not_found(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
@@ -1515,9 +1531,12 @@ class TestArtifacts:
|
||||
client.get_artifact("t1", "bad/path/file.txt")
|
||||
|
||||
def test_get_artifact_path_traversal(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
with pytest.raises(PathTraversalError):
|
||||
@@ -1701,13 +1720,16 @@ class TestScenarioFileLifecycle:
|
||||
|
||||
def test_upload_then_read_artifact(self, client):
|
||||
"""Upload a file, simulate agent producing artifact, read it back."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
|
||||
paths = Paths(base_dir=tmp_path)
|
||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact")
|
||||
user_id = get_effective_user_id()
|
||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact", user_id=user_id)
|
||||
outputs_dir.mkdir(parents=True)
|
||||
|
||||
# Upload phase
|
||||
@@ -1785,10 +1807,10 @@ class TestScenarioConfigManagement:
|
||||
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
|
||||
|
||||
client._agent = MagicMock() # Simulate existing agent
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=current_config),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
|
||||
):
|
||||
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
|
||||
assert "my-mcp" in mcp_result["mcp_servers"]
|
||||
@@ -1817,8 +1839,7 @@ class TestScenarioConfigManagement:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
skill_result = client.update_skill("code-gen", enabled=False)
|
||||
assert skill_result["enabled"] is False
|
||||
@@ -1846,7 +1867,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config_a)
|
||||
first_agent = client._agent
|
||||
@@ -1874,7 +1895,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client._ensure_agent(config)
|
||||
@@ -1899,7 +1920,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client.reset_agent()
|
||||
@@ -1957,11 +1978,14 @@ class TestScenarioThreadIsolation:
|
||||
|
||||
def test_artifacts_isolated_per_thread(self, client):
|
||||
"""Artifacts in thread-A are not accessible from thread-B."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
outputs_a = paths.sandbox_outputs_dir("thread-a")
|
||||
user_id = get_effective_user_id()
|
||||
outputs_a = paths.sandbox_outputs_dir("thread-a", user_id=user_id)
|
||||
outputs_a.mkdir(parents=True)
|
||||
paths.sandbox_user_data_dir("thread-b").mkdir(parents=True)
|
||||
paths.sandbox_outputs_dir("thread-b", user_id=user_id).mkdir(parents=True)
|
||||
(outputs_a / "result.txt").write_text("thread-a artifact")
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
@@ -2003,10 +2027,10 @@ class TestScenarioMemoryWorkflow:
|
||||
refreshed = client.reload_memory()
|
||||
assert len(refreshed["facts"]) == 2
|
||||
|
||||
with (
|
||||
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data),
|
||||
):
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = config
|
||||
client._app_config = app_cfg
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data):
|
||||
status = client.get_memory_status()
|
||||
assert status["config"]["enabled"] is True
|
||||
assert len(status["data"]["facts"]) == 2
|
||||
@@ -2067,8 +2091,7 @@ class TestScenarioSkillInstallAndUse:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
toggled = client.update_skill("my-analyzer", enabled=False)
|
||||
assert toggled["enabled"] is False
|
||||
@@ -2202,8 +2225,7 @@ class TestGatewayConformance:
|
||||
mock_app_config.models = [model]
|
||||
mock_app_config.token_usage.enabled = True
|
||||
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
client = DeerFlowClient()
|
||||
client = DeerFlowClient(config=mock_app_config)
|
||||
|
||||
result = client.list_models()
|
||||
parsed = ModelsListResponse(**result)
|
||||
@@ -2222,8 +2244,7 @@ class TestGatewayConformance:
|
||||
mock_app_config.models = [model]
|
||||
mock_app_config.get_model_config.return_value = model
|
||||
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
client = DeerFlowClient()
|
||||
client = DeerFlowClient(config=mock_app_config)
|
||||
|
||||
result = client.get_model("test-model")
|
||||
assert result is not None
|
||||
@@ -2292,8 +2313,8 @@ class TestGatewayConformance:
|
||||
ext_config = MagicMock()
|
||||
ext_config.mcp_servers = {"test": server}
|
||||
|
||||
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
|
||||
result = client.get_mcp_config()
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
result = client.get_mcp_config()
|
||||
|
||||
parsed = McpConfigResponse(**result)
|
||||
assert "test" in parsed.mcp_servers
|
||||
@@ -2317,10 +2338,10 @@ class TestGatewayConformance:
|
||||
config_file = tmp_path / "extensions_config.json"
|
||||
config_file.write_text("{}")
|
||||
|
||||
client._app_config = MagicMock(extensions=ext_config)
|
||||
with (
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=ext_config)),
|
||||
):
|
||||
result = client.update_mcp_config({"srv": server.model_dump.return_value})
|
||||
|
||||
@@ -2351,8 +2372,11 @@ class TestGatewayConformance:
|
||||
mem_cfg.injection_enabled = True
|
||||
mem_cfg.max_injection_tokens = 2000
|
||||
|
||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
||||
result = client.get_memory_config()
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_cfg
|
||||
|
||||
client._app_config = app_cfg
|
||||
result = client.get_memory_config()
|
||||
|
||||
parsed = MemoryConfigResponse(**result)
|
||||
assert parsed.enabled is True
|
||||
@@ -2368,6 +2392,8 @@ class TestGatewayConformance:
|
||||
mem_cfg.injection_enabled = True
|
||||
mem_cfg.max_injection_tokens = 2000
|
||||
|
||||
app_cfg = MagicMock()
|
||||
app_cfg.memory = mem_cfg
|
||||
memory_data = {
|
||||
"version": "1.0",
|
||||
"lastUpdated": "",
|
||||
@@ -2384,10 +2410,8 @@ class TestGatewayConformance:
|
||||
"facts": [],
|
||||
}
|
||||
|
||||
with (
|
||||
patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data),
|
||||
):
|
||||
client._app_config = app_cfg
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data):
|
||||
result = client.get_memory_status()
|
||||
|
||||
parsed = MemoryStatusResponse(**result)
|
||||
@@ -2676,8 +2700,7 @@ class TestConfigUpdateErrors:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="disappeared"):
|
||||
client.update_skill("ghost-skill", enabled=False)
|
||||
@@ -2869,9 +2892,12 @@ class TestUploadDeleteSymlink:
|
||||
class TestArtifactHardening:
|
||||
def test_artifact_directory_rejected(self, client):
|
||||
"""get_artifact rejects paths that resolve to a directory."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
subdir = paths.sandbox_outputs_dir("t1") / "subdir"
|
||||
user_id = get_effective_user_id()
|
||||
subdir = paths.sandbox_outputs_dir("t1", user_id=user_id) / "subdir"
|
||||
subdir.mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
@@ -2880,9 +2906,12 @@ class TestArtifactHardening:
|
||||
|
||||
def test_artifact_leading_slash_stripped(self, client):
|
||||
"""Paths with leading slash are handled correctly."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
outputs = paths.sandbox_outputs_dir("t1")
|
||||
user_id = get_effective_user_id()
|
||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||
outputs.mkdir(parents=True)
|
||||
(outputs / "file.txt").write_text("content")
|
||||
|
||||
@@ -2996,9 +3025,12 @@ class TestBugArtifactPrefixMatchTooLoose:
|
||||
|
||||
def test_exact_prefix_without_subpath_accepted(self, client):
|
||||
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
# Accepted at prefix check, but fails because it's a directory.
|
||||
@@ -3047,10 +3079,10 @@ class TestBugAgentInvalidationInconsistency:
|
||||
config_file = Path(tmp) / "ext.json"
|
||||
config_file.write_text("{}")
|
||||
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=current_config),
|
||||
patch("deerflow.client.reload_extensions_config", return_value=reloaded),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),
|
||||
):
|
||||
client.update_mcp_config({})
|
||||
|
||||
@@ -3082,8 +3114,7 @@ class TestBugAgentInvalidationInconsistency:
|
||||
with (
|
||||
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.client.get_extensions_config", return_value=ext_config),
|
||||
patch("deerflow.client.reload_extensions_config"),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
|
||||
):
|
||||
client.update_skill("s1", enabled=False)
|
||||
|
||||
|
||||
@@ -56,6 +56,10 @@ def _make_e2e_config() -> AppConfig:
|
||||
- ``E2E_BASE_URL`` (default: ``https://ark-cn-beijing.bytedance.net/api/v3``)
|
||||
- ``OPENAI_API_KEY`` (required for LLM tests)
|
||||
"""
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
return AppConfig(
|
||||
models=[
|
||||
ModelConfig(
|
||||
@@ -73,6 +77,9 @@ def _make_e2e_config() -> AppConfig:
|
||||
)
|
||||
],
|
||||
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", allow_host_bash=True),
|
||||
title=TitleConfig(enabled=False),
|
||||
memory=MemoryConfig(enabled=False),
|
||||
summarization=SummarizationConfig(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
@@ -87,7 +94,7 @@ def e2e_env(tmp_path, monkeypatch):
|
||||
|
||||
- DEER_FLOW_HOME → tmp_path (all thread data lands in a temp dir)
|
||||
- Singletons reset so they pick up the new env
|
||||
- Title/memory/summarization disabled to avoid extra LLM calls
|
||||
- Title/memory/summarization disabled via AppConfig fields
|
||||
- AppConfig built programmatically (avoids config.yaml param-name issues)
|
||||
"""
|
||||
# 1. Filesystem isolation
|
||||
@@ -95,30 +102,12 @@ def e2e_env(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("deerflow.config.paths._paths", None)
|
||||
monkeypatch.setattr("deerflow.sandbox.sandbox_provider._default_sandbox_provider", None)
|
||||
|
||||
# 2. Inject a clean AppConfig via the global singleton.
|
||||
config = _make_e2e_config()
|
||||
monkeypatch.setattr("deerflow.config.app_config._app_config", config)
|
||||
monkeypatch.setattr("deerflow.config.app_config._app_config_is_custom", True)
|
||||
# 1b. Override the autouse ``AppConfig.from_file`` stub from conftest
|
||||
# (minimal test config) with the e2e-specific config that carries a
|
||||
# real model entry and disables title/memory/summarization.
|
||||
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda config_path=None: _make_e2e_config()))
|
||||
|
||||
# 3. Disable title generation (extra LLM call, non-deterministic)
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
monkeypatch.setattr("deerflow.config.title_config._title_config", TitleConfig(enabled=False))
|
||||
|
||||
# 4. Disable memory queueing (avoids background threads & file writes)
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.memory_middleware.get_memory_config",
|
||||
lambda: MemoryConfig(enabled=False),
|
||||
)
|
||||
|
||||
# 5. Ensure summarization is off (default, but be explicit)
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
|
||||
monkeypatch.setattr("deerflow.config.summarization_config._summarization_config", SummarizationConfig(enabled=False))
|
||||
|
||||
# 6. Exclude TitleMiddleware from the chain.
|
||||
# 2. Exclude TitleMiddleware from the chain.
|
||||
# It triggers an extra LLM call to generate a thread title, which adds
|
||||
# non-determinism and cost to E2E tests (title generation is already
|
||||
# disabled via TitleConfig above, but the middleware still participates
|
||||
@@ -262,8 +251,9 @@ class TestFileUploadIntegration:
|
||||
|
||||
# Physically exists
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
assert (get_paths().sandbox_uploads_dir(tid) / "readme.txt").exists()
|
||||
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
|
||||
|
||||
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
||||
"""Uploading two files with the same name auto-renames the second."""
|
||||
@@ -472,12 +462,13 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_happy_path(self, e2e_env):
|
||||
"""Write a file to outputs, then read it back via get_artifact()."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
|
||||
# Create an output file in the thread's outputs directory
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||
outputs_dir.mkdir(parents=True, exist_ok=True)
|
||||
(outputs_dir / "result.txt").write_text("hello artifact")
|
||||
|
||||
@@ -488,11 +479,12 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_nested_path(self, e2e_env):
|
||||
"""Artifacts in subdirectories are accessible."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||
sub = outputs_dir / "charts"
|
||||
sub.mkdir(parents=True, exist_ok=True)
|
||||
(sub / "data.json").write_text('{"x": 1}')
|
||||
@@ -663,10 +655,9 @@ class TestConfigManagement:
|
||||
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
|
||||
|
||||
# Force reload so the singleton picks up our test file
|
||||
from deerflow.config.extensions_config import reload_extensions_config
|
||||
|
||||
reload_extensions_config()
|
||||
# Mock from_file so update_mcp_config's internal reload works without config.yaml
|
||||
e2e_config = _make_e2e_config()
|
||||
monkeypatch.setattr(AppConfig, "from_file", classmethod(lambda cls, path=None: e2e_config))
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
# Simulate a cached agent
|
||||
@@ -690,9 +681,9 @@ class TestConfigManagement:
|
||||
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
|
||||
|
||||
from deerflow.config.extensions_config import reload_extensions_config
|
||||
|
||||
reload_extensions_config()
|
||||
# Mock from_file so update_skill's internal reload works without config.yaml
|
||||
e2e_config = _make_e2e_config()
|
||||
monkeypatch.setattr(AppConfig, "from_file", classmethod(lambda cls, path=None: e2e_config))
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
c._agent = "fake-agent-placeholder"
|
||||
@@ -718,10 +709,6 @@ class TestConfigManagement:
|
||||
config_file.write_text(json.dumps({"mcpServers": {}, "skills": {}}))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(config_file))
|
||||
|
||||
from deerflow.config.extensions_config import reload_extensions_config
|
||||
|
||||
reload_extensions_config()
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
c.update_skill("nonexistent-skill-xyz", enabled=True)
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestLiveStreaming:
|
||||
class TestLiveToolUse:
|
||||
def test_agent_uses_bash_tool(self, client):
|
||||
"""Agent uses bash tool when asked to run a command."""
|
||||
if not is_host_bash_allowed():
|
||||
if not is_host_bash_allowed(client._app_config):
|
||||
pytest.skip("Host bash is disabled for LocalSandboxProvider in the active config")
|
||||
|
||||
events = list(client.stream("Use the bash tool to run: echo 'LIVE_TEST_OK'. Then tell me the output."))
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Multi-client isolation regression test.
|
||||
|
||||
Phase 2 Task P2-3: ``DeerFlowClient`` now captures its ``AppConfig`` in the
|
||||
constructor instead of going through a process-global config.
|
||||
This test pins the resulting invariant: two clients with different configs
|
||||
can coexist without contending over shared state.
|
||||
|
||||
Before P2-3, the shared ``AppConfig._global`` caused the second client's
|
||||
``init()`` to clobber the first client's config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_agent_creation(monkeypatch):
|
||||
"""Prevent lazy agent creation — we only care about config access."""
|
||||
monkeypatch.setattr(DeerFlowClient, "_get_or_create_agent", MagicMock(), raising=False)
|
||||
|
||||
|
||||
def test_two_clients_do_not_clobber_each_other(disable_agent_creation):
|
||||
"""Two clients with distinct configs keep their own AppConfig."""
|
||||
cfg_a = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
memory=MemoryConfig(enabled=True),
|
||||
)
|
||||
cfg_b = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
memory=MemoryConfig(enabled=False),
|
||||
)
|
||||
|
||||
client_a = DeerFlowClient(config=cfg_a)
|
||||
client_b = DeerFlowClient(config=cfg_b)
|
||||
|
||||
# Identity: each client retains its own instance, not a shared ref
|
||||
assert client_a._app_config is cfg_a
|
||||
assert client_b._app_config is cfg_b
|
||||
|
||||
# Semantic: memory flag differs
|
||||
assert client_a._app_config.memory.enabled is True
|
||||
assert client_b._app_config.memory.enabled is False
|
||||
|
||||
|
||||
def test_client_config_precedes_path(disable_agent_creation, tmp_path):
|
||||
"""When both config= and config_path= are given, config= wins."""
|
||||
cfg = AppConfig(sandbox=SandboxConfig(use="test"), log_level="debug")
|
||||
|
||||
# config_path points at a file that doesn't exist — proves it's unused
|
||||
bogus_path = str(tmp_path / "nope.yaml")
|
||||
client = DeerFlowClient(config_path=bogus_path, config=cfg)
|
||||
|
||||
assert client._app_config is cfg
|
||||
assert client._app_config.log_level == "debug"
|
||||
|
||||
|
||||
def test_multi_client_gateway_dict_returns_distinct(disable_agent_creation):
|
||||
"""get_mcp_config() reads from self._app_config, not process-global."""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
ext_a = ExtensionsConfig(mcp_servers={"server-a": McpServerConfig(enabled=True)})
|
||||
ext_b = ExtensionsConfig(mcp_servers={"server-b": McpServerConfig(enabled=True)})
|
||||
|
||||
cfg_a = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ext_a)
|
||||
cfg_b = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ext_b)
|
||||
|
||||
client_a = DeerFlowClient(config=cfg_a)
|
||||
client_b = DeerFlowClient(config=cfg_b)
|
||||
|
||||
servers_a = client_a.get_mcp_config()["mcp_servers"]
|
||||
servers_b = client_b.get_mcp_config()["mcp_servers"]
|
||||
|
||||
assert set(servers_a.keys()) == {"server-a"}
|
||||
assert set(servers_b.keys()) == {"server-b"}
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Verify that all sub-config Pydantic models are frozen (immutable).
|
||||
|
||||
Frozen models reject attribute assignment after construction, raising
|
||||
pydantic.ValidationError. This test collects every BaseModel subclass
|
||||
defined in the deerflow.config package and asserts that mutation is
|
||||
blocked.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import pkgutil
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import deerflow.config as config_pkg
|
||||
|
||||
|
||||
def _collect_config_models() -> list[type[BaseModel]]:
|
||||
"""Walk deerflow.config.* and return all concrete BaseModel subclasses."""
|
||||
import importlib
|
||||
|
||||
models: list[type[BaseModel]] = []
|
||||
package_path = config_pkg.__path__
|
||||
package_prefix = config_pkg.__name__ + "."
|
||||
|
||||
for _importer, modname, _ispkg in pkgutil.walk_packages(package_path, prefix=package_prefix):
|
||||
try:
|
||||
mod = importlib.import_module(modname)
|
||||
except Exception:
|
||||
continue
|
||||
for _name, obj in inspect.getmembers(mod, inspect.isclass):
|
||||
if (
|
||||
issubclass(obj, BaseModel)
|
||||
and obj is not BaseModel
|
||||
and obj.__module__ == mod.__name__
|
||||
):
|
||||
models.append(obj)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
_EXCLUDED: set[str] = set()
|
||||
|
||||
_ALL_MODELS = [m for m in _collect_config_models() if m.__name__ not in _EXCLUDED]
|
||||
|
||||
# Sanity: make sure we actually collected a meaningful set.
|
||||
assert len(_ALL_MODELS) >= 15, f"Expected at least 15 config models, found {len(_ALL_MODELS)}: {[m.__name__ for m in _ALL_MODELS]}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
|
||||
def test_config_model_is_frozen(model_cls: type[BaseModel]):
|
||||
"""Every sub-config model must have frozen=True in its model_config."""
|
||||
cfg = model_cls.model_config
|
||||
assert cfg.get("frozen") is True, (
|
||||
f"{model_cls.__name__} is not frozen. "
|
||||
f"Add `model_config = ConfigDict(frozen=True)` or add `frozen=True` to the existing ConfigDict."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
|
||||
def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
|
||||
"""Constructing then mutating any field must raise ValidationError."""
|
||||
# Build a minimal instance -- use model_construct to skip validation for
|
||||
# required fields, then pick the first field to try mutating.
|
||||
fields = list(model_cls.model_fields.keys())
|
||||
if not fields:
|
||||
pytest.skip(f"{model_cls.__name__} has no fields")
|
||||
|
||||
instance = model_cls.model_construct()
|
||||
first_field = fields[0]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
setattr(instance, first_field, "MUTATED")
|
||||
|
||||
|
||||
def test_extensions_nested_dict_mutation_is_not_blocked_by_pydantic():
|
||||
"""Regression guard: Pydantic `frozen=True` does NOT deep-freeze container fields.
|
||||
|
||||
This test documents the trap — callers MUST compose a new dict and persist
|
||||
it + reload AppConfig instead of reaching into `extensions.skills[x]`.
|
||||
If you need the dict to be truly immutable, wrap with Mapping/frozendict.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
|
||||
|
||||
ext = ExtensionsConfig(mcp_servers={}, skills={"a": SkillStateConfig(enabled=True)})
|
||||
|
||||
# This is the pre-refactor anti-pattern: Pydantic lets it through because
|
||||
# the outer model is frozen but the inner dict is a plain builtin. No error.
|
||||
ext.skills["a"] = SkillStateConfig(enabled=False)
|
||||
ext.skills["b"] = SkillStateConfig(enabled=True)
|
||||
|
||||
# The test asserts the leak exists so a future "add deep-freeze" change
|
||||
# flips this expectation and forces call-site review.
|
||||
assert ext.skills["a"].enabled is False
|
||||
assert "b" in ext.skills
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Tests for LangChain-to-OpenAI message format converters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.converters import (
|
||||
langchain_messages_to_openai,
|
||||
langchain_to_openai_completion,
|
||||
langchain_to_openai_message,
|
||||
)
|
||||
|
||||
|
||||
def _make_ai_message(content="", tool_calls=None, id="msg-123", usage_metadata=None, response_metadata=None):
|
||||
msg = MagicMock()
|
||||
msg.type = "ai"
|
||||
msg.content = content
|
||||
msg.tool_calls = tool_calls or []
|
||||
msg.id = id
|
||||
msg.usage_metadata = usage_metadata
|
||||
msg.response_metadata = response_metadata or {}
|
||||
return msg
|
||||
|
||||
|
||||
def _make_human_message(content="Hello"):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = content
|
||||
return msg
|
||||
|
||||
|
||||
def _make_system_message(content="You are an assistant."):
|
||||
msg = MagicMock()
|
||||
msg.type = "system"
|
||||
msg.content = content
|
||||
return msg
|
||||
|
||||
|
||||
def _make_tool_message(content="result", tool_call_id="call-abc"):
|
||||
msg = MagicMock()
|
||||
msg.type = "tool"
|
||||
msg.content = content
|
||||
msg.tool_call_id = tool_call_id
|
||||
return msg
|
||||
|
||||
|
||||
class TestLangchainToOpenaiMessage:
|
||||
def test_ai_message_text_only(self):
|
||||
msg = _make_ai_message(content="Hello world")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "Hello world"
|
||||
assert "tool_calls" not in result
|
||||
|
||||
def test_ai_message_with_tool_calls(self):
|
||||
tool_calls = [
|
||||
{"id": "call-1", "name": "bash", "args": {"command": "ls"}},
|
||||
]
|
||||
msg = _make_ai_message(content="", tool_calls=tool_calls)
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] is None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
tc = result["tool_calls"][0]
|
||||
assert tc["id"] == "call-1"
|
||||
assert tc["type"] == "function"
|
||||
assert tc["function"]["name"] == "bash"
|
||||
# arguments must be a JSON string
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
assert args == {"command": "ls"}
|
||||
|
||||
def test_ai_message_text_and_tool_calls(self):
|
||||
tool_calls = [
|
||||
{"id": "call-2", "name": "read_file", "args": {"path": "/tmp/x"}},
|
||||
]
|
||||
msg = _make_ai_message(content="Reading the file", tool_calls=tool_calls)
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "Reading the file"
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
def test_ai_message_empty_content_no_tools(self):
|
||||
msg = _make_ai_message(content="")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == ""
|
||||
assert "tool_calls" not in result
|
||||
|
||||
def test_ai_message_list_content(self):
|
||||
# Multimodal content is preserved as-is
|
||||
list_content = [
|
||||
{"type": "text", "text": "Here is an image"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]
|
||||
msg = _make_ai_message(content=list_content)
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == list_content
|
||||
|
||||
def test_human_message(self):
|
||||
msg = _make_human_message("Tell me a joke")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Tell me a joke"
|
||||
|
||||
def test_tool_message(self):
|
||||
msg = _make_tool_message(content="file contents here", tool_call_id="call-xyz")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "tool"
|
||||
assert result["tool_call_id"] == "call-xyz"
|
||||
assert result["content"] == "file contents here"
|
||||
|
||||
def test_system_message(self):
|
||||
msg = _make_system_message("You are a helpful assistant.")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "system"
|
||||
assert result["content"] == "You are a helpful assistant."
|
||||
|
||||
|
||||
class TestLangchainToOpenaiCompletion:
|
||||
def test_basic_completion(self):
|
||||
usage_metadata = {"input_tokens": 10, "output_tokens": 20}
|
||||
msg = _make_ai_message(
|
||||
content="Hello",
|
||||
id="msg-abc",
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata={"model_name": "gpt-4o", "finish_reason": "stop"},
|
||||
)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["id"] == "msg-abc"
|
||||
assert result["model"] == "gpt-4o"
|
||||
assert len(result["choices"]) == 1
|
||||
choice = result["choices"][0]
|
||||
assert choice["index"] == 0
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert choice["message"]["role"] == "assistant"
|
||||
assert choice["message"]["content"] == "Hello"
|
||||
assert result["usage"] is not None
|
||||
assert result["usage"]["prompt_tokens"] == 10
|
||||
assert result["usage"]["completion_tokens"] == 20
|
||||
assert result["usage"]["total_tokens"] == 30
|
||||
|
||||
def test_completion_with_tool_calls(self):
|
||||
tool_calls = [{"id": "call-1", "name": "bash", "args": {}}]
|
||||
msg = _make_ai_message(
|
||||
content="",
|
||||
tool_calls=tool_calls,
|
||||
id="msg-tc",
|
||||
response_metadata={"model_name": "gpt-4o"},
|
||||
)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["choices"][0]["finish_reason"] == "tool_calls"
|
||||
|
||||
def test_completion_no_usage(self):
|
||||
msg = _make_ai_message(content="Hi", id="msg-nousage", usage_metadata=None)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["usage"] is None
|
||||
|
||||
def test_finish_reason_from_response_metadata(self):
|
||||
msg = _make_ai_message(
|
||||
content="Done",
|
||||
id="msg-fr",
|
||||
response_metadata={"model_name": "claude-3", "finish_reason": "end_turn"},
|
||||
)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["choices"][0]["finish_reason"] == "end_turn"
|
||||
|
||||
def test_finish_reason_default_stop(self):
|
||||
msg = _make_ai_message(content="Done", id="msg-defstop", response_metadata={})
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
|
||||
class TestMessagesToOpenai:
|
||||
def test_convert_message_list(self):
|
||||
human = _make_human_message("Hi")
|
||||
ai = _make_ai_message(content="Hello!")
|
||||
tool_msg = _make_tool_message("result", "call-1")
|
||||
messages = [human, ai, tool_msg]
|
||||
result = langchain_messages_to_openai(messages)
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert langchain_messages_to_openai([]) == []
|
||||
@@ -9,7 +9,9 @@ import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from deerflow.config.agents_api_config import AgentsApiConfig, get_agents_api_config, set_agents_api_config
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = MemoryConfig()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -329,38 +331,26 @@ class TestMemoryFilePath:
|
||||
def test_global_memory_path(self, tmp_path):
|
||||
"""None agent_name should return global memory file."""
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
|
||||
def test_agent_memory_path(self, tmp_path):
|
||||
"""Providing agent_name should return per-agent memory file."""
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path("code-reviewer")
|
||||
assert path == tmp_path / "agents" / "code-reviewer" / "memory.json"
|
||||
|
||||
def test_different_paths_for_different_agents(self, tmp_path):
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)):
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path_global = storage._get_memory_file_path(None)
|
||||
path_a = storage._get_memory_file_path("agent-a")
|
||||
path_b = storage._get_memory_file_path("agent-b")
|
||||
@@ -389,38 +379,13 @@ def _make_test_app(tmp_path: Path):
|
||||
@pytest.fixture()
|
||||
def agent_client(tmp_path):
|
||||
"""TestClient with agents router, using tmp_path as base_dir."""
|
||||
import app.gateway.routers.agents as agents_router
|
||||
|
||||
paths_instance = _make_paths(tmp_path)
|
||||
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
|
||||
set_agents_api_config(AgentsApiConfig(enabled=True))
|
||||
try:
|
||||
app = _make_test_app(tmp_path)
|
||||
with TestClient(app) as client:
|
||||
client._tmp_path = tmp_path # type: ignore[attr-defined]
|
||||
yield client
|
||||
finally:
|
||||
set_agents_api_config(previous_config)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def disabled_agent_client(tmp_path):
|
||||
"""TestClient with agents router while the management API is disabled."""
|
||||
import app.gateway.routers.agents as agents_router
|
||||
|
||||
paths_instance = _make_paths(tmp_path)
|
||||
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
|
||||
set_agents_api_config(AgentsApiConfig(enabled=False))
|
||||
try:
|
||||
app = _make_test_app(tmp_path)
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
finally:
|
||||
set_agents_api_config(previous_config)
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch("app.gateway.routers.agents.get_paths", return_value=paths_instance):
|
||||
app = _make_test_app(tmp_path)
|
||||
with TestClient(app) as client:
|
||||
client._tmp_path = tmp_path # type: ignore[attr-defined]
|
||||
yield client
|
||||
|
||||
|
||||
class TestAgentsAPI:
|
||||
@@ -586,37 +551,3 @@ class TestUserProfileAPI:
|
||||
response = agent_client.put("/api/user-profile", json={"content": ""})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["content"] is None
|
||||
|
||||
|
||||
class TestAgentsApiDisabled:
|
||||
def test_agents_list_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.get("/api/agents")
|
||||
assert response.status_code == 403
|
||||
assert "agents_api.enabled=true" in response.json()["detail"]
|
||||
|
||||
def test_agent_get_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.get("/api/agents/example-agent")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_name_check_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.get("/api/agents/check", params={"name": "example-agent"})
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_create_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.post("/api/agents", json={"name": "example-agent", "soul": "blocked"})
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_update_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.put("/api/agents/example-agent", json={"description": "blocked"})
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_agent_delete_returns_403(self, disabled_agent_client):
|
||||
response = disabled_agent_client.delete("/api/agents/example-agent")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_user_profile_routes_return_403(self, disabled_agent_client):
|
||||
get_response = disabled_agent_client.get("/api/user-profile")
|
||||
put_response = disabled_agent_client.put("/api/user-profile", json={"content": "blocked"})
|
||||
|
||||
assert get_response.status_code == 403
|
||||
assert put_response.status_code == 403
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Tests for DeerFlowContext and resolve_context()."""
|
||||
|
||||
from dataclasses import FrozenInstanceError
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _make_config(**overrides) -> AppConfig:
|
||||
defaults = {"sandbox": SandboxConfig(use="test")}
|
||||
defaults.update(overrides)
|
||||
return AppConfig(**defaults)
|
||||
|
||||
|
||||
class TestDeerFlowContext:
|
||||
def test_frozen(self):
|
||||
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
ctx.app_config = _make_config()
|
||||
|
||||
def test_fields(self):
|
||||
config = _make_config()
|
||||
ctx = DeerFlowContext(app_config=config, thread_id="t1", agent_name="test-agent")
|
||||
assert ctx.thread_id == "t1"
|
||||
assert ctx.agent_name == "test-agent"
|
||||
assert ctx.app_config is config
|
||||
|
||||
def test_agent_name_default(self):
|
||||
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
|
||||
assert ctx.agent_name is None
|
||||
|
||||
def test_thread_id_required(self):
|
||||
with pytest.raises(TypeError):
|
||||
DeerFlowContext(app_config=_make_config()) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestResolveContext:
|
||||
def test_returns_typed_context_directly(self):
|
||||
"""Gateway/Client path: runtime.context is DeerFlowContext → return as-is."""
|
||||
config = _make_config()
|
||||
ctx = DeerFlowContext(app_config=config, thread_id="t1")
|
||||
runtime = MagicMock()
|
||||
runtime.context = ctx
|
||||
assert resolve_context(runtime) is ctx
|
||||
|
||||
def test_raises_on_none_context(self):
|
||||
"""Without a typed DeerFlowContext, resolve_context refuses to guess."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = None
|
||||
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
|
||||
resolve_context(runtime)
|
||||
|
||||
def test_raises_on_dict_context(self):
|
||||
"""Legacy dict shape is no longer supported — we raise instead of lazily loading AppConfig."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
|
||||
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
|
||||
resolve_context(runtime)
|
||||
@@ -0,0 +1,296 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot no-op (admin creation removed), orphan migration
|
||||
when admin exists, no-op on no admin found, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _make_app_stub(store=None):
|
||||
"""Minimal app-like object with state.store."""
|
||||
app = SimpleNamespace()
|
||||
app.state = SimpleNamespace()
|
||||
app.state.store = store
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(admin_count=0):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=admin_count)
|
||||
p.count_admin_users = AsyncMock(return_value=admin_count)
|
||||
p.create_user = AsyncMock()
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
def _make_session_factory(admin_row=None):
|
||||
"""Build a mock async session factory that returns a row from execute()."""
|
||||
row_result = MagicMock()
|
||||
row_result.scalar_one_or_none.return_value = admin_row
|
||||
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = admin_row
|
||||
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock(return_value=execute_result)
|
||||
|
||||
# Async context manager
|
||||
session_cm = AsyncMock()
|
||||
session_cm.__aenter__ = AsyncMock(return_value=session)
|
||||
session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
sf = MagicMock()
|
||||
sf.return_value = session_cm
|
||||
return sf
|
||||
|
||||
|
||||
# ── First boot: no admin → return early ──────────────────────────────────
|
||||
|
||||
|
||||
def test_first_boot_does_not_create_admin():
|
||||
"""admin_count==0 → do NOT create admin automatically."""
|
||||
provider = _make_provider(admin_count=0)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_first_boot_skips_migration():
|
||||
"""No admin → return early before any migration attempt."""
|
||||
provider = _make_provider(admin_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
# ── Admin exists: migration runs when admin row found ────────────────────
|
||||
|
||||
|
||||
def test_admin_exists_triggers_migration():
|
||||
"""Admin exists and admin row found → _migrate_orphaned_threads called."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_exists_no_admin_row_skips_migration():
|
||||
"""Admin count > 0 but DB row missing (edge case) → skip migration gracefully."""
|
||||
provider = _make_provider(admin_count=2)
|
||||
sf = _make_session_factory(admin_row=None)
|
||||
store = AsyncMock()
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
def test_admin_exists_no_store_skips_migration():
|
||||
"""Admin exists, row found, but no store → no crash, no migration."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# No assertion needed — just verify no crash
|
||||
|
||||
|
||||
def test_admin_exists_session_factory_none_skips_migration():
|
||||
"""get_session_factory() returns None → return early, no crash."""
|
||||
provider = _make_provider(admin_count=1)
|
||||
store = AsyncMock()
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=None):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
|
||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||
|
||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||
(no auth) accumulates threads in the LangGraph Store namespace
|
||||
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||
rewrite each unowned item with the freshly created admin's id.
|
||||
"""
|
||||
from app.gateway.app import _migrate_orphaned_threads
|
||||
|
||||
# Three orphan items + one already-owned item that should be left alone.
|
||||
items = [
|
||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||
]
|
||||
store = AsyncMock()
|
||||
# asearch returns the entire batch on first call, then an empty page
|
||||
# to terminate _iter_store_items.
|
||||
store.asearch = AsyncMock(side_effect=[items, []])
|
||||
aput_calls: list[tuple[tuple, str, dict]] = []
|
||||
|
||||
async def _record_aput(namespace, key, value):
|
||||
aput_calls.append((namespace, key, value))
|
||||
|
||||
store.aput = AsyncMock(side_effect=_record_aput)
|
||||
|
||||
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
|
||||
|
||||
# Three orphan rows migrated, one preserved.
|
||||
assert migrated == 3
|
||||
assert len(aput_calls) == 3
|
||||
rewritten_keys = {call[1] for call in aput_calls}
|
||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||
# Each rewrite carries the new user_id; titles preserved where present.
|
||||
by_key = {call[1]: call[2] for call in aput_calls}
|
||||
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||
# The pre-owned item must NOT have been rewritten.
|
||||
assert "t4" not in rewritten_keys
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_empty_store_is_noop():
|
||||
"""A store with no threads → migrated == 0, no aput calls."""
|
||||
from app.gateway.app import _migrate_orphaned_threads
|
||||
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
store.aput = AsyncMock()
|
||||
|
||||
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
|
||||
|
||||
assert migrated == 0
|
||||
store.aput.assert_not_called()
|
||||
|
||||
|
||||
def test_iter_store_items_walks_multiple_pages():
|
||||
"""Cursor-style iterator pulls every page until a short page terminates.
|
||||
|
||||
Closes the regression where the old hardcoded ``limit=1000`` could
|
||||
silently drop orphans on a large pre-upgrade dataset. The migration
|
||||
code path uses the default ``page_size=500``; this test pins the
|
||||
iterator with ``page_size=2`` so it stays fast.
|
||||
"""
|
||||
from app.gateway.app import _iter_store_items
|
||||
|
||||
page_a = [SimpleNamespace(key=f"t{i}", value={"metadata": {}}) for i in range(2)]
|
||||
page_b = [SimpleNamespace(key=f"t{i + 2}", value={"metadata": {}}) for i in range(2)]
|
||||
page_c: list = [] # short page → loop terminates
|
||||
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=[page_a, page_b, page_c])
|
||||
|
||||
async def _collect():
|
||||
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=2)]
|
||||
|
||||
keys = asyncio.run(_collect())
|
||||
assert keys == ["t0", "t1", "t2", "t3"]
|
||||
# Three asearch calls: full batch, full batch, empty terminator
|
||||
assert store.asearch.await_count == 3
|
||||
|
||||
|
||||
def test_iter_store_items_terminates_on_short_page():
|
||||
"""A short page (len < page_size) ends the loop without an extra call."""
|
||||
from app.gateway.app import _iter_store_items
|
||||
|
||||
page = [SimpleNamespace(key=f"t{i}", value={}) for i in range(3)]
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=page)
|
||||
|
||||
async def _collect():
|
||||
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=10)]
|
||||
|
||||
keys = asyncio.run(_collect())
|
||||
assert keys == ["t0", "t1", "t2"]
|
||||
# Only one call — no terminator probe needed because len(batch) < page_size
|
||||
assert store.asearch.await_count == 1
|
||||
@@ -5,20 +5,36 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
|
||||
|
||||
def _runtime_with_config(config):
|
||||
"""Build a runtime carrying a custom (possibly mocked) app_config.
|
||||
|
||||
``DeerFlowContext`` is a frozen dataclass typed as ``AppConfig`` but
|
||||
dataclasses don't enforce the type at runtime — handing a Mock through
|
||||
lets tests exercise the tool's ``get_tool_config`` lookup without going
|
||||
through a process-global config.
|
||||
"""
|
||||
ctx = _P2Ctx.__new__(_P2Ctx)
|
||||
object.__setattr__(ctx, "app_config", config)
|
||||
object.__setattr__(ctx, "thread_id", "test-thread")
|
||||
object.__setattr__(ctx, "agent_name", None)
|
||||
return _P2NS(context=ctx)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_config():
|
||||
"""Mock the app config to return tool configurations."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 5,
|
||||
"search_type": "auto",
|
||||
"contents_max_characters": 1000,
|
||||
"api_key": "test-api-key",
|
||||
}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock_config
|
||||
"""Fixture retained as a pass-through: tests inject config via runtime directly."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -49,7 +65,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test query"})
|
||||
result = web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert len(parsed) == 2
|
||||
@@ -67,30 +83,30 @@ class TestWebSearchTool:
|
||||
|
||||
def test_search_with_custom_config(self, mock_exa_client):
|
||||
"""Test search respects custom configuration values."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 10,
|
||||
"search_type": "neural",
|
||||
"contents_max_characters": 2000,
|
||||
"api_key": "test-key",
|
||||
}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 10,
|
||||
"search_type": "neural",
|
||||
"contents_max_characters": 2000,
|
||||
"api_key": "test-key",
|
||||
}
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = tool_config
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_exa_client.search.return_value = mock_response
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_exa_client.search.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "neural search"})
|
||||
web_search_tool.func(query="neural search", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
mock_exa_client.search.assert_called_once_with(
|
||||
"neural search",
|
||||
type="neural",
|
||||
num_results=10,
|
||||
contents={"highlights": {"max_characters": 2000}},
|
||||
)
|
||||
mock_exa_client.search.assert_called_once_with(
|
||||
"neural search",
|
||||
type="neural",
|
||||
num_results=10,
|
||||
contents={"highlights": {"max_characters": 2000}},
|
||||
)
|
||||
|
||||
def test_search_with_no_highlights(self, mock_app_config, mock_exa_client):
|
||||
"""Test search handles results with no highlights."""
|
||||
@@ -105,7 +121,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
result = web_search_tool.func(query="test", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed[0]["snippet"] == ""
|
||||
@@ -118,7 +134,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "nothing"})
|
||||
result = web_search_tool.func(query="nothing", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed == []
|
||||
@@ -129,7 +145,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "error"})
|
||||
result = web_search_tool.func(query="error", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: API rate limit exceeded"
|
||||
|
||||
@@ -147,7 +163,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "# Fetched Page\n\nThis is the page content."
|
||||
mock_exa_client.get_contents.assert_called_once_with(
|
||||
@@ -167,7 +183,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result.startswith("# Untitled\n\n")
|
||||
|
||||
@@ -179,7 +195,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com/404"})
|
||||
result = web_fetch_tool.func(url="https://example.com/404", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: No results found"
|
||||
|
||||
@@ -189,16 +205,44 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: Connection timeout"
|
||||
|
||||
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
|
||||
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
mock_result.text = "Content."
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = [mock_result]
|
||||
mock_exa_client.get_contents.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
fake_config.get_tool_config.assert_any_call("web_fetch")
|
||||
|
||||
def test_fetch_uses_independent_api_key(self, mock_exa_client):
|
||||
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
|
||||
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
|
||||
mock_exa_cls.return_value = mock_exa_client
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
|
||||
def get_tool_config(name):
|
||||
if name == "web_fetch":
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
@@ -209,37 +253,9 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
mock_config.return_value.get_tool_config.assert_any_call("web_fetch")
|
||||
|
||||
def test_fetch_uses_independent_api_key(self, mock_exa_client):
|
||||
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
|
||||
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
|
||||
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
|
||||
mock_exa_cls.return_value = mock_exa_client
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
|
||||
def get_tool_config(name):
|
||||
if name == "web_fetch":
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
mock_config.return_value.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
mock_result.text = "Content."
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = [mock_result]
|
||||
mock_exa_client.get_contents.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
|
||||
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
|
||||
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
|
||||
|
||||
def test_fetch_truncates_long_content(self, mock_app_config, mock_exa_client):
|
||||
"""Test fetch truncates content to 4096 characters."""
|
||||
@@ -253,7 +269,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
# "# Long Page\n\n" is 14 chars, content truncated to 4096
|
||||
content_after_header = result.split("\n\n", 1)[1]
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
"""Tests for FeedbackRepository and follow-up association.
|
||||
|
||||
Uses temp SQLite DB for ORM tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
|
||||
async def _make_feedback_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return FeedbackRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- FeedbackRepository --
|
||||
|
||||
|
||||
class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_positive(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
assert record["feedback_id"]
|
||||
assert record["rating"] == 1
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["thread_id"] == "t1"
|
||||
assert "created_at" in record
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_negative_with_comment(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(
|
||||
run_id="r1",
|
||||
thread_id="t1",
|
||||
rating=-1,
|
||||
comment="Response was inaccurate",
|
||||
)
|
||||
assert record["rating"] == -1
|
||||
assert record["comment"] == "Response was inaccurate"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_message_id(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
|
||||
assert record["message_id"] == "msg-42"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
assert record["user_id"] == "user-1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_zero(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=0)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_five(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=5)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
fetched = await repo.get(created["feedback_id"])
|
||||
assert fetched is not None
|
||||
assert fetched["feedback_id"] == created["feedback_id"]
|
||||
assert fetched["rating"] == 1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r3", thread_id="t2", rating=1)
|
||||
results = await repo.list_by_thread("t1")
|
||||
assert len(results) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
deleted = await repo.delete(created["feedback_id"])
|
||||
assert deleted is True
|
||||
assert await repo.get(created["feedback_id"]) is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete("nonexistent")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
assert stats["negative"] == 1
|
||||
assert stats["run_id"] == "r1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 0
|
||||
assert stats["positive"] == 0
|
||||
assert stats["negative"] == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_creates_new(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
assert record["rating"] == 1
|
||||
assert record["feedback_id"]
|
||||
assert record["user_id"] == "u1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_updates_existing(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
||||
assert second["feedback_id"] == first["feedback_id"]
|
||||
assert second["rating"] == -1
|
||||
assert second["comment"] == "changed my mind"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_different_users_separate(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||
assert r1["feedback_id"] != r2["feedback_id"]
|
||||
assert r1["rating"] == 1
|
||||
assert r2["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_invalid_rating(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is True
|
||||
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||
assert len(results) == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert "r1" in grouped
|
||||
assert "r2" in grouped
|
||||
assert "r3" not in grouped
|
||||
assert grouped["r1"]["rating"] == 1
|
||||
assert grouped["r2"]["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert grouped == {}
|
||||
await _cleanup()
|
||||
|
||||
|
||||
# -- Follow-up association --
|
||||
|
||||
|
||||
class TestFollowUpAssociation:
|
||||
@pytest.mark.anyio
|
||||
async def test_run_records_follow_up_via_memory_store(self):
|
||||
"""MemoryRunStore stores follow_up_to_run_id in kwargs."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
await store.put("r1", thread_id="t1", status="success")
|
||||
# MemoryRunStore doesn't have follow_up_to_run_id as a top-level param,
|
||||
# but it can be passed via metadata
|
||||
await store.put("r2", thread_id="t1", metadata={"follow_up_to_run_id": "r1"})
|
||||
run = await store.get("r2")
|
||||
assert run["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_human_message_has_follow_up_metadata(self):
|
||||
"""human_message event metadata includes follow_up_to_run_id."""
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
event_store = MemoryRunEventStore()
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r2",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="Tell me more about that",
|
||||
metadata={"follow_up_to_run_id": "r1"},
|
||||
)
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_follow_up_auto_detection_logic(self):
|
||||
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
await store.put("r1", thread_id="t1", status="success")
|
||||
await store.put("r2", thread_id="t1", status="error")
|
||||
|
||||
# Auto-detect: list_by_thread returns newest first
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
# r2 (error) is newest, so no follow_up detected
|
||||
assert follow_up is None
|
||||
|
||||
# Now add a successful run
|
||||
await store.put("r3", thread_id="t1", status="success")
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
assert follow_up == "r3"
|
||||
@@ -3,14 +3,31 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from types import SimpleNamespace as _P2NS
|
||||
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
|
||||
|
||||
def _runtime_with_config(config):
|
||||
ctx = _P2Ctx.__new__(_P2Ctx)
|
||||
object.__setattr__(ctx, "app_config", config)
|
||||
object.__setattr__(ctx, "thread_id", "test-thread")
|
||||
object.__setattr__(ctx, "agent_name", None)
|
||||
return _P2NS(context=ctx)
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
|
||||
@patch("deerflow.community.firecrawl.tools.get_app_config")
|
||||
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
|
||||
def test_search_uses_web_search_config(self, mock_firecrawl_cls):
|
||||
search_config = MagicMock()
|
||||
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
|
||||
mock_get_app_config.return_value.get_tool_config.return_value = search_config
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = search_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.web = [
|
||||
@@ -20,7 +37,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.firecrawl.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test query"})
|
||||
result = web_search_tool.func(query="test query", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
assert json.loads(result) == [
|
||||
{
|
||||
@@ -29,15 +46,14 @@ class TestWebSearchTool:
|
||||
"snippet": "Snippet",
|
||||
}
|
||||
]
|
||||
mock_get_app_config.return_value.get_tool_config.assert_called_with("web_search")
|
||||
fake_config.get_tool_config.assert_called_with("web_search")
|
||||
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-search-key")
|
||||
mock_firecrawl_cls.return_value.search.assert_called_once_with("test query", limit=7)
|
||||
|
||||
|
||||
class TestWebFetchTool:
|
||||
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
|
||||
@patch("deerflow.community.firecrawl.tools.get_app_config")
|
||||
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
|
||||
def test_fetch_uses_web_fetch_config(self, mock_firecrawl_cls):
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}
|
||||
|
||||
@@ -46,7 +62,8 @@ class TestWebFetchTool:
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
mock_get_app_config.return_value.get_tool_config.side_effect = get_tool_config
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_scrape_result = MagicMock()
|
||||
mock_scrape_result.markdown = "Fetched markdown"
|
||||
@@ -55,10 +72,10 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.firecrawl.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
assert result == "# Fetched Page\n\nFetched markdown"
|
||||
mock_get_app_config.return_value.get_tool_config.assert_any_call("web_fetch")
|
||||
fake_config.get_tool_config.assert_any_call("web_fetch")
|
||||
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-fetch-key")
|
||||
mock_firecrawl_cls.return_value.scrape.assert_called_once_with(
|
||||
"https://example.com",
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for the FastAPI get_config dependency.
|
||||
|
||||
Phase 2 step 1: introduces the new explicit-config primitive that
|
||||
resolves ``AppConfig`` from ``request.app.state.config``. After migration,
|
||||
it is the sole mechanism.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def test_get_config_returns_app_state_config():
|
||||
"""get_config returns the AppConfig stored on app.state.config."""
|
||||
app = FastAPI()
|
||||
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
app.state.config = cfg
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(c: AppConfig = Depends(get_config)):
|
||||
# Identity check: FastAPI must hand us the exact object from app.state
|
||||
return {"same_identity": c is cfg, "log_level": c.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["same_identity"] is True
|
||||
assert body["log_level"] == "info"
|
||||
|
||||
|
||||
def test_get_config_reads_updated_app_state():
|
||||
"""When app.state.config is swapped (config reload), get_config sees the new value."""
|
||||
app = FastAPI()
|
||||
original = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
|
||||
replacement = original.model_copy(update={"log_level": "debug"})
|
||||
|
||||
app.state.config = original
|
||||
|
||||
@app.get("/log-level")
|
||||
def log_level(c: AppConfig = Depends(get_config)):
|
||||
return {"level": c.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/log-level").json() == {"level": "info"}
|
||||
|
||||
# Simulate config reload (PUT /mcp/config, etc.)
|
||||
app.state.config = replacement
|
||||
assert client.get("/log-level").json() == {"level": "debug"}
|
||||
@@ -333,12 +333,14 @@ class TestGuardrailsConfig:
|
||||
assert config.provider.use == "deerflow.guardrails.builtin:AllowlistProvider"
|
||||
assert config.provider.config == {"denied_tools": ["bash"]}
|
||||
|
||||
def test_singleton_load_and_get(self):
|
||||
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict, reset_guardrails_config
|
||||
def test_guardrails_config_via_app_config(self):
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.guardrails_config import GuardrailProviderConfig, GuardrailsConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
try:
|
||||
load_guardrails_config_from_dict({"enabled": True, "provider": {"use": "test:Foo"}})
|
||||
config = get_guardrails_config()
|
||||
assert config.enabled is True
|
||||
finally:
|
||||
reset_guardrails_config()
|
||||
cfg = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
guardrails=GuardrailsConfig(enabled=True, provider=GuardrailProviderConfig(use="test:Foo")),
|
||||
)
|
||||
config = cfg.guardrails
|
||||
assert config.enabled is True
|
||||
|
||||
@@ -6,6 +6,16 @@ from unittest.mock import MagicMock, patch
|
||||
from deerflow.community.infoquest import tools
|
||||
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
class TestInfoQuestClient:
|
||||
def test_infoquest_client_initialization(self):
|
||||
@@ -130,7 +140,7 @@ class TestInfoQuestClient:
|
||||
mock_client.web_search.return_value = json.dumps([])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.web_search_tool.run("test query")
|
||||
result = tools.web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == json.dumps([])
|
||||
mock_get_client.assert_called_once()
|
||||
@@ -143,14 +153,13 @@ class TestInfoQuestClient:
|
||||
mock_client.fetch.return_value = "<html><body>Test content</body></html>"
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.web_fetch_tool.run("https://example.com")
|
||||
result = tools.web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "# Untitled\n\nTest content"
|
||||
mock_get_client.assert_called_once()
|
||||
mock_client.fetch.assert_called_once_with("https://example.com")
|
||||
|
||||
@patch("deerflow.community.infoquest.tools.get_app_config")
|
||||
def test_get_infoquest_client(self, mock_get_app_config):
|
||||
def test_get_infoquest_client(self):
|
||||
"""Test _get_infoquest_client function with config."""
|
||||
mock_config = MagicMock()
|
||||
# Add image_search config to the side_effect
|
||||
@@ -159,9 +168,8 @@ class TestInfoQuestClient:
|
||||
MagicMock(model_extra={"fetch_time": 10, "timeout": 30, "navigation_timeout": 60}), # web_fetch config
|
||||
MagicMock(model_extra={"image_search_time_range": 7, "image_size": "l"}), # image_search config
|
||||
]
|
||||
mock_get_app_config.return_value = mock_config
|
||||
|
||||
client = tools._get_infoquest_client()
|
||||
client = tools._get_infoquest_client(mock_config)
|
||||
|
||||
assert client.search_time_range == 24
|
||||
assert client.fetch_time == 10
|
||||
@@ -321,7 +329,7 @@ class TestImageSearch:
|
||||
mock_client.image_search.return_value = json.dumps([{"image_url": "https://example.com/image1.jpg"}])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.image_search_tool.run({"query": "test query"})
|
||||
result = tools.image_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
|
||||
# Check if result is a valid JSON string
|
||||
result_data = json.loads(result)
|
||||
@@ -340,7 +348,7 @@ class TestImageSearch:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Pass all parameters as a dictionary (extra parameters will be ignored)
|
||||
tools.image_search_tool.run({"query": "sunset", "time_range": 30, "site": "unsplash.com", "image_size": "l"})
|
||||
tools.image_search_tool.func(query="sunset", runtime=_P2_RUNTIME)
|
||||
|
||||
mock_get_client.assert_called_once()
|
||||
# image_search_tool only passes query to client.image_search
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for the POST /api/v1/auth/initialize endpoint.
|
||||
|
||||
Covers: first-boot admin creation, rejection when system already
|
||||
initialized, password strength validation,
|
||||
and public accessibility (no auth cookie required).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
|
||||
_TEST_SECRET = "test-secret-key-initialize-admin-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth(tmp_path):
|
||||
"""Fresh SQLite engine + auth config per test."""
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(_setup_auth):
|
||||
from app.gateway.app import create_app
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
app = create_app()
|
||||
# Do NOT use TestClient as a context manager — that would trigger the
|
||||
# full lifespan which requires config.yaml. The auth endpoints work
|
||||
# without the lifespan (persistence engine is set up by _setup_auth).
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
def _init_payload(**extra):
|
||||
"""Build a valid /initialize payload."""
|
||||
return {
|
||||
"email": "admin@example.com",
|
||||
"password": "Str0ng!Pass99",
|
||||
**extra,
|
||||
}
|
||||
|
||||
|
||||
# ── Happy path ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_creates_admin_and_sets_cookie(client):
|
||||
"""POST /initialize when no admin exists → 201, session cookie set."""
|
||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["email"] == "admin@example.com"
|
||||
assert data["system_role"] == "admin"
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_initialize_needs_setup_false(client):
|
||||
"""Newly created admin via /initialize has needs_setup=False."""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
me = client.get("/api/v1/auth/me")
|
||||
assert me.status_code == 200
|
||||
assert me.json()["needs_setup"] is False
|
||||
|
||||
|
||||
# ── Rejection when already initialized ───────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_rejected_when_admin_exists(client):
|
||||
"""Second call to /initialize after admin exists → 409 system_already_initialized."""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
resp2 = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "email": "other@example.com"},
|
||||
)
|
||||
assert resp2.status_code == 409
|
||||
body = resp2.json()
|
||||
assert body["detail"]["code"] == "system_already_initialized"
|
||||
|
||||
|
||||
def test_initialize_register_does_not_block_initialization(client):
|
||||
"""/register creating a user before /initialize doesn't block admin creation."""
|
||||
# Register a regular user first
|
||||
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
||||
# /initialize should still succeed (checks admin_count, not total user_count)
|
||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["system_role"] == "admin"
|
||||
|
||||
|
||||
# ── Endpoint is public (no cookie required) ───────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_accessible_without_cookie(client):
|
||||
"""No access_token cookie needed for /initialize."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json=_init_payload(),
|
||||
cookies={},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
|
||||
# ── Password validation ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_rejects_short_password(client):
|
||||
"""Password shorter than 8 chars → 422."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "password": "short"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_initialize_rejects_common_password(client):
|
||||
"""Common password → 422."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── setup-status reflects initialization ─────────────────────────────────
|
||||
|
||||
|
||||
def test_setup_status_before_initialization(client):
|
||||
"""setup-status returns needs_setup=True before /initialize is called."""
|
||||
resp = client.get("/api/v1/auth/setup-status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["needs_setup"] is True
|
||||
|
||||
|
||||
def test_setup_status_after_initialization(client):
|
||||
"""setup-status returns needs_setup=False after /initialize succeeds."""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
resp = client.get("/api/v1/auth/setup-status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["needs_setup"] is False
|
||||
|
||||
|
||||
def test_setup_status_false_when_only_regular_user_exists(client):
|
||||
"""setup-status returns needs_setup=True even when regular users exist (no admin)."""
|
||||
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
||||
resp = client.get("/api/v1/auth/setup-status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["needs_setup"] is True
|
||||
@@ -6,7 +6,7 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
from deerflow.tools.builtins.invoke_acp_agent_tool import (
|
||||
_build_acp_mcp_servers,
|
||||
_build_mcp_servers,
|
||||
@@ -18,7 +18,6 @@ from deerflow.tools.tools import get_available_tools
|
||||
|
||||
|
||||
def test_build_mcp_servers_filters_disabled_and_maps_transports():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
|
||||
fresh_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"]),
|
||||
@@ -40,11 +39,9 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
|
||||
}
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_acp_mcp_servers_formats_list_payload():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
|
||||
fresh_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
|
||||
@@ -77,7 +74,6 @@ def test_build_acp_mcp_servers_formats_list_payload():
|
||||
]
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_permission_response_prefers_allow_once():
|
||||
@@ -152,8 +148,10 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
|
||||
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
||||
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
result = _get_work_dir("thread-abc-123")
|
||||
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
||||
assert result == str(expected)
|
||||
@@ -310,8 +308,10 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
||||
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
||||
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
@@ -665,31 +665,23 @@ async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch,
|
||||
|
||||
|
||||
def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(monkeypatch):
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
|
||||
load_acp_config_from_dict(
|
||||
{
|
||||
"codex": {
|
||||
"command": "codex-acp",
|
||||
"args": [],
|
||||
"description": "Codex CLI",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
tools=[],
|
||||
models=[],
|
||||
tool_search=SimpleNamespace(enabled=False),
|
||||
acp_agents={
|
||||
"codex": ACPAgentConfig(
|
||||
command="codex-acp",
|
||||
args=[],
|
||||
description="Codex CLI",
|
||||
)
|
||||
},
|
||||
get_model_config=lambda name: None,
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: fake_config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
|
||||
)
|
||||
|
||||
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
|
||||
tools = get_available_tools(include_mcp=True, subagent_enabled=False, app_config=fake_config)
|
||||
assert "invoke_acp_agent" in [tool.name for tool in tools]
|
||||
|
||||
load_acp_config_from_dict({})
|
||||
|
||||
@@ -10,6 +10,16 @@ import deerflow.community.jina_ai.jina_client as jina_client_module
|
||||
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||
from deerflow.community.jina_ai.tools import web_fetch_tool
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jina_client():
|
||||
@@ -176,9 +186,8 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
assert result.startswith("Error:")
|
||||
assert "429" in result
|
||||
|
||||
@@ -192,9 +201,8 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
assert "Hello world" in result
|
||||
assert not result.startswith("Error:")
|
||||
|
||||
|
||||
@@ -0,0 +1,312 @@
|
||||
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
|
||||
|
||||
Validates that the LangGraph auth layer enforces the same rules as Gateway:
|
||||
cookie → JWT decode → DB lookup → token_version check → owner filter
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _req(cookies=None, method="GET", headers=None):
|
||||
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
|
||||
|
||||
|
||||
def _user(user_id=None, token_version=0):
|
||||
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
|
||||
|
||||
def _mock_provider(user=None):
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(return_value=user)
|
||||
return p
|
||||
|
||||
|
||||
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_cookie_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req()))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_invalid_jwt_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_expired_jwt_raises_401():
|
||||
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_user_not_found_raises_401():
|
||||
token = create_access_token("ghost")
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "User not found" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_token_version_mismatch_raises_401():
|
||||
user = _user(token_version=2)
|
||||
token = create_access_token(str(user.id), token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "revoked" in str(exc.value.detail).lower()
|
||||
|
||||
|
||||
def test_valid_token_returns_user_id():
|
||||
user = _user(token_version=0)
|
||||
token = create_access_token(str(user.id), token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
def test_valid_token_matching_version():
|
||||
user = _user(token_version=5)
|
||||
token = create_access_token(str(user.id), token_version=5)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_provider_exception_propagates():
|
||||
"""Provider raises → should not be swallowed silently."""
|
||||
token = create_access_token("user-1")
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||
with pytest.raises(RuntimeError, match="DB down"):
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
|
||||
|
||||
def test_jwt_missing_ver_defaults_to_zero():
|
||||
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert result == uid
|
||||
|
||||
|
||||
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_wrong_secret_raises_401():
|
||||
"""Token signed with different secret → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# ── @auth.on (owner filter) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
|
||||
|
||||
def __init__(self, identity: str):
|
||||
self.identity = identity
|
||||
self.is_authenticated = True
|
||||
self.display_name = identity
|
||||
|
||||
|
||||
def _make_ctx(user_id):
|
||||
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
|
||||
|
||||
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"user_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["user_id"] != f_b["user_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"user_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["user_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["user_id"] == "user-z"
|
||||
assert result == {"user_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shared_jwt_secret():
|
||||
token = create_access_token("user-1", token_version=3)
|
||||
payload = decode_token(token)
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.sub == "user-1"
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_langgraph_json_has_auth_path():
|
||||
import json
|
||||
|
||||
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||
assert "auth" in config
|
||||
assert "langgraph_auth" in config["auth"]["path"]
|
||||
|
||||
|
||||
def test_auth_handler_has_both_layers():
|
||||
from app.gateway.langgraph_auth import auth
|
||||
|
||||
assert auth._authenticate_handler is not None
|
||||
assert len(auth._global_handlers) == 1
|
||||
|
||||
|
||||
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_csrf_get_no_check():
|
||||
"""GET requests skip CSRF — should proceed to JWT validation."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="GET")))
|
||||
# Rejected by missing cookie, NOT by CSRF
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_missing_token():
|
||||
"""POST without CSRF token → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
assert "CSRF token missing" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_mismatched_token():
|
||||
"""POST with mismatched CSRF tokens → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
|
||||
headers={"x-csrf-token": "wrong-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
assert "mismatch" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_matching_token_proceeds_to_jwt():
|
||||
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "garbage", "csrf_token": "same-token"},
|
||||
headers={"x-csrf-token": "same-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
# Past CSRF, rejected by JWT decode
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_put_requires_token():
|
||||
"""PUT also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
def test_csrf_delete_requires_token():
|
||||
"""DELETE also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
@@ -33,7 +32,7 @@ def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
def test_resolve_model_name_falls_back_to_default(caplog):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
@@ -41,16 +40,14 @@ def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
resolved = lead_agent_module._resolve_model_name("missing-model")
|
||||
resolved = lead_agent_module._resolve_model_name(app_config, "missing-model")
|
||||
|
||||
assert resolved == "default-model"
|
||||
assert "fallback to default model 'default-model'" in caplog.text
|
||||
|
||||
|
||||
def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||
def test_resolve_model_name_uses_default_when_none():
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
@@ -58,23 +55,19 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
resolved = lead_agent_module._resolve_model_name(None)
|
||||
resolved = lead_agent_module._resolve_model_name(app_config, None)
|
||||
|
||||
assert resolved == "default-model"
|
||||
|
||||
|
||||
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
|
||||
def test_resolve_model_name_raises_when_no_models_configured():
|
||||
app_config = _make_app_config([])
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="No chat models are configured",
|
||||
):
|
||||
lead_agent_module._resolve_model_name("missing-model")
|
||||
lead_agent_module._resolve_model_name(app_config, "missing-model")
|
||||
|
||||
|
||||
def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch):
|
||||
@@ -82,13 +75,12 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
|
||||
import deerflow.tools as tools_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda app_config, config, model_name, agent_name=None: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
@@ -105,7 +97,8 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert captured["name"] == "safe-model"
|
||||
@@ -113,74 +106,6 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
_make_model("context-model", supports_thinking=True),
|
||||
]
|
||||
)
|
||||
|
||||
import deerflow.tools as tools_module
|
||||
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
result = lead_agent_module.make_lead_agent(
|
||||
{
|
||||
"context": {
|
||||
"model_name": "context-model",
|
||||
"thinking_enabled": False,
|
||||
"reasoning_effort": "high",
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
"max_concurrent_subagents": 7,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert captured == {
|
||||
"name": "context-model",
|
||||
"thinking_enabled": False,
|
||||
"reasoning_effort": "high",
|
||||
}
|
||||
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
def test_make_lead_agent_rejects_invalid_bootstrap_agent_name(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
lead_agent_module.make_lead_agent(
|
||||
{
|
||||
"configurable": {
|
||||
"model_name": "safe-model",
|
||||
"thinking_enabled": False,
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
"is_bootstrap": True,
|
||||
"agent_name": "../../../tmp/evil",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
@@ -197,11 +122,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda _ac: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
middlewares = lead_agent_module._build_middlewares(app_config, {"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
@@ -209,73 +133,27 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
|
||||
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["thinking_enabled"] = thinking_enabled
|
||||
captured["reasoning_effort"] = reasoning_effort
|
||||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
middleware = lead_agent_module._create_summarization_middleware(patched)
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
|
||||
|
||||
def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_middleware(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
|
||||
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
|
||||
|
||||
|
||||
def test_create_summarization_middleware_passes_skill_read_tool_names(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("default-model", supports_thinking=False)])
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, skill_file_read_tool_names=["read_file", "cat"]),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_middleware(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
|
||||
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["skill_file_read_tool_names"] == ["read_file", "cat"]
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
@@ -4,25 +4,23 @@ from types import SimpleNamespace
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts():
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
|
||||
assert prompt_module._build_custom_mounts_section() == ""
|
||||
assert prompt_module._build_custom_mounts_section(config) == ""
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
|
||||
def test_build_custom_mounts_section_lists_configured_mounts():
|
||||
mounts = [
|
||||
SimpleNamespace(container_path="/home/user/shared", read_only=False),
|
||||
SimpleNamespace(container_path="/mnt/reference", read_only=True),
|
||||
]
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
|
||||
section = prompt_module._build_custom_mounts_section()
|
||||
section = prompt_module._build_custom_mounts_section(config)
|
||||
|
||||
assert "**Custom Mounted Directories:**" in section
|
||||
assert "`/home/user/shared`" in section
|
||||
@@ -36,15 +34,15 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||
config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=mounts),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
prompt = prompt_module.apply_prompt_template(config)
|
||||
|
||||
assert "`/home/user/shared`" in prompt
|
||||
assert "Custom Mounted Directories" in prompt
|
||||
@@ -54,15 +52,15 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
|
||||
config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=[]),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
prompt = prompt_module.apply_prompt_template(config)
|
||||
|
||||
assert "Treat `/mnt/user-data/workspace` as your default current working directory" in prompt
|
||||
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
|
||||
@@ -83,7 +81,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
)
|
||||
|
||||
state = {"skills": [make_skill("first-skill")]}
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda *a, **kwargs: list(state["skills"]))
|
||||
prompt_module._reset_skills_system_prompt_cache_state()
|
||||
|
||||
try:
|
||||
@@ -119,7 +117,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def fake_load_skills(enabled_only=True):
|
||||
def fake_load_skills(*a, **kwargs):
|
||||
nonlocal active_loads, max_active_loads, call_count
|
||||
with lock:
|
||||
active_loads += 1
|
||||
@@ -156,7 +154,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
|
||||
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
|
||||
event = threading.Event()
|
||||
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda: event)
|
||||
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda *a, **k: event)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
warmed = prompt_module.warm_enabled_skills_cache(timeout_seconds=0.01)
|
||||
|
||||
@@ -19,27 +19,40 @@ def _make_skill(name: str) -> Skill:
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SKILLS_CONFIG = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
def _evolution_enabled_config() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills={"non_existent_skill"})
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"non_existent_skill"})
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=set())
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=set())
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_skills(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills={"skill1"})
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"skill1"})
|
||||
assert "skill1" in result
|
||||
assert "skill2" not in result
|
||||
assert "[built-in]" in result
|
||||
@@ -47,56 +60,41 @@ def test_get_skills_prompt_section_returns_skills(monkeypatch):
|
||||
|
||||
def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=None)
|
||||
assert "skill1" in result
|
||||
assert "skill2" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
|
||||
skills = [_make_skill("skill1")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.get_app_config",
|
||||
lambda: SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
|
||||
assert "Skill Self-Evolution" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.get_app_config",
|
||||
lambda: SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: [])
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
|
||||
assert "Skill Self-Evolution" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeypatch):
|
||||
skills = [_make_skill("skill1")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
config = _evolution_enabled_config()
|
||||
|
||||
enabled_result = get_skills_prompt_section(available_skills=None)
|
||||
enabled_result = get_skills_prompt_section(config, available_skills=None)
|
||||
assert "Skill Self-Evolution" in enabled_result
|
||||
|
||||
config.skill_evolution.enabled = False
|
||||
disabled_result = get_skills_prompt_section(available_skills=None)
|
||||
disabled_config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
disabled_result = get_skills_prompt_section(disabled_config, available_skills=None)
|
||||
assert "Skill Self-Evolution" not in disabled_result
|
||||
|
||||
|
||||
@@ -106,8 +104,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
# Mock dependencies
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda app_config=None, x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
@@ -118,11 +115,10 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = MockModelConfig()
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
captured_skills = []
|
||||
|
||||
def mock_apply_prompt_template(**kwargs):
|
||||
def mock_apply_prompt_template(_app_config, *args, **kwargs):
|
||||
captured_skills.append(kwargs.get("available_skills"))
|
||||
return "mock_prompt"
|
||||
|
||||
@@ -130,15 +126,15 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
# Case 1: Empty skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] == set()
|
||||
|
||||
# Case 2: None skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] is None
|
||||
|
||||
# Case 3: Some skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
@@ -22,26 +22,26 @@ def _make_config(*, allow_host_bash: bool, sandbox_use: str = "deerflow.sandbox.
|
||||
|
||||
|
||||
def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=False))
|
||||
app_config = _make_config(allow_host_bash=False)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=app_config)]
|
||||
|
||||
assert "bash" not in names
|
||||
assert "ls" in names
|
||||
|
||||
|
||||
def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=True))
|
||||
app_config = _make_config(allow_host_bash=True)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=app_config)]
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
@@ -52,13 +52,12 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
|
||||
allow_host_bash=False,
|
||||
extra_tools=[SimpleNamespace(name="shell", group="bash", use="deerflow.sandbox.tools:bash_tool")],
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=config)]
|
||||
|
||||
assert "bash" not in names
|
||||
assert "shell" not in names
|
||||
@@ -70,13 +69,12 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
|
||||
allow_host_bash=False,
|
||||
sandbox_use="deerflow.community.aio_sandbox:AioSandboxProvider",
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.tools.resolve_variable",
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=config)]
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import errno
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -314,8 +313,7 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
|
||||
|
||||
@@ -336,8 +334,7 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
@@ -360,8 +357,7 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
@@ -476,7 +472,6 @@ class TestLocalSandboxProviderMounts:
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
|
||||
|
||||
@@ -10,12 +10,22 @@ from deerflow.agents.middlewares.loop_detection_middleware import (
|
||||
LoopDetectionMiddleware,
|
||||
_hash_tool_calls,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(thread_id="test-thread"):
|
||||
"""Build a minimal Runtime mock with context."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
runtime.context = _make_context(thread_id)
|
||||
return runtime
|
||||
|
||||
|
||||
@@ -293,10 +303,10 @@ class TestLoopDetection:
|
||||
assert isinstance(mw._lock, type(mw._lock))
|
||||
|
||||
def test_fallback_thread_id_when_missing(self):
|
||||
"""When runtime context has no thread_id, should use 'default'."""
|
||||
"""When runtime context has empty thread_id, should use 'default'."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2)
|
||||
runtime = MagicMock()
|
||||
runtime.context = {}
|
||||
runtime.context = _make_context("")
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
@@ -3,23 +3,31 @@ import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _memory_config(**overrides: object) -> MemoryConfig:
|
||||
config = MemoryConfig()
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
def _make_config(**memory_overrides) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
with patch.object(queue, "_reset_timer"):
|
||||
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
|
||||
|
||||
@@ -29,7 +37,7 @@ def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
|
||||
|
||||
def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
@@ -50,16 +58,14 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
with patch.object(queue, "_reset_timer"):
|
||||
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||
|
||||
@@ -69,7 +75,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
@@ -90,6 +96,7 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for user_id propagation through memory queue."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_memory(monkeypatch):
|
||||
"""Ensure MemoryUpdateQueue.add() doesn't early-return on disabled memory."""
|
||||
config = MagicMock(spec=AppConfig)
|
||||
config.memory = MemoryConfig(enabled=True)
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice")
|
||||
assert ctx.user_id == "alice"
|
||||
|
||||
|
||||
def test_conversation_context_user_id_default_none():
|
||||
ctx = ConversationContext(thread_id="t1", messages=[])
|
||||
assert ctx.user_id is None
|
||||
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
assert q._queue[0].user_id == "alice"
|
||||
q.clear()
|
||||
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
q._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once()
|
||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "alice"
|
||||
@@ -4,6 +4,18 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import memory
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
_TEST_APP_CONFIG = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
"""Build a memory-router app pre-populated with a minimal AppConfig."""
|
||||
app = FastAPI()
|
||||
app.state.config = _TEST_APP_CONFIG
|
||||
app.include_router(memory.router)
|
||||
return app
|
||||
|
||||
|
||||
def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
@@ -25,8 +37,7 @@ def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
|
||||
|
||||
def test_export_memory_route_returns_current_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -49,8 +60,7 @@ def test_export_memory_route_returns_current_memory() -> None:
|
||||
|
||||
|
||||
def test_import_memory_route_returns_imported_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -73,8 +83,7 @@ def test_import_memory_route_returns_imported_memory() -> None:
|
||||
|
||||
|
||||
def test_export_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -98,8 +107,7 @@ def test_export_memory_route_preserves_source_error() -> None:
|
||||
|
||||
|
||||
def test_import_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -123,8 +131,7 @@ def test_import_memory_route_preserves_source_error() -> None:
|
||||
|
||||
|
||||
def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.clear_memory_data", return_value=_sample_memory()):
|
||||
with TestClient(app) as client:
|
||||
@@ -135,8 +142,7 @@ def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
|
||||
|
||||
def test_create_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -166,8 +172,7 @@ def test_create_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_delete_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -190,8 +195,7 @@ def test_delete_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.delete_memory_fact", side_effect=KeyError("fact_missing")):
|
||||
with TestClient(app) as client:
|
||||
@@ -202,8 +206,7 @@ def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -233,8 +236,7 @@ def test_update_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -258,18 +260,18 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
update_fact.assert_called_once_with(
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
category=None,
|
||||
confidence=None,
|
||||
)
|
||||
assert update_fact.call_count == 1
|
||||
call_kwargs = update_fact.call_args.kwargs
|
||||
assert call_kwargs.get("fact_id") == "fact_edit"
|
||||
assert call_kwargs.get("content") == "User prefers spaces"
|
||||
assert call_kwargs.get("category") is None
|
||||
assert call_kwargs.get("confidence") is None
|
||||
assert "user_id" in call_kwargs
|
||||
assert response.json()["facts"] == updated_memory["facts"]
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=KeyError("fact_missing")):
|
||||
with TestClient(app) as client:
|
||||
@@ -287,8 +289,7 @@ def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_specific_error_for_invalid_confidence() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=ValueError("confidence")):
|
||||
with TestClient(app) as client:
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for memory storage providers."""
|
||||
|
||||
import threading
|
||||
@@ -11,7 +23,13 @@ from deerflow.agents.memory.storage import (
|
||||
create_empty_memory,
|
||||
get_memory_storage,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _app_config(**memory_overrides) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
|
||||
|
||||
|
||||
class TestCreateEmptyMemory:
|
||||
@@ -53,10 +71,9 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
|
||||
def test_get_memory_file_path_agent(self, tmp_path):
|
||||
"""Should return per-agent memory file path when agent_name is provided."""
|
||||
@@ -67,14 +84,14 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path("test-agent")
|
||||
assert path == tmp_path / "agents" / "test-agent" / "memory.json"
|
||||
|
||||
@pytest.mark.parametrize("invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"])
|
||||
def test_validate_agent_name_invalid(self, invalid_name):
|
||||
"""Should raise ValueError for invalid agent names that don't match the pattern."""
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
with pytest.raises(ValueError, match="Invalid agent name|Agent name must be a non-empty string"):
|
||||
storage._validate_agent_name(invalid_name)
|
||||
|
||||
@@ -87,11 +104,10 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
memory = storage.load()
|
||||
assert isinstance(memory, dict)
|
||||
assert memory["version"] == "1.0"
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = storage.load()
|
||||
assert isinstance(memory, dict)
|
||||
assert memory["version"] == "1.0"
|
||||
|
||||
def test_save_writes_to_file(self, tmp_path):
|
||||
"""Should save memory data to file."""
|
||||
@@ -103,12 +119,11 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
|
||||
result = storage.save(test_memory)
|
||||
assert result is True
|
||||
assert memory_file.exists()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
|
||||
result = storage.save(test_memory)
|
||||
assert result is True
|
||||
assert memory_file.exists()
|
||||
|
||||
def test_save_does_not_mutate_caller_dict(self, tmp_path):
|
||||
"""save() must not mutate the caller's dict (lastUpdated side-effect)."""
|
||||
@@ -209,18 +224,17 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
# First load
|
||||
memory1 = storage.load()
|
||||
assert memory1["facts"][0]["content"] == "initial fact"
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
# First load
|
||||
memory1 = storage.load()
|
||||
assert memory1["facts"][0]["content"] == "initial fact"
|
||||
|
||||
# Update file directly
|
||||
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
|
||||
# Update file directly
|
||||
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
|
||||
|
||||
# Reload should get updated data
|
||||
memory2 = storage.reload()
|
||||
assert memory2["facts"][0]["content"] == "updated fact"
|
||||
# Reload should get updated data
|
||||
memory2 = storage.reload()
|
||||
assert memory2["facts"][0]["content"] == "updated fact"
|
||||
|
||||
|
||||
class TestGetMemoryStorage:
|
||||
@@ -237,22 +251,19 @@ class TestGetMemoryStorage:
|
||||
|
||||
def test_returns_file_memory_storage_by_default(self):
|
||||
"""Should return FileMemoryStorage by default."""
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_falls_back_to_file_memory_storage_on_error(self):
|
||||
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="non.existent.StorageClass")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_returns_singleton_instance(self):
|
||||
"""Should return the same instance on subsequent calls."""
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
storage1 = get_memory_storage()
|
||||
storage2 = get_memory_storage()
|
||||
assert storage1 is storage2
|
||||
storage1 = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
storage2 = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert storage1 is storage2
|
||||
|
||||
def test_get_memory_storage_thread_safety(self):
|
||||
"""Should safely initialize the singleton even with concurrent calls."""
|
||||
@@ -260,16 +271,15 @@ class TestGetMemoryStorage:
|
||||
|
||||
def get_storage():
|
||||
# get_memory_storage is called concurrently from multiple threads while
|
||||
# get_memory_config is patched once around thread creation. This verifies
|
||||
# AppConfig.get is patched once around thread creation. This verifies
|
||||
# that the singleton initialization remains thread-safe.
|
||||
results.append(get_memory_storage())
|
||||
results.append(get_memory_storage(_TEST_MEMORY_CONFIG))
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
threads = [threading.Thread(target=get_storage) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
threads = [threading.Thread(target=get_storage) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All results should be the exact same instance
|
||||
assert len(results) == 10
|
||||
@@ -278,13 +288,11 @@ class TestGetMemoryStorage:
|
||||
def test_get_memory_storage_invalid_class_fallback(self):
|
||||
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
|
||||
# Using a built-in function instead of a class
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="os.path.join")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_get_memory_storage_non_subclass_fallback(self):
|
||||
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
|
||||
# Using 'dict' as a class that is not a MemoryStorage subclass
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="builtins.dict")):
|
||||
storage = get_memory_storage()
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for per-user memory storage isolation."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _mock_app_config() -> AppConfig:
|
||||
"""Build a minimal AppConfig with default (empty) memory storage_path."""
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(storage_path=""))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage() -> FileMemoryStorage:
|
||||
return FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
|
||||
|
||||
|
||||
|
||||
class TestUserIsolatedStorage:
|
||||
def test_save_and_load_per_user(self, storage: FileMemoryStorage, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "User A context"
|
||||
storage.save(memory_a, user_id="alice")
|
||||
|
||||
memory_b = create_empty_memory()
|
||||
memory_b["user"]["workContext"]["summary"] = "User B context"
|
||||
storage.save(memory_b, user_id="bob")
|
||||
|
||||
loaded_a = storage.load(user_id="alice")
|
||||
loaded_b = storage.load(user_id="bob")
|
||||
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "User A context"
|
||||
assert loaded_b["user"]["workContext"]["summary"] == "User B context"
|
||||
|
||||
def test_user_memory_file_location(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_cache_isolated_per_user(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "A"
|
||||
s.save(memory_a, user_id="alice")
|
||||
|
||||
memory_b = create_empty_memory()
|
||||
memory_b["user"]["workContext"]["summary"] = "B"
|
||||
s.save(memory_b, user_id="bob")
|
||||
|
||||
loaded_a = s.load(user_id="alice")
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "A"
|
||||
|
||||
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id=None)
|
||||
expected_path = base_dir / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
|
||||
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
|
||||
legacy_mem = create_empty_memory()
|
||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||
s.save(legacy_mem, user_id=None)
|
||||
|
||||
user_mem = create_empty_memory()
|
||||
user_mem["user"]["workContext"]["summary"] = "alice"
|
||||
s.save(user_mem, user_id="alice")
|
||||
|
||||
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
||||
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
||||
|
||||
def test_user_agent_memory_file_location(self, base_dir: Path):
|
||||
"""Per-user per-agent memory uses the user_agent_memory_file path."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "agent scoped"
|
||||
s.save(memory, "test-agent", user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "agents" / "test-agent" / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_cache_key_is_user_agent_tuple(self, base_dir: Path):
|
||||
"""Cache keys must be (user_id, agent_name) tuples, not bare agent names."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
# After save, cache should have tuple key
|
||||
assert ("alice", None) in s._memory_cache
|
||||
|
||||
def test_reload_with_user_id(self, base_dir: Path):
|
||||
"""reload() with user_id should force re-read from the user-scoped file."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "initial"
|
||||
s.save(memory, user_id="alice")
|
||||
|
||||
# Load once to prime cache
|
||||
s.load(user_id="alice")
|
||||
|
||||
# Write updated content directly to file
|
||||
user_file = base_dir / "users" / "alice" / "memory.json"
|
||||
import json
|
||||
|
||||
updated = create_empty_memory()
|
||||
updated["user"]["workContext"]["summary"] = "updated"
|
||||
user_file.write_text(json.dumps(updated))
|
||||
|
||||
# reload should pick up the new content
|
||||
reloaded = s.reload(user_id="alice")
|
||||
assert reloaded["user"]["workContext"]["summary"] == "updated"
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Owner isolation tests for MemoryThreadMetaStore.
|
||||
|
||||
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
|
||||
the in-memory LangGraph Store backend used when database.backend=memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryThreadMetaStore(InMemoryStore())
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_search_isolation(store):
|
||||
"""search() returns only threads owned by the current user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta", display_name="B's thread")
|
||||
|
||||
with _as_user(USER_A):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
with _as_user(USER_B):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_get_isolation(store):
|
||||
"""get() returns None for threads owned by another user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert await store.get("t-alpha") is None
|
||||
|
||||
with _as_user(USER_A):
|
||||
result = await store.get("t-alpha")
|
||||
assert result is not None
|
||||
assert result["display_name"] == "A's thread"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_display_name_denied(store):
|
||||
"""User B cannot rename User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="original")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_display_name("t-alpha", "hacked")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_status_denied(store):
|
||||
"""User B cannot change status of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_status("t-alpha", "error")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["status"] == "idle"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_metadata_denied(store):
|
||||
"""User B cannot modify metadata of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", metadata={"key": "original"})
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_metadata("t-alpha", {"key": "hacked"})
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["metadata"]["key"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_delete_denied(store):
|
||||
"""User B cannot delete User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.delete("t-alpha")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_no_context_raises(store):
|
||||
"""Calling methods without user context raises RuntimeError."""
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await store.search()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(store):
|
||||
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta")
|
||||
|
||||
all_rows = await store.search(user_id=None)
|
||||
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
|
||||
|
||||
row = await store.get("t-alpha", user_id=None)
|
||||
assert row is not None
|
||||
@@ -1,22 +1,32 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
_extract_text,
|
||||
_run_async_update_sync,
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
import_memory_data,
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
|
||||
return {
|
||||
"version": "1.0",
|
||||
@@ -35,15 +45,12 @@ def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, obje
|
||||
}
|
||||
|
||||
|
||||
def _memory_config(**overrides: object) -> MemoryConfig:
|
||||
config = MemoryConfig()
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
def _memory_config(**overrides: object) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig().model_copy(update=overrides))
|
||||
|
||||
|
||||
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -70,19 +77,14 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
|
||||
{"content": "User likes Python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == ["User likes Python"]
|
||||
assert all(fact["id"] != "fact_remove" for fact in result["facts"])
|
||||
|
||||
|
||||
def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@@ -91,12 +93,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
|
||||
{"content": "User works on DeerFlow", "category": "context", "confidence": 0.87},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == [
|
||||
"User prefers dark mode",
|
||||
@@ -107,7 +104,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
|
||||
|
||||
|
||||
def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=2, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -135,12 +132,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
{"content": "User likes noisy logs", "category": "behavior", "confidence": 0.6},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == [
|
||||
"User likes Python",
|
||||
@@ -151,7 +143,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
|
||||
|
||||
def test_apply_updates_preserves_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@@ -163,19 +155,14 @@ def test_apply_updates_preserves_source_error() -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||
assert result["facts"][0]["category"] == "correction"
|
||||
|
||||
|
||||
def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@@ -187,19 +174,14 @@ def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert "sourceError" not in result["facts"][0]
|
||||
|
||||
|
||||
def test_clear_memory_data_resets_all_sections() -> None:
|
||||
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
|
||||
result = clear_memory_data()
|
||||
result = clear_memory_data(_TEST_MEMORY_CONFIG)
|
||||
|
||||
assert result["version"] == "1.0"
|
||||
assert result["facts"] == []
|
||||
@@ -233,7 +215,7 @@ def test_delete_memory_fact_removes_only_matching_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = delete_memory_fact("fact_delete")
|
||||
result = delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_delete")
|
||||
|
||||
assert [fact["id"] for fact in result["facts"]] == ["fact_keep"]
|
||||
|
||||
@@ -243,7 +225,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = create_memory_fact(
|
||||
result = create_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
content=" User prefers concise code reviews. ",
|
||||
category="preference",
|
||||
confidence=0.88,
|
||||
@@ -258,7 +240,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
|
||||
|
||||
def test_create_memory_fact_rejects_empty_content() -> None:
|
||||
try:
|
||||
create_memory_fact(content=" ")
|
||||
create_memory_fact(_TEST_MEMORY_CONFIG, content=" ")
|
||||
except ValueError as exc:
|
||||
assert exc.args == ("content",)
|
||||
else:
|
||||
@@ -268,7 +250,7 @@ def test_create_memory_fact_rejects_empty_content() -> None:
|
||||
def test_create_memory_fact_rejects_invalid_confidence() -> None:
|
||||
for confidence in (-0.1, 1.1, float("nan"), float("inf"), float("-inf")):
|
||||
try:
|
||||
create_memory_fact(content="User likes tests", confidence=confidence)
|
||||
create_memory_fact(_TEST_MEMORY_CONFIG, content="User likes tests", confidence=confidence)
|
||||
except ValueError as exc:
|
||||
assert exc.args == ("confidence",)
|
||||
else:
|
||||
@@ -278,7 +260,7 @@ def test_create_memory_fact_rejects_invalid_confidence() -> None:
|
||||
def test_delete_memory_fact_raises_for_unknown_id() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
|
||||
try:
|
||||
delete_memory_fact("fact_missing")
|
||||
delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_missing")
|
||||
except KeyError as exc:
|
||||
assert exc.args == ("fact_missing",)
|
||||
else:
|
||||
@@ -303,10 +285,10 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
||||
mock_storage.load.return_value = imported_memory
|
||||
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
result = import_memory_data(imported_memory)
|
||||
result = import_memory_data(_TEST_MEMORY_CONFIG, imported_memory)
|
||||
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None)
|
||||
mock_storage.load.assert_called_once_with(None)
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||
assert result == imported_memory
|
||||
|
||||
|
||||
@@ -336,7 +318,7 @@ def test_update_memory_fact_updates_only_matching_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = update_memory_fact(
|
||||
result = update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
category="workflow",
|
||||
@@ -369,7 +351,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = update_memory_fact(
|
||||
result = update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
)
|
||||
@@ -382,7 +364,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
|
||||
def test_update_memory_fact_raises_for_unknown_id() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
|
||||
try:
|
||||
update_memory_fact(
|
||||
update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_missing",
|
||||
content="User prefers concise code reviews.",
|
||||
category="preference",
|
||||
@@ -414,7 +396,7 @@ def test_update_memory_fact_rejects_invalid_confidence() -> None:
|
||||
return_value=current_memory,
|
||||
):
|
||||
try:
|
||||
update_memory_fact(
|
||||
update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
confidence=confidence,
|
||||
@@ -527,17 +509,15 @@ class TestUpdateMemoryStructuredResponse:
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_string_response_parses(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -551,17 +531,15 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_list_content_response_parses(self):
|
||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
list_content = [{"type": "text", "text": valid_json}]
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -576,38 +554,13 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_async_update_memory_uses_ainvoke(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi there"
|
||||
ai_msg.tool_calls = []
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -622,17 +575,16 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -647,95 +599,15 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello from loop"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
async def run_in_loop():
|
||||
return updater.update_memory([msg, ai_msg])
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
||||
updater = MemoryUpdater()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello from loop"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
async def run_in_loop():
|
||||
return updater.update_memory([msg, ai_msg])
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRunAsyncUpdateSync:
|
||||
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
||||
class CloseableAwaitable:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def __await__(self):
|
||||
pytest.fail("awaitable should not have been awaited")
|
||||
yield
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
awaitable = CloseableAwaitable()
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
):
|
||||
|
||||
async def run_in_loop():
|
||||
return _run_async_update_sync(awaitable)
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is False
|
||||
assert awaitable.closed is True
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -755,19 +627,14 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@@ -786,12 +653,7 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
@@ -804,17 +666,16 @@ class TestReinforcementHint:
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -829,17 +690,16 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -854,17 +714,16 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
@@ -879,56 +738,6 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
|
||||
class TestFinalizeCacheIsolation:
|
||||
"""_finalize_update must not mutate the cached memory object."""
|
||||
|
||||
def test_deepcopy_prevents_cache_corruption_on_save_failure(self):
|
||||
"""If save() fails, the in-memory snapshot used by _finalize_update
|
||||
must remain independent of any object the storage layer may still hold in
|
||||
its cache. The deepcopy in _finalize_update achieves this — the object
|
||||
passed to _apply_updates is always a fresh copy, never the cache reference.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
original_memory = _make_memory(facts=[{"id": "fact_orig", "content": "original", "category": "context", "confidence": 0.9, "createdAt": "2024-01-01T00:00:00Z", "source": "t1"}])
|
||||
|
||||
import json as _json
|
||||
|
||||
new_fact_json = _json.dumps(
|
||||
{
|
||||
"user": {},
|
||||
"history": {},
|
||||
"newFacts": [{"content": "new fact", "category": "context", "confidence": 0.9}],
|
||||
"factsToRemove": [],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = new_fact_json
|
||||
mock_model = AsyncMock()
|
||||
mock_model.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
saved_objects: list[dict] = []
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=mock_model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=original_memory),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=save_mock)),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "world"
|
||||
ai_msg.tool_calls = []
|
||||
updater.update_memory([msg, ai_msg], thread_id="t1")
|
||||
|
||||
# original_memory must not have been mutated — deepcopy isolates the mutation
|
||||
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
||||
assert original_memory["facts"][0]["content"] == "original"
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for user_id propagation in memory updater."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
||||
|
||||
|
||||
def test_get_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.load.return_value = {"version": "1.0"}
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
get_memory_data(_TEST_MEMORY_CONFIG, user_id="alice")
|
||||
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||
|
||||
|
||||
def test_save_memory_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
_save_memory_to_file(_TEST_MEMORY_CONFIG, {"version": "1.0"}, user_id="bob")
|
||||
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||
|
||||
|
||||
def test_clear_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
clear_memory_data(_TEST_MEMORY_CONFIG, user_id="charlie")
|
||||
# Verify save was called with user_id
|
||||
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Tests for per-user data migration."""
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths(base_dir: Path) -> Paths:
|
||||
return Paths(base_dir)
|
||||
|
||||
|
||||
class TestMigrateThreadDirs:
|
||||
def test_moves_thread_to_user_dir(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
(legacy / "file.txt").write_text("hello")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
|
||||
assert expected.exists()
|
||||
assert expected.read_text() == "hello"
|
||||
assert not (base_dir / "threads" / "t1").exists()
|
||||
|
||||
def test_unowned_thread_goes_to_default(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t2" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
expected = base_dir / "users" / "default" / "threads" / "t2"
|
||||
assert expected.exists()
|
||||
|
||||
def test_idempotent_skip_already_migrated(self, base_dir: Path, paths: Paths):
|
||||
new_dir = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||
new_dir.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
assert new_dir.exists()
|
||||
|
||||
def test_conflict_preserved(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
(legacy / "old.txt").write_text("old")
|
||||
|
||||
dest = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||
dest.mkdir(parents=True)
|
||||
(dest / "new.txt").write_text("new")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
assert (dest / "new.txt").read_text() == "new"
|
||||
conflicts = base_dir / "migration-conflicts" / "t1"
|
||||
assert conflicts.exists()
|
||||
|
||||
def test_cleans_up_empty_legacy_dir(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
assert not (base_dir / "threads").exists()
|
||||
|
||||
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
|
||||
|
||||
assert len(report) == 1
|
||||
assert (base_dir / "threads" / "t1").exists() # not moved
|
||||
assert not (base_dir / "users" / "alice" / "threads" / "t1").exists()
|
||||
|
||||
|
||||
class TestMigrateMemory:
|
||||
def test_moves_global_memory(self, base_dir: Path, paths: Paths):
|
||||
legacy_mem = base_dir / "memory.json"
|
||||
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
expected = base_dir / "users" / "default" / "memory.json"
|
||||
assert expected.exists()
|
||||
assert not legacy_mem.exists()
|
||||
|
||||
def test_skips_if_destination_exists(self, base_dir: Path, paths: Paths):
|
||||
legacy_mem = base_dir / "memory.json"
|
||||
legacy_mem.write_text(json.dumps({"version": "old"}))
|
||||
|
||||
dest = base_dir / "users" / "default" / "memory.json"
|
||||
dest.parent.mkdir(parents=True)
|
||||
dest.write_text(json.dumps({"version": "new"}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
assert json.loads(dest.read_text())["version"] == "new"
|
||||
assert (base_dir / "memory.legacy.json").exists()
|
||||
|
||||
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default") # should not raise
|
||||
@@ -72,8 +72,7 @@ class FakeChatModel(BaseChatModel):
|
||||
|
||||
|
||||
def _patch_factory(monkeypatch, app_config: AppConfig, model_class=FakeChatModel):
|
||||
"""Patch get_app_config, resolve_class, and tracing for isolated unit tests."""
|
||||
monkeypatch.setattr(factory_module, "get_app_config", lambda: app_config)
|
||||
"""Patch resolve_class and tracing for isolated unit tests."""
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: model_class)
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
|
||||
|
||||
@@ -88,7 +87,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name=None)
|
||||
factory_module.create_chat_model(name=None, app_config=cfg)
|
||||
|
||||
# resolve_class is called — if we reach here without ValueError, the correct model was used
|
||||
assert FakeChatModel.captured_kwargs.get("model") == "alpha"
|
||||
@@ -96,11 +95,10 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
|
||||
|
||||
def test_raises_when_model_not_found(monkeypatch):
|
||||
cfg = _make_app_config([_make_model("only-model")])
|
||||
monkeypatch.setattr(factory_module, "get_app_config", lambda: cfg)
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
|
||||
|
||||
with pytest.raises(ValueError, match="ghost-model"):
|
||||
factory_module.create_chat_model(name="ghost-model")
|
||||
factory_module.create_chat_model(name="ghost-model", app_config=cfg)
|
||||
|
||||
|
||||
def test_appends_all_tracing_callbacks(monkeypatch):
|
||||
@@ -109,7 +107,7 @@ def test_appends_all_tracing_callbacks(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: ["smith-callback", "langfuse-callback"])
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
model = factory_module.create_chat_model(name="alpha")
|
||||
model = factory_module.create_chat_model(name="alpha", app_config=cfg)
|
||||
|
||||
assert model.callbacks == ["smith-callback", "langfuse-callback"]
|
||||
|
||||
@@ -127,7 +125,7 @@ def test_thinking_enabled_raises_when_not_supported_but_when_thinking_enabled_is
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support thinking"):
|
||||
factory_module.create_chat_model(name="no-think", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-think", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
|
||||
def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(monkeypatch):
|
||||
@@ -138,7 +136,7 @@ def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support thinking"):
|
||||
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
|
||||
def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
|
||||
@@ -147,7 +145,7 @@ def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="thinker", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="thinker", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("temperature") == 1.0
|
||||
assert FakeChatModel.captured_kwargs.get("max_tokens") == 16000
|
||||
@@ -183,7 +181,7 @@ def test_thinking_disabled_openai_gateway_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
assert captured.get("reasoning_effort") == "minimal"
|
||||
@@ -216,7 +214,7 @@ def test_thinking_disabled_langchain_anthropic_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
assert "extra_body" not in captured
|
||||
@@ -238,7 +236,7 @@ def test_thinking_disabled_no_when_thinking_enabled_does_nothing(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="plain", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="plain", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert "extra_body" not in captured
|
||||
assert "thinking" not in captured
|
||||
@@ -278,7 +276,7 @@ def test_when_thinking_disabled_takes_precedence_over_hardcoded_disable(monkeypa
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
# User overrode the hardcoded "minimal" with "low"
|
||||
@@ -310,7 +308,7 @@ def test_when_thinking_disabled_not_used_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
# when_thinking_enabled should apply, NOT when_thinking_disabled
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
@@ -339,7 +337,7 @@ def test_when_thinking_disabled_without_when_thinking_enabled_still_applies(monk
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# when_thinking_disabled is now gated independently of has_thinking_settings
|
||||
assert captured.get("reasoning_effort") == "low"
|
||||
@@ -370,7 +368,7 @@ def test_when_thinking_disabled_excluded_from_model_dump(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
# when_thinking_disabled value must NOT appear as a raw key
|
||||
assert "when_thinking_disabled" not in captured
|
||||
@@ -394,7 +392,7 @@ def test_reasoning_effort_cleared_when_not_supported(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-effort", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="no-effort", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
@@ -422,7 +420,7 @@ def test_reasoning_effort_preserved_when_supported(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="effort-model", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="effort-model", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# When supports_reasoning_effort=True, it should NOT be cleared to None
|
||||
# The disable path sets it to "minimal"; supports_reasoning_effort=True keeps it
|
||||
@@ -458,7 +456,7 @@ def test_thinking_shortcut_enables_thinking_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert captured.get("thinking") == thinking_settings
|
||||
|
||||
@@ -488,7 +486,7 @@ def test_thinking_shortcut_disables_thinking_when_thinking_disabled(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
assert "extra_body" not in captured
|
||||
@@ -520,7 +518,7 @@ def test_thinking_shortcut_merges_with_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="merge-model", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="merge-model", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
# Both the thinking shortcut and when_thinking_enabled settings should be applied
|
||||
assert captured.get("thinking") == thinking_settings
|
||||
@@ -552,7 +550,7 @@ def test_thinking_shortcut_not_leaked_into_model_when_disabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-leak", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="no-leak", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# The disable path should have set thinking to disabled (not the raw enabled shortcut)
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
@@ -590,7 +588,7 @@ def test_openai_compatible_provider_passes_base_url(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
factory_module.create_chat_model(name="minimax-m2.5", app_config=cfg)
|
||||
|
||||
assert captured.get("model") == "MiniMax-M2.5"
|
||||
assert captured.get("base_url") == "https://api.minimax.io/v1"
|
||||
@@ -731,11 +729,11 @@ def test_openai_compatible_provider_multiple_models(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
# Create first model
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
factory_module.create_chat_model(name="minimax-m2.5", app_config=cfg)
|
||||
assert captured.get("model") == "MiniMax-M2.5"
|
||||
|
||||
# Create second model
|
||||
factory_module.create_chat_model(name="minimax-m2.5-highspeed")
|
||||
factory_module.create_chat_model(name="minimax-m2.5-highspeed", app_config=cfg)
|
||||
assert captured.get("model") == "MiniMax-M2.5-highspeed"
|
||||
|
||||
|
||||
@@ -763,7 +761,7 @@ def test_codex_provider_disables_reasoning_when_thinking_disabled(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "none"
|
||||
|
||||
@@ -783,7 +781,7 @@ def test_codex_provider_preserves_explicit_reasoning_effort(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high")
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high", app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "high"
|
||||
|
||||
@@ -803,7 +801,7 @@ def test_codex_provider_defaults_reasoning_effort_to_medium(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "medium"
|
||||
|
||||
@@ -824,7 +822,7 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=cfg)
|
||||
|
||||
assert "max_tokens" not in FakeChatModel.captured_kwargs
|
||||
|
||||
@@ -837,7 +835,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
model = model.model_copy(update={"extra_body": {"top_k": 20}})
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
@@ -850,7 +848,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
@@ -864,7 +862,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
model = model.model_copy(update={"extra_body": {"top_k": 20}})
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
@@ -877,7 +875,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
assert captured.get("extra_body") == {
|
||||
"top_k": 20,
|
||||
@@ -886,6 +884,85 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_usage injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeWithStreamUsage(FakeChatModel):
|
||||
"""Fake model that declares stream_usage in model_fields (like BaseChatOpenAI)."""
|
||||
|
||||
stream_usage: bool | None = None
|
||||
|
||||
|
||||
def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
|
||||
"""Factory should set stream_usage=True for models with stream_usage field."""
|
||||
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(_FakeWithStreamUsage):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek", app_config=cfg)
|
||||
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
|
||||
def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
|
||||
"""Factory should NOT inject stream_usage for models without the field."""
|
||||
cfg = _make_app_config([_make_model("claude", use="langchain_anthropic:ChatAnthropic")])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="claude", app_config=cfg)
|
||||
|
||||
assert "stream_usage" not in captured
|
||||
|
||||
|
||||
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
|
||||
"""If config dumps stream_usage=False, factory should respect it."""
|
||||
# Build a ModelConfig with stream_usage=False as an extra field (extra="allow").
|
||||
model_with_stream_usage = ModelConfig(
|
||||
name="deepseek",
|
||||
display_name="deepseek",
|
||||
description=None,
|
||||
use="langchain_deepseek:ChatDeepSeek",
|
||||
model="deepseek",
|
||||
supports_thinking=False,
|
||||
supports_vision=False,
|
||||
stream_usage=False,
|
||||
)
|
||||
cfg = _make_app_config([model_with_stream_usage])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(_FakeWithStreamUsage):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek", app_config=cfg)
|
||||
|
||||
assert captured.get("stream_usage") is False
|
||||
|
||||
|
||||
def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
model = ModelConfig(
|
||||
name="gpt-5-responses",
|
||||
@@ -911,7 +988,7 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="gpt-5-responses")
|
||||
factory_module.create_chat_model(name="gpt-5-responses", app_config=cfg)
|
||||
|
||||
assert captured.get("use_responses_api") is True
|
||||
assert captured.get("output_version") == "responses/v1"
|
||||
@@ -952,7 +1029,7 @@ def test_no_duplicate_kwarg_when_reasoning_effort_in_config_and_thinking_disable
|
||||
_patch_factory(monkeypatch, cfg, model_class=CapturingModel)
|
||||
|
||||
# Must not raise TypeError
|
||||
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False, app_config=cfg)
|
||||
|
||||
# kwargs (runtime) takes precedence: thinking-disabled path sets reasoning_effort=minimal
|
||||
assert captured.get("reasoning_effort") == "minimal"
|
||||
|
||||
@@ -0,0 +1,465 @@
|
||||
"""Cross-user isolation tests — non-negotiable safety gate.
|
||||
|
||||
Mirrors TC-API-17..20 from backend/docs/AUTH_TEST_PLAN.md. A failure
|
||||
here means users can see each other's data; PR must not merge.
|
||||
|
||||
Architecture note
|
||||
-----------------
|
||||
These tests bypass the HTTP layer and exercise the storage-layer
|
||||
owner filter directly by switching the ``user_context`` contextvar
|
||||
between two users. The safety property under test is:
|
||||
|
||||
After a repository write with user_id=A, a subsequent read with
|
||||
user_id=B must not return the row, and vice versa.
|
||||
|
||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||
that a request cookie reaches the ``set_current_user`` call. Together
|
||||
the two suites prove the full chain:
|
||||
|
||||
cookie → middleware → contextvar → repository → isolation
|
||||
|
||||
Every test in this file opts out of the autouse contextvar fixture
|
||||
(``@pytest.mark.no_auto_user``) so it can set the contextvar to the
|
||||
specific users it cares about.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
async def _make_engines(tmp_path):
|
||||
"""Initialize the shared engine against a per-test SQLite DB.
|
||||
|
||||
Returns a cleanup coroutine the caller should await at the end.
|
||||
"""
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return close_engine
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
"""Context manager-like helper that set/reset the contextvar."""
|
||||
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
# ── TC-API-17 — threads_meta isolation ────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
# User A creates a thread.
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha", display_name="A's private thread")
|
||||
|
||||
# User B creates a thread.
|
||||
with _as_user(USER_B):
|
||||
await repo.create("t-beta", display_name="B's private thread")
|
||||
|
||||
# User A must see only A's thread.
|
||||
with _as_user(USER_A):
|
||||
a_view = await repo.get("t-alpha")
|
||||
assert a_view is not None
|
||||
assert a_view["display_name"] == "A's private thread"
|
||||
|
||||
# CRITICAL: User A must NOT see B's thread.
|
||||
leaked = await repo.get("t-beta")
|
||||
assert leaked is None, f"User A leaked User B's thread: {leaked}"
|
||||
|
||||
# Search should only return A's threads.
|
||||
results = await repo.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
# User B must see only B's thread.
|
||||
with _as_user(USER_B):
|
||||
b_view = await repo.get("t-beta")
|
||||
assert b_view is not None
|
||||
assert b_view["display_name"] == "B's private thread"
|
||||
|
||||
leaked = await repo.get("t-alpha")
|
||||
assert leaked is None, f"User B leaked User A's thread: {leaked}"
|
||||
|
||||
results = await repo.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_mutation_denied(tmp_path):
|
||||
"""User B cannot update or delete a thread owned by User A."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha", display_name="original")
|
||||
|
||||
# User B tries to rename A's thread — must be a no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.update_display_name("t-alpha", "hacked")
|
||||
|
||||
# Verify the row is unchanged from A's perspective.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
# User B tries to delete A's thread — must be a no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.delete("t-alpha")
|
||||
|
||||
# A's thread still exists.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("t-alpha")
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-18 — runs isolation ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = RunRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.put("run-a1", thread_id="t-alpha")
|
||||
await repo.put("run-a2", thread_id="t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await repo.put("run-b1", thread_id="t-beta")
|
||||
|
||||
# User A must see only A's runs.
|
||||
with _as_user(USER_A):
|
||||
r = await repo.get("run-a1")
|
||||
assert r is not None
|
||||
assert r["run_id"] == "run-a1"
|
||||
|
||||
leaked = await repo.get("run-b1")
|
||||
assert leaked is None, "User A leaked User B's run"
|
||||
|
||||
a_runs = await repo.list_by_thread("t-alpha")
|
||||
assert {r["run_id"] for r in a_runs} == {"run-a1", "run-a2"}
|
||||
|
||||
# Listing B's thread from A's perspective: empty
|
||||
empty = await repo.list_by_thread("t-beta")
|
||||
assert empty == []
|
||||
|
||||
# User B must see only B's runs.
|
||||
with _as_user(USER_B):
|
||||
leaked = await repo.get("run-a1")
|
||||
assert leaked is None, "User B leaked User A's run"
|
||||
|
||||
b_runs = await repo.list_by_thread("t-beta")
|
||||
assert [r["run_id"] for r in b_runs] == ["run-b1"]
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_delete_denied(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = RunRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.put("run-a1", thread_id="t-alpha")
|
||||
|
||||
# User B tries to delete A's run — no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.delete("run-a1")
|
||||
|
||||
# A's run still exists.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("run-a1")
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-19 — run_events isolation (CRITICAL: content leak) ─────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_isolation(tmp_path):
|
||||
"""run_events holds raw conversation content — most sensitive leak vector."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
store = DbRunEventStore(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="User A private question",
|
||||
)
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content="User A private answer",
|
||||
)
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.put(
|
||||
thread_id="t-beta",
|
||||
run_id="run-b1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="User B private question",
|
||||
)
|
||||
|
||||
# User A must see only A's events — CRITICAL.
|
||||
with _as_user(USER_A):
|
||||
msgs = await store.list_messages("t-alpha")
|
||||
contents = [m["content"] for m in msgs]
|
||||
assert "User A private question" in contents
|
||||
assert "User A private answer" in contents
|
||||
# CRITICAL: User B's content must not appear.
|
||||
assert "User B private question" not in contents
|
||||
|
||||
# Attempt to read B's thread by guessing thread_id.
|
||||
leaked = await store.list_messages("t-beta")
|
||||
assert leaked == [], f"User A leaked User B's messages: {leaked}"
|
||||
|
||||
leaked_events = await store.list_events("t-beta", "run-b1")
|
||||
assert leaked_events == [], "User A leaked User B's events"
|
||||
|
||||
# count_messages must also be zero for B's thread from A's view.
|
||||
count = await store.count_messages("t-beta")
|
||||
assert count == 0
|
||||
|
||||
# User B must see only B's events.
|
||||
with _as_user(USER_B):
|
||||
msgs = await store.list_messages("t-beta")
|
||||
contents = [m["content"] for m in msgs]
|
||||
assert "User B private question" in contents
|
||||
assert "User A private question" not in contents
|
||||
assert "User A private answer" not in contents
|
||||
|
||||
count = await store.count_messages("t-alpha")
|
||||
assert count == 0
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_delete_denied(tmp_path):
|
||||
"""User B cannot delete User A's event stream."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
store = DbRunEventStore(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="hello",
|
||||
)
|
||||
|
||||
# User B tries to wipe A's thread events.
|
||||
with _as_user(USER_B):
|
||||
removed = await store.delete_by_thread("t-alpha")
|
||||
assert removed == 0, f"User B deleted {removed} of User A's events"
|
||||
|
||||
# A's events still exist.
|
||||
with _as_user(USER_A):
|
||||
count = await store.count_messages("t-alpha")
|
||||
assert count == 1
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-20 — feedback isolation ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = FeedbackRepository(get_session_factory())
|
||||
|
||||
# User A submits positive feedback.
|
||||
with _as_user(USER_A):
|
||||
a_feedback = await repo.create(
|
||||
run_id="run-a1",
|
||||
thread_id="t-alpha",
|
||||
rating=1,
|
||||
comment="A liked this",
|
||||
)
|
||||
|
||||
# User B submits negative feedback.
|
||||
with _as_user(USER_B):
|
||||
b_feedback = await repo.create(
|
||||
run_id="run-b1",
|
||||
thread_id="t-beta",
|
||||
rating=-1,
|
||||
comment="B disliked this",
|
||||
)
|
||||
|
||||
# User A must see only A's feedback.
|
||||
with _as_user(USER_A):
|
||||
retrieved = await repo.get(a_feedback["feedback_id"])
|
||||
assert retrieved is not None
|
||||
assert retrieved["comment"] == "A liked this"
|
||||
|
||||
# CRITICAL: cannot read B's feedback by id.
|
||||
leaked = await repo.get(b_feedback["feedback_id"])
|
||||
assert leaked is None, "User A leaked User B's feedback"
|
||||
|
||||
# list_by_run for B's run must be empty.
|
||||
empty = await repo.list_by_run("t-beta", "run-b1")
|
||||
assert empty == []
|
||||
|
||||
# User B must see only B's feedback.
|
||||
with _as_user(USER_B):
|
||||
leaked = await repo.get(a_feedback["feedback_id"])
|
||||
assert leaked is None, "User B leaked User A's feedback"
|
||||
|
||||
b_list = await repo.list_by_run("t-beta", "run-b1")
|
||||
assert len(b_list) == 1
|
||||
assert b_list[0]["comment"] == "B disliked this"
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_delete_denied(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = FeedbackRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
fb = await repo.create(run_id="run-a1", thread_id="t-alpha", rating=1)
|
||||
|
||||
# User B tries to delete A's feedback — must return False (no-op).
|
||||
with _as_user(USER_B):
|
||||
deleted = await repo.delete(fb["feedback_id"])
|
||||
assert deleted is False, "User B deleted User A's feedback"
|
||||
|
||||
# A's feedback still retrievable.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get(fb["feedback_id"])
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Regression: AUTO sentinel without contextvar must raise ───────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_repository_without_context_raises(tmp_path):
|
||||
"""Defense-in-depth: calling repo methods without a user context errors."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
# Contextvar is explicitly unset under @pytest.mark.no_auto_user.
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await repo.get("anything")
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
# Seed data as two different users.
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await repo.create("t-beta")
|
||||
|
||||
# Migration-style read: no contextvar, explicit None bypass.
|
||||
all_rows = await repo.search(user_id=None)
|
||||
thread_ids = {r["thread_id"] for r in all_rows}
|
||||
assert thread_ids == {"t-alpha", "t-beta"}
|
||||
|
||||
# Explicit get with None does not apply the filter either.
|
||||
row_a = await repo.get("t-alpha", user_id=None)
|
||||
assert row_a is not None
|
||||
row_b = await repo.get("t-beta", user_id=None)
|
||||
assert row_b is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Tests for user-scoped path resolution in Paths."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths(tmp_path: Path) -> Paths:
|
||||
return Paths(tmp_path)
|
||||
|
||||
|
||||
class TestValidateUserId:
|
||||
def test_valid_user_id(self, paths: Paths):
|
||||
d = paths.user_dir("u-abc-123")
|
||||
assert d == paths.base_dir / "users" / "u-abc-123"
|
||||
|
||||
def test_rejects_path_traversal(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("../escape")
|
||||
|
||||
def test_rejects_slash(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("foo/bar")
|
||||
|
||||
def test_rejects_empty(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("")
|
||||
|
||||
|
||||
class TestUserDir:
|
||||
def test_user_dir(self, paths: Paths):
|
||||
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
||||
|
||||
|
||||
class TestUserMemoryFile:
|
||||
def test_user_memory_file(self, paths: Paths):
|
||||
assert paths.user_memory_file("bob") == paths.base_dir / "users" / "bob" / "memory.json"
|
||||
|
||||
|
||||
class TestUserAgentMemoryFile:
|
||||
def test_user_agent_memory_file(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||
assert paths.user_agent_memory_file("bob", "myagent") == expected
|
||||
|
||||
def test_user_agent_memory_file_lowercases_name(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
||||
|
||||
|
||||
class TestUserThreadDir:
|
||||
def test_user_thread_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
||||
assert paths.thread_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_thread_dir_no_user_id_falls_back_to_legacy(self, paths: Paths):
|
||||
expected = paths.base_dir / "threads" / "t1"
|
||||
assert paths.thread_dir("t1") == expected
|
||||
|
||||
|
||||
class TestUserSandboxDirs:
|
||||
def test_sandbox_work_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "workspace"
|
||||
assert paths.sandbox_work_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_uploads_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "uploads"
|
||||
assert paths.sandbox_uploads_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_outputs_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "outputs"
|
||||
assert paths.sandbox_outputs_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_user_data_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data"
|
||||
assert paths.sandbox_user_data_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_acp_workspace_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "acp-workspace"
|
||||
assert paths.acp_workspace_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_legacy_sandbox_work_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
assert paths.sandbox_work_dir("t1") == expected
|
||||
|
||||
|
||||
class TestHostPathsWithUserId:
|
||||
def test_host_thread_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_thread_dir("t1", user_id="u1")
|
||||
assert "users" in result
|
||||
assert "u1" in result
|
||||
assert "threads" in result
|
||||
assert "t1" in result
|
||||
|
||||
def test_host_thread_dir_legacy(self, paths: Paths):
|
||||
result = paths.host_thread_dir("t1")
|
||||
assert "threads" in result
|
||||
assert "t1" in result
|
||||
assert "users" not in result
|
||||
|
||||
def test_host_sandbox_user_data_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_user_data_dir("t1", user_id="u1")
|
||||
assert "users" in result
|
||||
assert "user-data" in result
|
||||
|
||||
def test_host_sandbox_work_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_work_dir("t1", user_id="u1")
|
||||
assert "workspace" in result
|
||||
|
||||
def test_host_sandbox_uploads_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_uploads_dir("t1", user_id="u1")
|
||||
assert "uploads" in result
|
||||
|
||||
def test_host_sandbox_outputs_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_outputs_dir("t1", user_id="u1")
|
||||
assert "outputs" in result
|
||||
|
||||
def test_host_acp_workspace_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_acp_workspace_dir("t1", user_id="u1")
|
||||
assert "acp-workspace" in result
|
||||
|
||||
|
||||
class TestEnsureAndDeleteWithUserId:
|
||||
def test_ensure_thread_dirs_creates_user_scoped(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
assert paths.sandbox_work_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.sandbox_uploads_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.sandbox_outputs_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.acp_workspace_dir("t1", user_id="u1").is_dir()
|
||||
|
||||
def test_delete_thread_dir_removes_user_scoped(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||
paths.delete_thread_dir("t1", user_id="u1")
|
||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||
|
||||
def test_delete_thread_dir_idempotent(self, paths: Paths):
|
||||
paths.delete_thread_dir("nonexistent", user_id="u1") # should not raise
|
||||
|
||||
def test_ensure_thread_dirs_legacy_still_works(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1")
|
||||
assert paths.sandbox_work_dir("t1").is_dir()
|
||||
|
||||
def test_user_scoped_and_legacy_are_independent(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
paths.ensure_thread_dirs("t1")
|
||||
# Both exist independently
|
||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||
assert paths.thread_dir("t1").exists()
|
||||
# Delete one doesn't affect the other
|
||||
paths.delete_thread_dir("t1", user_id="u1")
|
||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||
assert paths.thread_dir("t1").exists()
|
||||
|
||||
|
||||
class TestResolveVirtualPathWithUserId:
|
||||
def test_resolve_virtual_path_with_user_id(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt", user_id="u1")
|
||||
expected_base = paths.sandbox_user_data_dir("t1", user_id="u1").resolve()
|
||||
assert str(result).startswith(str(expected_base))
|
||||
|
||||
def test_resolve_virtual_path_legacy(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1")
|
||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt")
|
||||
expected_base = paths.sandbox_user_data_dir("t1").resolve()
|
||||
assert str(result).startswith(str(expected_base))
|
||||
@@ -0,0 +1,233 @@
|
||||
"""Tests for the persistence layer scaffolding.
|
||||
|
||||
Tests:
|
||||
1. DatabaseConfig property derivation (paths, URLs)
|
||||
2. MemoryRunStore CRUD + user_id filtering
|
||||
3. Base.to_dict() via inspect mixin
|
||||
4. Engine init/close lifecycle (memory + SQLite)
|
||||
5. Postgres missing-dep error message
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
# -- DatabaseConfig --
|
||||
|
||||
|
||||
class TestDatabaseConfig:
|
||||
def test_defaults(self):
|
||||
c = DatabaseConfig()
|
||||
assert c.backend == "memory"
|
||||
assert c.pool_size == 5
|
||||
|
||||
def test_sqlite_paths_unified(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||
assert c.sqlite_path.endswith("deerflow.db")
|
||||
assert "mydata" in c.sqlite_path
|
||||
# Backward-compatible aliases point to the same file
|
||||
assert c.checkpointer_sqlite_path == c.sqlite_path
|
||||
assert c.app_sqlite_path == c.sqlite_path
|
||||
|
||||
def test_app_sqlalchemy_url_sqlite(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("sqlite+aiosqlite:///")
|
||||
assert "deerflow.db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres(self):
|
||||
c = DatabaseConfig(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://u:p@h:5432/db",
|
||||
)
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("postgresql+asyncpg://")
|
||||
assert "u:p@h:5432/db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres_already_asyncpg(self):
|
||||
c = DatabaseConfig(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql+asyncpg://u:p@h:5432/db",
|
||||
)
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.count("asyncpg") == 1
|
||||
|
||||
def test_memory_has_no_url(self):
|
||||
c = DatabaseConfig(backend="memory")
|
||||
with pytest.raises(ValueError, match="No SQLAlchemy URL"):
|
||||
_ = c.app_sqlalchemy_url
|
||||
|
||||
|
||||
# -- MemoryRunStore --
|
||||
|
||||
|
||||
class TestMemoryRunStore:
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
return MemoryRunStore()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, store):
|
||||
await store.put("r1", thread_id="t1", status="pending")
|
||||
row = await store.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["status"] == "pending"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, store):
|
||||
assert await store.get("nope") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.update_status("r1", "running")
|
||||
assert (await store.get("r1"))["status"] == "running"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.update_status("r1", "error", error="boom")
|
||||
row = await store.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.put("r2", thread_id="t1")
|
||||
await store.put("r3", thread_id="t2")
|
||||
rows = await store.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, store):
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, store):
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.delete("r1")
|
||||
assert await store.get("r1") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, store):
|
||||
await store.delete("nope") # should not raise
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, store):
|
||||
await store.put("r1", thread_id="t1", status="pending")
|
||||
await store.put("r2", thread_id="t1", status="running")
|
||||
await store.put("r3", thread_id="t2", status="pending")
|
||||
pending = await store.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending_respects_before(self, store):
|
||||
past = "2020-01-01T00:00:00+00:00"
|
||||
future = "2099-01-01T00:00:00+00:00"
|
||||
await store.put("r1", thread_id="t1", status="pending", created_at=past)
|
||||
await store.put("r2", thread_id="t1", status="pending", created_at=future)
|
||||
pending = await store.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending_fifo_order(self, store):
|
||||
await store.put("r2", thread_id="t1", status="pending", created_at="2024-01-02T00:00:00+00:00")
|
||||
await store.put("r1", thread_id="t1", status="pending", created_at="2024-01-01T00:00:00+00:00")
|
||||
pending = await store.list_pending()
|
||||
assert pending[0]["run_id"] == "r1"
|
||||
|
||||
|
||||
# -- Base.to_dict mixin --
|
||||
|
||||
|
||||
class TestBaseToDictMixin:
|
||||
@pytest.mark.anyio
|
||||
async def test_to_dict_and_exclude(self, tmp_path):
|
||||
"""Create a temp SQLite DB with a minimal model, verify to_dict."""
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
class _Tmp(Base):
|
||||
__tablename__ = "_tmp_test"
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
sf = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with sf() as session:
|
||||
session.add(_Tmp(id="1", name="hello"))
|
||||
await session.commit()
|
||||
obj = await session.get(_Tmp, "1")
|
||||
|
||||
assert obj.to_dict() == {"id": "1", "name": "hello"}
|
||||
assert obj.to_dict(exclude={"name"}) == {"id": "1"}
|
||||
assert "_Tmp" in repr(obj)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# -- Engine lifecycle --
|
||||
|
||||
|
||||
class TestEngineLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_is_noop(self):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("memory")
|
||||
assert get_session_factory() is None
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_engine(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
assert sf is not None
|
||||
async with sf() as session:
|
||||
assert session is not None
|
||||
await close_engine()
|
||||
assert get_session_factory() is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_without_asyncpg_gives_actionable_error(self):
|
||||
"""If asyncpg is not installed, error message tells user what to do."""
|
||||
from deerflow.persistence.engine import init_engine
|
||||
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
|
||||
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
|
||||
except ImportError:
|
||||
# asyncpg is not installed — this is the expected state for this test.
|
||||
# We proceed to verify that init_engine raises an actionable ImportError.
|
||||
pass # noqa: S110 — intentionally ignored
|
||||
with pytest.raises(ImportError, match="uv sync --extra postgres"):
|
||||
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
|
||||
@@ -3,14 +3,24 @@
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
present_file_tool_module = importlib.import_module("deerflow.tools.builtins.present_file_tool")
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(outputs_path: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
state={"thread_data": {"outputs_path": outputs_path}},
|
||||
context={"thread_id": "thread-1"},
|
||||
config={},
|
||||
context=_make_context("thread-1"),
|
||||
)
|
||||
|
||||
|
||||
@@ -39,7 +49,7 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
present_file_tool_module,
|
||||
"get_paths",
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path, *, user_id=None: artifact_path),
|
||||
)
|
||||
|
||||
result = present_file_tool_module.present_file_tool.func(
|
||||
@@ -51,34 +61,6 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||
|
||||
|
||||
def test_present_files_uses_config_thread_id_when_context_missing(tmp_path, monkeypatch):
|
||||
outputs_dir = tmp_path / "threads" / "thread-from-config" / "user-data" / "outputs"
|
||||
outputs_dir.mkdir(parents=True)
|
||||
artifact_path = outputs_dir / "summary.json"
|
||||
artifact_path.write_text("{}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
present_file_tool_module,
|
||||
"get_paths",
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
state={"thread_data": {"outputs_path": str(outputs_dir)}},
|
||||
context={},
|
||||
config={"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
|
||||
result = present_file_tool_module.present_file_tool.func(
|
||||
runtime=runtime,
|
||||
filepaths=["/mnt/user-data/outputs/summary.json"],
|
||||
tool_call_id="tc-config",
|
||||
)
|
||||
|
||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||
assert result.update["messages"][0].content == "Successfully presented files"
|
||||
|
||||
|
||||
def test_present_files_rejects_paths_outside_outputs(tmp_path):
|
||||
outputs_dir = tmp_path / "threads" / "thread-1" / "user-data" / "outputs"
|
||||
workspace_dir = tmp_path / "threads" / "thread-1" / "user-data" / "workspace"
|
||||
|
||||
@@ -0,0 +1,500 @@
|
||||
"""Tests for RunEventStore contract across all backends.
|
||||
|
||||
Uses a helper to create the store for each backend type.
|
||||
Memory tests run directly; DB and JSONL tests create stores inside each test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
# -- Basic write and query --
|
||||
|
||||
|
||||
class TestPutAndSeq:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_returns_dict_with_seq(self, store):
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
|
||||
assert "seq" in record
|
||||
assert record["seq"] == 1
|
||||
assert record["thread_id"] == "t1"
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["event_type"] == "human_message"
|
||||
assert record["category"] == "message"
|
||||
assert record["content"] == "hello"
|
||||
assert "created_at" in record
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_seq_strictly_increasing_same_thread(self, store):
|
||||
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
r2 = await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
r3 = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
assert r1["seq"] == 1
|
||||
assert r2["seq"] == 2
|
||||
assert r3["seq"] == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_seq_independent_across_threads(self, store):
|
||||
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
r2 = await store.put(thread_id="t2", run_id="r2", event_type="human_message", category="message")
|
||||
assert r1["seq"] == 1
|
||||
assert r2["seq"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_respects_provided_created_at(self, store):
|
||||
ts = "2024-06-01T12:00:00+00:00"
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", created_at=ts)
|
||||
assert record["created_at"] == ts
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_metadata_preserved(self, store):
|
||||
meta = {"model": "gpt-4", "tokens": 100}
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", metadata=meta)
|
||||
assert record["metadata"] == meta
|
||||
|
||||
|
||||
# -- list_messages --
|
||||
|
||||
|
||||
class TestListMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_only_returns_message_category(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["category"] == "message"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ascending_seq_order(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="first")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="second")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="third")
|
||||
messages = await store.list_messages("t1")
|
||||
seqs = [m["seq"] for m in messages]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_before_seq_pagination(self, store):
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", before_seq=6, limit=3)
|
||||
assert len(messages) == 3
|
||||
assert [m["seq"] for m in messages] == [3, 4, 5]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_after_seq_pagination(self, store):
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", after_seq=7, limit=3)
|
||||
assert len(messages) == 3
|
||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_limit_restricts_count(self, store):
|
||||
for _ in range(20):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
messages = await store.list_messages("t1", limit=5)
|
||||
assert len(messages) == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_unified_ordering(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
messages = await store.list_messages("t1")
|
||||
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[2]["run_id"] == "r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_returns_latest(self, store):
|
||||
for _ in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
messages = await store.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||
|
||||
|
||||
# -- list_events --
|
||||
|
||||
|
||||
class TestListEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_all_categories_for_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_event_types_filter(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_start", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="tool_start", category="trace")
|
||||
events = await store.list_events("t1", "r1", event_types=["llm_end"])
|
||||
assert len(events) == 1
|
||||
assert events[0]["event_type"] == "llm_end"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_only_returns_specified_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) == 1
|
||||
assert events[0]["run_id"] == "r1"
|
||||
|
||||
|
||||
# -- list_messages_by_run --
|
||||
|
||||
|
||||
class TestListMessagesByRun:
|
||||
@pytest.mark.anyio
|
||||
async def test_only_messages_for_specified_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await store.list_messages_by_run("t1", "r1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[0]["category"] == "message"
|
||||
|
||||
|
||||
# -- count_messages --
|
||||
|
||||
|
||||
class TestCountMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_counts_only_message_category(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
assert await store.count_messages("t1") == 2
|
||||
|
||||
|
||||
# -- put_batch --
|
||||
|
||||
|
||||
class TestPutBatch:
|
||||
@pytest.mark.anyio
|
||||
async def test_batch_assigns_seq(self, store):
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"},
|
||||
]
|
||||
results = await store.put_batch(events)
|
||||
assert len(results) == 3
|
||||
assert all("seq" in r for r in results)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_batch_seq_strictly_increasing(self, store):
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message"},
|
||||
]
|
||||
results = await store.put_batch(events)
|
||||
assert results[0]["seq"] == 1
|
||||
assert results[1]["seq"] == 2
|
||||
|
||||
|
||||
# -- delete --
|
||||
|
||||
|
||||
class TestDelete:
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_thread(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_thread("t1")
|
||||
assert count == 3
|
||||
assert await store.list_messages("t1") == []
|
||||
assert await store.count_messages("t1") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_run("t1", "r2")
|
||||
assert count == 2
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_thread_returns_zero(self, store):
|
||||
assert await store.delete_by_thread("nope") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_run_returns_zero(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert await store.delete_by_run("t1", "nope") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_thread_for_run_returns_zero(self, store):
|
||||
assert await store.delete_by_run("nope", "r1") == 0
|
||||
|
||||
|
||||
# -- Edge cases --
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_list_messages(self, store):
|
||||
assert await store.list_messages("empty") == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_run_list_events(self, store):
|
||||
assert await store.list_events("empty", "r1") == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_count_messages(self, store):
|
||||
assert await store.count_messages("empty") == 0
|
||||
|
||||
|
||||
# -- DB-specific tests --
|
||||
|
||||
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
|
||||
assert r2["seq"] == 2
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
|
||||
count = await s.count_messages("t1")
|
||||
assert count == 2
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trace_content_truncation(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
|
||||
|
||||
long = "x" * 200
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
|
||||
assert len(r["content"]) == 100
|
||||
assert r["metadata"].get("content_truncated") is True
|
||||
|
||||
# message content NOT truncated
|
||||
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
|
||||
assert len(m["content"]) == 200
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
for i in range(10):
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
|
||||
# before_seq
|
||||
msgs = await s.list_messages("t1", before_seq=6, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [3, 4, 5]
|
||||
|
||||
# after_seq
|
||||
msgs = await s.list_messages("t1", after_seq=7, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
# default (latest)
|
||||
msgs = await s.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
c = await s.delete_by_thread("t1")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 0
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_seq_continuity(self, tmp_path):
|
||||
"""Batch write produces continuous seq values with no gaps."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"} for _ in range(50)]
|
||||
results = await s.put_batch(events)
|
||||
seqs = [r["seq"] for r in results]
|
||||
assert seqs == list(range(1, 51))
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- Factory tests --
|
||||
|
||||
|
||||
class TestMakeRunEventStore:
|
||||
"""Tests for the make_run_event_store factory function."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_backend_default(self):
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
store = make_run_event_store(None)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_backend_explicit(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "memory"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backend_with_engine(self, tmp_path):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "db"
|
||||
config.max_trace_content = 10240
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "DbRunEventStore"
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backend_no_engine_falls_back(self):
|
||||
"""db backend without engine falls back to memory."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
await init_engine("memory") # no engine created
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "db"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_jsonl_backend(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "jsonl"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "JsonlRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_unknown_backend_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "redis"
|
||||
with pytest.raises(ValueError, match="Unknown"):
|
||||
make_run_event_store(config)
|
||||
|
||||
|
||||
# -- JSONL-specific tests --
|
||||
|
||||
|
||||
class TestJsonlRunEventStore:
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_file_at_correct_path(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_messages(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert [m["seq"] for m in messages] == [1, 2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
|
||||
assert await s.count_messages("t1") == 1
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_default_returns_all(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
assert len(msgs) == 7
|
||||
assert all(m["category"] == "message" for m in msgs)
|
||||
assert all(m["run_id"] == "run-a" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_with_limit(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
||||
assert len(msgs) == 3
|
||||
seqs = [m["seq"] for m in msgs]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_after_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[2]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] > cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_before_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[4]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] < cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message", category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-b")
|
||||
assert len(msgs) == 3
|
||||
assert all(m["run_id"] == "run-b" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_empty_run(base_store):
|
||||
store = base_store
|
||||
msgs = await store.list_messages_by_run("t1", "nonexistent")
|
||||
assert msgs == []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,196 @@
|
||||
"""Tests for RunRepository (SQLAlchemy-backed RunStore).
|
||||
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return RunRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
row = await repo.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.get("nope") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "running")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "error", error="boom")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.put("r2", thread_id="t1")
|
||||
await repo.put("r3", thread_id="t2")
|
||||
rows = await repo.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.delete("r1")
|
||||
assert await repo.get("r1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.delete("nope") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
await repo.put("r2", thread_id="t1", status="running")
|
||||
await repo.put("r3", thread_id="t2", status="pending")
|
||||
pending = await repo.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_completion(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"r1",
|
||||
status="success",
|
||||
total_input_tokens=100,
|
||||
total_output_tokens=50,
|
||||
total_tokens=150,
|
||||
llm_call_count=2,
|
||||
lead_agent_tokens=120,
|
||||
subagent_tokens=20,
|
||||
middleware_tokens=10,
|
||||
message_count=3,
|
||||
last_ai_message="The answer is 42",
|
||||
first_human_message="What is the meaning?",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 150
|
||||
assert row["llm_call_count"] == 2
|
||||
assert row["lead_agent_tokens"] == 120
|
||||
assert row["message_count"] == 3
|
||||
assert row["last_ai_message"] == "The answer is 42"
|
||||
assert row["first_human_message"] == "What is the meaning?"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_metadata_preserved(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", metadata={"key": "value"})
|
||||
row = await repo.get("r1")
|
||||
assert row["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_kwargs_with_non_serializable(self, tmp_path):
|
||||
"""kwargs containing non-JSON-serializable objects should be safely handled."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()})
|
||||
row = await repo.get("r1")
|
||||
assert "obj" in row["kwargs"]
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_completion_preserves_existing_fields(self, tmp_path):
|
||||
"""update_run_completion does not overwrite thread_id or assistant_id."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", assistant_id="agent1", status="running")
|
||||
await repo.update_run_completion("r1", status="success", total_tokens=100)
|
||||
row = await repo.get("r1")
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["assistant_id"] == "agent1"
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", created_at="2024-01-01T00:00:00+00:00")
|
||||
await repo.put("r2", thread_id="t1", created_at="2024-01-02T00:00:00+00:00")
|
||||
rows = await repo.list_by_thread("t1")
|
||||
assert rows[0]["run_id"] == "r2"
|
||||
assert rows[1]["run_id"] == "r1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_limit(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
for i in range(5):
|
||||
await repo.put(f"r{i}", thread_id="t1")
|
||||
rows = await repo.list_by_thread("t1", limit=2)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
@@ -0,0 +1,243 @@
|
||||
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(run_store=None, event_store=None, feedback_repo=None):
|
||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||
app = make_authed_test_app()
|
||||
app.include_router(runs.router)
|
||||
|
||||
if run_store is not None:
|
||||
app.state.run_store = run_store
|
||||
if event_store is not None:
|
||||
app.state.run_event_store = event_store
|
||||
if feedback_repo is not None:
|
||||
app.state.feedback_repo = feedback_repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_run_store(run_record: dict | None):
|
||||
"""Return an AsyncMock run store whose get() returns run_record."""
|
||||
store = MagicMock()
|
||||
store.get = AsyncMock(return_value=run_record)
|
||||
return store
|
||||
|
||||
|
||||
def _make_event_store(rows: list[dict]):
|
||||
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||
store = MagicMock()
|
||||
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||
return store
|
||||
|
||||
|
||||
def _make_message(seq: int) -> dict:
|
||||
return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_messages_returns_envelope():
|
||||
"""GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}."""
|
||||
rows = [_make_message(i) for i in range(1, 4)]
|
||||
run_record = {"run_id": "run-1", "thread_id": "thread-1"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-1/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "data" in body
|
||||
assert "has_more" in body
|
||||
assert body["has_more"] is False
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
|
||||
def test_run_messages_404_when_run_not_found():
|
||||
"""Returns 404 when the run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/messages")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_run_messages_has_more_true_when_extra_row_returned():
|
||||
"""has_more=True when event store returns limit+1 rows."""
|
||||
# Default limit is 50; provide 51 rows
|
||||
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||
run_record = {"run_id": "run-2", "thread_id": "thread-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-2/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["has_more"] is True
|
||||
assert len(body["data"]) == 50 # trimmed to limit
|
||||
|
||||
|
||||
def test_run_messages_passes_after_seq_to_event_store():
|
||||
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(10)]
|
||||
run_record = {"run_id": "run-3", "thread_id": "thread-3"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_respects_custom_limit():
|
||||
"""Custom limit is respected and capped at 200."""
|
||||
rows = [_make_message(i) for i in range(1, 6)]
|
||||
run_record = {"run_id": "run-4", "thread_id": "thread-4"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-4/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4", "run-4",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_passes_before_seq_to_event_store():
|
||||
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(3)]
|
||||
run_record = {"run_id": "run-5", "thread_id": "thread-5"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-5/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5", "run-5",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_empty_data():
|
||||
"""Returns empty data list when no messages exist."""
|
||||
run_record = {"run_id": "run-6", "thread_id": "thread-6"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-6/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert body["has_more"] is False
|
||||
|
||||
|
||||
def _make_feedback_repo(rows: list[dict]):
|
||||
"""Return an AsyncMock feedback repo whose list_by_run() returns rows."""
|
||||
repo = MagicMock()
|
||||
repo.list_by_run = AsyncMock(return_value=rows)
|
||||
return repo
|
||||
|
||||
|
||||
def _make_feedback(run_id: str, idx: int) -> dict:
|
||||
return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestRunFeedback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunFeedback:
|
||||
def test_returns_list_of_feedback_dicts(self):
|
||||
"""GET /api/runs/{run_id}/feedback returns a list of feedback dicts."""
|
||||
run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"}
|
||||
rows = [_make_feedback("run-fb-1", i) for i in range(3)]
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-1/feedback")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert isinstance(body, list)
|
||||
assert len(body) == 3
|
||||
|
||||
def test_404_when_run_not_found(self):
|
||||
"""Returns 404 when run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/feedback")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
def test_empty_list_when_no_feedback(self):
|
||||
"""Returns empty list when no feedback exists for the run."""
|
||||
run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-2/feedback")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_503_when_feedback_repo_not_configured(self):
|
||||
"""Returns 503 when feedback_repo is None (no DB configured)."""
|
||||
run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
)
|
||||
# Explicitly set feedback_repo to None to simulate missing DB
|
||||
app.state.feedback_repo = None
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-3/feedback")
|
||||
assert response.status_code == 503
|
||||
@@ -14,6 +14,10 @@ def _make_runtime(tmp_path):
|
||||
workspace.mkdir()
|
||||
uploads.mkdir()
|
||||
outputs.mkdir()
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
@@ -23,7 +27,10 @@ def _make_runtime(tmp_path):
|
||||
"outputs_path": str(outputs),
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id="thread-1",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -103,8 +110,6 @@ def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
|
||||
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
@@ -324,10 +329,6 @@ def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -
|
||||
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.get_app_config",
|
||||
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
|
||||
)
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
@@ -34,6 +35,53 @@ _THREAD_DATA = {
|
||||
}
|
||||
|
||||
|
||||
def _make_app_config(
|
||||
*,
|
||||
skills_container_path: str = "/mnt/skills",
|
||||
skills_host_path: str | None = None,
|
||||
mounts=None,
|
||||
mcp_servers=None,
|
||||
tool_config_map=None,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a lightweight AppConfig stand-in used by tests.
|
||||
|
||||
Only the attributes accessed by the helpers under test are populated;
|
||||
everything else is omitted to keep the fake minimal and explicit.
|
||||
"""
|
||||
skills_path = Path(skills_host_path) if skills_host_path is not None else None
|
||||
skills_cfg = SimpleNamespace(
|
||||
container_path=skills_container_path,
|
||||
get_skills_path=lambda: skills_path if skills_path is not None else Path("/nonexistent-skills-root-12345"),
|
||||
)
|
||||
sandbox_cfg = SimpleNamespace(mounts=list(mounts) if mounts else [], bash_output_max_chars=20000)
|
||||
extensions_cfg = SimpleNamespace(mcp_servers=dict(mcp_servers) if mcp_servers else {})
|
||||
tool_config_map = dict(tool_config_map or {})
|
||||
return SimpleNamespace(
|
||||
skills=skills_cfg,
|
||||
sandbox=sandbox_cfg,
|
||||
extensions=extensions_cfg,
|
||||
get_tool_config=lambda name: tool_config_map.get(name),
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_APP_CONFIG = _make_app_config()
|
||||
|
||||
|
||||
def _make_ctx(thread_id: str = "thread-1", *, app_config=_DEFAULT_APP_CONFIG, sandbox_key: str | None = None):
|
||||
"""Build a DeerFlowContext-like object with extra attributes allowed.
|
||||
|
||||
``resolve_context`` only checks ``isinstance(ctx, DeerFlowContext)``; for
|
||||
tests that need additional attributes (``sandbox_key``) we use a subclass
|
||||
created at runtime.
|
||||
"""
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _DFC
|
||||
|
||||
ctx = _DFC(app_config=app_config, thread_id=thread_id)
|
||||
if sandbox_key is not None:
|
||||
object.__setattr__(ctx, "sandbox_key", sandbox_key)
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------- replace_virtual_path ----------
|
||||
|
||||
|
||||
@@ -85,7 +133,7 @@ def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing
|
||||
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||
|
||||
|
||||
@@ -94,7 +142,7 @@ def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
|
||||
def test_mask_local_paths_in_output_hides_host_paths() -> None:
|
||||
output = "Created: /tmp/deer-flow/threads/t1/user-data/workspace/result.txt"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert "/tmp/deer-flow/threads/t1/user-data" not in masked
|
||||
assert "/mnt/user-data/workspace/result.txt" in masked
|
||||
@@ -107,7 +155,7 @@ def test_mask_local_paths_in_output_hides_skills_host_paths() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
output = "Reading: /home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert "/home/user/deer-flow/skills" not in masked
|
||||
assert "/mnt/skills/public/bootstrap/SKILL.md" in masked
|
||||
@@ -143,12 +191,12 @@ def test_reject_path_traversal_allows_normal_paths() -> None:
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
|
||||
with pytest.raises(PermissionError, match="configured mount paths"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
|
||||
@@ -158,42 +206,41 @@ def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -
|
||||
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
|
||||
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA)
|
||||
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_user_data_paths() -> None:
|
||||
# Should not raise — user-data paths are always allowed
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_user_data_write() -> None:
|
||||
# read_only=False (default) should still work for user-data paths
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_traversal_in_user_data() -> None:
|
||||
"""Path traversal via .. in user-data paths must be rejected."""
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_traversal_in_skills() -> None:
|
||||
"""Path traversal via .. in skills paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
@@ -201,7 +248,7 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
from deerflow.sandbox.exceptions import SandboxRuntimeError
|
||||
|
||||
with pytest.raises(SandboxRuntimeError):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- _resolve_skills_path ----------
|
||||
@@ -209,32 +256,26 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
|
||||
def test_resolve_skills_path_resolves_correctly() -> None:
|
||||
"""Skills virtual path should resolve to host path."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
|
||||
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
|
||||
# Force get_skills_path().exists() to be True without touching the FS
|
||||
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
|
||||
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", cfg)
|
||||
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
|
||||
|
||||
def test_resolve_skills_path_resolves_root() -> None:
|
||||
"""Skills container root should resolve to host skills directory."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
resolved = _resolve_skills_path("/mnt/skills")
|
||||
assert resolved == "/home/user/deer-flow/skills"
|
||||
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
|
||||
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
|
||||
resolved = _resolve_skills_path("/mnt/skills", cfg)
|
||||
assert resolved == "/home/user/deer-flow/skills"
|
||||
|
||||
|
||||
def test_resolve_skills_path_raises_when_not_configured() -> None:
|
||||
"""Should raise FileNotFoundError when skills directory is not available."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=None),
|
||||
):
|
||||
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
|
||||
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
|
||||
# Default app config has no host path configured → _get_skills_host_path returns None
|
||||
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
|
||||
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- _resolve_and_validate_user_data_path ----------
|
||||
@@ -249,7 +290,7 @@ def test_resolve_and_validate_user_data_path_resolves_correctly(tmp_path: Path)
|
||||
"uploads_path": str(tmp_path / "uploads"),
|
||||
"outputs_path": str(tmp_path / "outputs"),
|
||||
}
|
||||
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data)
|
||||
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data, _DEFAULT_APP_CONFIG)
|
||||
assert resolved == str(workspace / "hello.txt")
|
||||
|
||||
|
||||
@@ -264,7 +305,7 @@ def test_resolve_and_validate_user_data_path_blocks_traversal(tmp_path: Path) ->
|
||||
}
|
||||
# This path resolves outside the allowed roots
|
||||
with pytest.raises(PermissionError):
|
||||
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data)
|
||||
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- replace_virtual_paths_in_command ----------
|
||||
@@ -277,7 +318,7 @@ def test_replace_virtual_paths_in_command_replaces_skills_paths() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
cmd = "cat /mnt/skills/public/bootstrap/SKILL.md"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/skills" not in result
|
||||
assert "/home/user/deer-flow/skills/public/bootstrap/SKILL.md" in result
|
||||
|
||||
@@ -289,7 +330,7 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/skills"),
|
||||
):
|
||||
cmd = "cat /mnt/skills/public/SKILL.md > /mnt/user-data/workspace/out.txt"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/skills" not in result
|
||||
assert "/mnt/user-data" not in result
|
||||
assert "/home/user/skills/public/SKILL.md" in result
|
||||
@@ -301,30 +342,27 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||
validate_local_bash_command_paths(
|
||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> None:
|
||||
@@ -332,8 +370,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> No
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/user-data/workspace/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
|
||||
@@ -342,21 +379,20 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/skills/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) -> None:
|
||||
runtime = SimpleNamespace(
|
||||
state={"sandbox": {"sandbox_id": "local"}, "thread_data": _THREAD_DATA.copy()},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=_make_ctx("thread-1"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: SimpleNamespace(execute_command=lambda command: pytest.fail("host bash should not execute")),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda: False)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = bash_tool.func(
|
||||
runtime=runtime,
|
||||
@@ -371,33 +407,32 @@ def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) ->
|
||||
|
||||
|
||||
def test_is_skills_path_recognises_default_prefix() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
assert _is_skills_path("/mnt/skills") is True
|
||||
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md") is True
|
||||
assert _is_skills_path("/mnt/skills-extra/foo") is False
|
||||
assert _is_skills_path("/mnt/user-data/workspace") is False
|
||||
assert _is_skills_path("/mnt/skills", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_skills_path("/mnt/skills-extra/foo", _DEFAULT_APP_CONFIG) is False
|
||||
assert _is_skills_path("/mnt/user-data/workspace", _DEFAULT_APP_CONFIG) is False
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_skills_read_only() -> None:
|
||||
"""read_file / ls should be able to access /mnt/skills paths."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
# Should not raise
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=True,
|
||||
)
|
||||
# Should not raise — default app config uses /mnt/skills as container path
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_skills_write() -> None:
|
||||
"""write_file / str_replace must NOT write to skills paths."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=False,
|
||||
)
|
||||
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_skills_path() -> None:
|
||||
@@ -405,8 +440,7 @@ def test_validate_local_bash_command_paths_allows_skills_path() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_urls() -> None:
|
||||
@@ -414,40 +448,35 @@ def test_validate_local_bash_command_paths_allows_urls() -> None:
|
||||
# HTTPS URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl -X POST https://example.com/api/v1/risk/check",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# HTTP URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://localhost:8080/health",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# URLs with query strings
|
||||
validate_local_bash_command_paths(
|
||||
"curl https://api.example.com/v2/search?q=test",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# FTP URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl ftp://ftp.example.com/pub/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# URL mixed with valid virtual path
|
||||
validate_local_bash_command_paths(
|
||||
"curl https://example.com/data -o /mnt/user-data/workspace/data.json",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls() -> None:
|
||||
"""file:// URLs should be treated as unsafe and blocked."""
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls_case_insensitive() -> None:
|
||||
"""file:// URL detection should be case-insensitive."""
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -> None:
|
||||
@@ -455,35 +484,36 @@ def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths(
|
||||
"curl file:///etc/passwd -o /mnt/user-data/workspace/out.txt",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_other_paths() -> None:
|
||||
"""Paths outside virtual and system prefixes must still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_skills_custom_container_path() -> None:
|
||||
"""Skills with a custom container_path in config should also work."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/custom/skills"):
|
||||
# Should not raise
|
||||
custom_config = _make_app_config(skills_container_path="/custom/skills")
|
||||
# Should not raise
|
||||
validate_local_tool_path(
|
||||
"/custom/skills/public/my-skill/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
custom_config,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# The default /mnt/skills should not match since container path is /custom/skills
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(
|
||||
"/custom/skills/public/my-skill/SKILL.md",
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
custom_config,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# The default /mnt/skills should not match since container path is /custom/skills
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------- ACP workspace path tests ----------
|
||||
|
||||
@@ -500,6 +530,7 @@ def test_validate_local_tool_path_allows_acp_workspace_read_only() -> None:
|
||||
validate_local_tool_path(
|
||||
"/mnt/acp-workspace/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
@@ -510,6 +541,7 @@ def test_validate_local_tool_path_blocks_acp_workspace_write() -> None:
|
||||
validate_local_tool_path(
|
||||
"/mnt/acp-workspace/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
@@ -518,8 +550,7 @@ def test_validate_local_bash_command_paths_allows_acp_workspace() -> None:
|
||||
"""bash commands referencing /mnt/acp-workspace should be allowed."""
|
||||
validate_local_bash_command_paths(
|
||||
"cp /mnt/acp-workspace/hello_world.py /mnt/user-data/outputs/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -> None:
|
||||
@@ -527,8 +558,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/acp-workspace/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_resolve_acp_workspace_path_resolves_correctly(tmp_path: Path) -> None:
|
||||
@@ -570,7 +600,7 @@ def test_replace_virtual_paths_in_command_replaces_acp_workspace() -> None:
|
||||
acp_host = "/home/user/.deer-flow/acp-workspace"
|
||||
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
|
||||
cmd = "cp /mnt/acp-workspace/hello.py /mnt/user-data/outputs/hello.py"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/acp-workspace" not in result
|
||||
assert f"{acp_host}/hello.py" in result
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/outputs/hello.py" in result
|
||||
@@ -581,7 +611,7 @@ def test_mask_local_paths_in_output_hides_acp_workspace_host_paths() -> None:
|
||||
acp_host = "/home/user/.deer-flow/acp-workspace"
|
||||
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
|
||||
output = f"Copied: {acp_host}/hello.py"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert acp_host not in masked
|
||||
assert "/mnt/acp-workspace/hello.py" in masked
|
||||
@@ -617,39 +647,37 @@ def test_apply_cwd_prefix_quotes_path_with_spaces() -> None:
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None:
|
||||
"""Bash commands referencing MCP filesystem server paths should be allowed."""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
mock_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"filesystem": McpServerConfig(
|
||||
enabled=True,
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
|
||||
)
|
||||
}
|
||||
)
|
||||
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=mock_config):
|
||||
# Should not raise - MCP filesystem paths are allowed
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
|
||||
|
||||
# Path traversal should still be blocked
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
|
||||
|
||||
# Disabled servers should not expose paths
|
||||
disabled_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"filesystem": McpServerConfig(
|
||||
enabled=False,
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
|
||||
)
|
||||
}
|
||||
def _mcp_app_config(enabled: bool) -> AppConfig:
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
extensions=ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"filesystem": McpServerConfig(
|
||||
enabled=enabled,
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
|
||||
enabled_cfg = _mcp_app_config(True)
|
||||
# Should not raise - MCP filesystem paths are allowed
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, enabled_cfg)
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA, enabled_cfg)
|
||||
|
||||
# Path traversal should still be blocked
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA, enabled_cfg)
|
||||
|
||||
# Disabled servers should not expose paths
|
||||
disabled_cfg = _mcp_app_config(False)
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, disabled_cfg)
|
||||
|
||||
|
||||
# ---------- Custom mount path tests ----------
|
||||
@@ -667,12 +695,12 @@ def _mock_custom_mounts():
|
||||
|
||||
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
assert _is_custom_mount_path("/mnt/code-read") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
|
||||
assert _is_custom_mount_path("/mnt/data") is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
|
||||
assert _is_custom_mount_path("/mnt/other") is False
|
||||
assert _is_custom_mount_path("/mnt/code-read", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/data", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG) is False
|
||||
assert _is_custom_mount_path("/mnt/other", _DEFAULT_APP_CONFIG) is False
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
@@ -683,7 +711,7 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py")
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py", _DEFAULT_APP_CONFIG)
|
||||
assert mount is not None
|
||||
assert mount.container_path == "/mnt/code"
|
||||
|
||||
@@ -691,90 +719,72 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
|
||||
"""read_file / ls should be able to access custom mount paths."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
|
||||
"""write_file / str_replace must NOT write to read-only custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
|
||||
"""write_file / str_replace should succeed on writable custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Path traversal via .. in custom mount paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
|
||||
"""bash commands referencing custom mount paths should be allowed."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Bash commands with traversal in custom mount paths should be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
|
||||
"""Paths not matching any custom mount should still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should cache after first successful load."""
|
||||
# Clear any existing cache
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
# Use real directories so host_path.exists() filtering passes
|
||||
def test_get_custom_mounts_reads_from_app_config(tmp_path) -> None:
|
||||
"""_get_custom_mounts should read directly from the supplied AppConfig."""
|
||||
dir_a = tmp_path / "code-read"
|
||||
dir_a.mkdir()
|
||||
dir_b = tmp_path / "data"
|
||||
dir_b.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 2
|
||||
|
||||
# After caching, should return cached value even without mock
|
||||
assert hasattr(_get_custom_mounts, "_cached")
|
||||
assert len(_get_custom_mounts()) == 2
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
cfg = _make_app_config(mounts=mounts)
|
||||
result = _get_custom_mounts(cfg)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(tmp_path) -> None:
|
||||
"""_get_custom_mounts should only return mounts whose host_path exists."""
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
existing_dir = tmp_path / "existing"
|
||||
existing_dir.mkdir()
|
||||
@@ -783,22 +793,16 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
|
||||
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
|
||||
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
cfg = _make_app_config(mounts=mounts)
|
||||
result = _get_custom_mounts(cfg)
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
|
||||
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG)
|
||||
assert mount is None
|
||||
|
||||
|
||||
@@ -829,8 +833,8 @@ def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) ->
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
@@ -905,14 +909,14 @@ def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_pat
|
||||
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
|
||||
}
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1", sandbox_key="sandbox-a"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-2", sandbox_key="sandbox-b"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
|
||||
lambda runtime: sandboxes[runtime.context.sandbox_key],
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
@@ -972,8 +976,8 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
|
||||
@@ -29,10 +29,9 @@ async def test_scan_skill_content_passes_run_name_to_model(monkeypatch):
|
||||
@pytest.mark.anyio
|
||||
async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
|
||||
config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None))
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
|
||||
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
|
||||
result = await scan_skill_content(config, "---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
|
||||
|
||||
assert result.decision == "block"
|
||||
assert "manual review required" in result.reason
|
||||
|
||||
@@ -4,9 +4,20 @@ from types import SimpleNamespace
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
|
||||
|
||||
|
||||
def _make_context(thread_id: str, app_config: object | None = None) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=app_config if app_config is not None else AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _skill_content(name: str, description: str = "Demo skill") -> str:
|
||||
return f"---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"
|
||||
|
||||
@@ -23,18 +34,15 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
|
||||
result = anyio.run(
|
||||
skill_manage_module.skill_manage_tool.coroutine,
|
||||
@@ -67,17 +75,14 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
|
||||
|
||||
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
|
||||
@@ -107,10 +112,8 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
|
||||
runtime = SimpleNamespace(context={}, config={"configurable": {}})
|
||||
runtime = SimpleNamespace(context=_make_context("", config), config={"configurable": {}})
|
||||
|
||||
with pytest.raises(ValueError, match="built-in skill"):
|
||||
anyio.run(
|
||||
@@ -131,17 +134,15 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-sync"}, config={"configurable": {"thread_id": "thread-sync"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-sync", config), config={"configurable": {"thread_id": "thread-sync"}})
|
||||
result = skill_manage_module.skill_manage_tool.func(
|
||||
runtime=runtime,
|
||||
action="create",
|
||||
@@ -159,17 +160,14 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
|
||||
|
||||
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):
|
||||
|
||||
@@ -7,6 +7,9 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import skills as skills_router
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.skills.manager import get_skill_history_file
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
@@ -44,17 +47,16 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -94,14 +96,12 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
get_skill_history_file("demo-skill").write_text(
|
||||
get_skill_history_file("demo-skill", config).write_text(
|
||||
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
@@ -114,6 +114,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -136,17 +137,16 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -238,23 +238,25 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
enabled_state = {"value": True}
|
||||
refresh_calls = []
|
||||
|
||||
def _load_skills(*, enabled_only: bool):
|
||||
def _load_skills(*a, enabled_only: bool = False, **k):
|
||||
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
|
||||
if enabled_only and not skill.enabled:
|
||||
return []
|
||||
return [skill]
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
enabled_state["value"] = False
|
||||
|
||||
_app_cfg = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
|
||||
monkeypatch.setattr("app.gateway.routers.skills.get_extensions_config", lambda: SimpleNamespace(mcp_servers={}, skills={}))
|
||||
monkeypatch.setattr("app.gateway.routers.skills.reload_extensions_config", lambda: None)
|
||||
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
|
||||
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = _app_cfg
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path:
|
||||
_write_skill(skills_root / "public" / "parent" / "child-skill", "child-skill", "Child skill")
|
||||
_write_skill(skills_root / "custom" / "team" / "helper", "team-helper", "Team helper")
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
by_name = {skill.name: skill for skill in skills}
|
||||
|
||||
assert {"root-skill", "child-skill", "team-helper"} <= set(by_name)
|
||||
@@ -57,7 +57,7 @@ def test_load_skills_skips_hidden_directories(tmp_path: Path):
|
||||
"Hidden skill",
|
||||
)
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
names = {skill.name for skill in skills}
|
||||
|
||||
assert "ok-skill" in names
|
||||
@@ -69,7 +69,7 @@ def test_load_skills_prefers_custom_over_public_with_same_name(tmp_path: Path):
|
||||
_write_skill(skills_root / "public" / "shared-skill", "shared-skill", "Public version")
|
||||
_write_skill(skills_root / "custom" / "shared-skill", "shared-skill", "Custom version")
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
shared = next(skill for skill in skills if skill.name == "shared-skill")
|
||||
|
||||
assert shared.category == "custom"
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -331,6 +332,9 @@ async def test_concurrent_tasks_end_sentinel():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_make_stream_bridge_defaults():
|
||||
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
|
||||
async with make_stream_bridge() as bridge:
|
||||
"""make_stream_bridge with a config lacking stream_bridge yields a MemoryStreamBridge."""
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
async with make_stream_bridge(config) as bridge:
|
||||
assert isinstance(bridge, MemoryStreamBridge)
|
||||
|
||||
@@ -21,6 +21,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
_TEST_APP_CONFIG = MagicMock(name="TestAppConfig")
|
||||
|
||||
# Module names that need to be mocked to break circular imports
|
||||
_MOCKED_MODULE_NAMES = [
|
||||
"deerflow.agents",
|
||||
@@ -203,7 +205,7 @@ class TestAsyncExecutionPath:
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -232,7 +234,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -259,7 +261,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -285,7 +287,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -306,7 +308,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -327,7 +329,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -348,7 +350,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -384,7 +386,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -419,7 +421,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -456,7 +458,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -477,7 +479,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_aexecute") as mock_aexecute:
|
||||
@@ -511,7 +513,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -565,7 +567,7 @@ class TestAsyncToolSupport:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -602,7 +604,7 @@ class TestAsyncToolSupport:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -648,7 +650,7 @@ class TestThreadSafety:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id=f"thread-{task_id}",
|
||||
thread_id=f"thread-{task_id}", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -858,7 +860,7 @@ class TestCooperativeCancellation:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -898,7 +900,7 @@ class TestCooperativeCancellation:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@@ -977,7 +979,7 @@ class TestCooperativeCancellation:
|
||||
config=short_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
# Wrap _scheduler_pool.submit so we know when run_task finishes
|
||||
|
||||
@@ -1,29 +1,35 @@
|
||||
"""Tests for subagent availability and prompt exposure under local bash hardening."""
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.subagents import registry as registry_module
|
||||
|
||||
|
||||
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: False)
|
||||
def _config() -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
names = registry_module.get_available_subagent_names()
|
||||
|
||||
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
names = registry_module.get_available_subagent_names(_config())
|
||||
|
||||
assert names == ["general-purpose"]
|
||||
|
||||
|
||||
def test_get_available_subagent_names_keeps_bash_when_allowed(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: True)
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: True)
|
||||
|
||||
names = registry_module.get_available_subagent_names()
|
||||
names = registry_module.get_available_subagent_names(_config())
|
||||
|
||||
assert names == ["general-purpose", "bash"]
|
||||
|
||||
|
||||
def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch) -> None:
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose"])
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
|
||||
|
||||
section = prompt_module._build_subagent_section(3)
|
||||
section = prompt_module._build_subagent_section(3, _config())
|
||||
|
||||
# When bash is not available, it should not appear at all (aligned with Codex:
|
||||
# unavailable roles are omitted, not listed as disabled)
|
||||
@@ -34,9 +40,9 @@ def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch
|
||||
|
||||
|
||||
def test_build_subagent_section_includes_bash_when_available(monkeypatch) -> None:
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose", "bash"])
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose", "bash"])
|
||||
|
||||
section = prompt_module._build_subagent_section(3)
|
||||
section = prompt_module._build_subagent_section(3, _config())
|
||||
|
||||
assert "For command execution (git, build, test, deploy operations)" in section
|
||||
assert 'bash("npm test")' in section
|
||||
|
||||
@@ -1,596 +0,0 @@
|
||||
"""Tests for subagent per-agent skill configuration and custom subagent types.
|
||||
|
||||
Covers:
|
||||
- SubagentConfig.skills field
|
||||
- SubagentOverrideConfig.skills field
|
||||
- CustomSubagentConfig model validation
|
||||
- SubagentsAppConfig.custom_agents and get_skills_for()
|
||||
- Registry: custom agent lookup, skills override, merged available names
|
||||
- Skills filter passthrough in task_tool config assembly
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.subagents_config import (
|
||||
CustomSubagentConfig,
|
||||
SubagentOverrideConfig,
|
||||
SubagentsAppConfig,
|
||||
get_subagents_app_config,
|
||||
load_subagents_config_from_dict,
|
||||
)
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reset_subagents_config(**kwargs) -> None:
|
||||
"""Reset global subagents config to a known state."""
|
||||
load_subagents_config_from_dict(kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentConfig.skills field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentConfigSkills:
|
||||
def test_default_skills_is_none(self):
|
||||
config = SubagentConfig(name="test", description="test", system_prompt="test")
|
||||
assert config.skills is None
|
||||
|
||||
def test_skills_whitelist(self):
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=["data-analysis", "visualization"],
|
||||
)
|
||||
assert config.skills == ["data-analysis", "visualization"]
|
||||
|
||||
def test_skills_empty_list_means_no_skills(self):
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=[],
|
||||
)
|
||||
assert config.skills == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentOverrideConfig.skills field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentOverrideConfigSkills:
|
||||
def test_default_skills_is_none(self):
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.skills is None
|
||||
|
||||
def test_skills_whitelist(self):
|
||||
override = SubagentOverrideConfig(skills=["web-search", "data-analysis"])
|
||||
assert override.skills == ["web-search", "data-analysis"]
|
||||
|
||||
def test_skills_empty_list(self):
|
||||
override = SubagentOverrideConfig(skills=[])
|
||||
assert override.skills == []
|
||||
|
||||
def test_skills_coexists_with_other_fields(self):
|
||||
override = SubagentOverrideConfig(
|
||||
timeout_seconds=300,
|
||||
model="gpt-5",
|
||||
skills=["my-skill"],
|
||||
)
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.model == "gpt-5"
|
||||
assert override.skills == ["my-skill"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CustomSubagentConfig model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCustomSubagentConfig:
|
||||
def test_minimal_valid(self):
|
||||
config = CustomSubagentConfig(
|
||||
description="A test agent",
|
||||
system_prompt="You are a test agent.",
|
||||
)
|
||||
assert config.description == "A test agent"
|
||||
assert config.system_prompt == "You are a test agent."
|
||||
assert config.tools is None
|
||||
assert config.disallowed_tools == ["task", "ask_clarification", "present_files"]
|
||||
assert config.skills is None
|
||||
assert config.model == "inherit"
|
||||
assert config.max_turns == 50
|
||||
assert config.timeout_seconds == 900
|
||||
|
||||
def test_full_configuration(self):
|
||||
config = CustomSubagentConfig(
|
||||
description="Data analysis specialist",
|
||||
system_prompt="You are a data analysis subagent.",
|
||||
tools=["bash", "read_file", "write_file"],
|
||||
disallowed_tools=["task"],
|
||||
skills=["data-analysis", "visualization"],
|
||||
model="qwen3:32b",
|
||||
max_turns=80,
|
||||
timeout_seconds=600,
|
||||
)
|
||||
assert config.tools == ["bash", "read_file", "write_file"]
|
||||
assert config.skills == ["data-analysis", "visualization"]
|
||||
assert config.model == "qwen3:32b"
|
||||
assert config.max_turns == 80
|
||||
assert config.timeout_seconds == 600
|
||||
|
||||
def test_skills_empty_list_no_skills(self):
|
||||
config = CustomSubagentConfig(
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=[],
|
||||
)
|
||||
assert config.skills == []
|
||||
|
||||
def test_rejects_zero_max_turns(self):
|
||||
with pytest.raises(ValueError):
|
||||
CustomSubagentConfig(
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
max_turns=0,
|
||||
)
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
CustomSubagentConfig(
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
timeout_seconds=0,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig.custom_agents and get_skills_for()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentsAppConfigCustomAgents:
|
||||
def test_default_custom_agents_empty(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.custom_agents == {}
|
||||
|
||||
def test_custom_agents_loaded(self):
|
||||
config = SubagentsAppConfig(
|
||||
custom_agents={
|
||||
"analysis": CustomSubagentConfig(
|
||||
description="Analysis agent",
|
||||
system_prompt="You analyze data.",
|
||||
skills=["data-analysis"],
|
||||
),
|
||||
}
|
||||
)
|
||||
assert "analysis" in config.custom_agents
|
||||
assert config.custom_agents["analysis"].skills == ["data-analysis"]
|
||||
|
||||
def test_multiple_custom_agents(self):
|
||||
config = SubagentsAppConfig(
|
||||
custom_agents={
|
||||
"analysis": CustomSubagentConfig(
|
||||
description="Analysis",
|
||||
system_prompt="analyze",
|
||||
skills=["data-analysis"],
|
||||
),
|
||||
"researcher": CustomSubagentConfig(
|
||||
description="Research",
|
||||
system_prompt="research",
|
||||
skills=["web-search"],
|
||||
),
|
||||
}
|
||||
)
|
||||
assert len(config.custom_agents) == 2
|
||||
|
||||
|
||||
class TestGetSkillsFor:
|
||||
def test_returns_none_when_no_override(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.get_skills_for("general-purpose") is None
|
||||
assert config.get_skills_for("unknown") is None
|
||||
|
||||
def test_returns_skills_whitelist(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(skills=["web-search", "coding"]),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("general-purpose") == ["web-search", "coding"]
|
||||
|
||||
def test_returns_empty_list_for_no_skills(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"bash": SubagentOverrideConfig(skills=[]),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("bash") == []
|
||||
|
||||
def test_returns_none_for_unrelated_agent(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"bash": SubagentOverrideConfig(skills=["web-search"]),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("general-purpose") is None
|
||||
|
||||
def test_returns_none_when_skills_not_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
agents={
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=300),
|
||||
}
|
||||
)
|
||||
assert config.get_skills_for("bash") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_subagents_config_from_dict with skills and custom_agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadSubagentsConfigWithSkills:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_load_with_skills_override(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search", "data-analysis"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_skills_for("general-purpose") == ["web-search", "data-analysis"]
|
||||
|
||||
def test_load_with_empty_skills(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"bash": {"skills": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_skills_for("bash") == []
|
||||
|
||||
def test_load_with_custom_agents(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Data analysis specialist",
|
||||
"system_prompt": "You are a data analysis subagent.",
|
||||
"skills": ["data-analysis", "visualization"],
|
||||
"tools": ["bash", "read_file"],
|
||||
"max_turns": 80,
|
||||
"timeout_seconds": 600,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert "analysis" in cfg.custom_agents
|
||||
custom = cfg.custom_agents["analysis"]
|
||||
assert custom.skills == ["data-analysis", "visualization"]
|
||||
assert custom.tools == ["bash", "read_file"]
|
||||
assert custom.max_turns == 80
|
||||
assert custom.timeout_seconds == 600
|
||||
|
||||
def test_load_with_both_overrides_and_custom(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search"]},
|
||||
},
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"skills": ["data-analysis"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_skills_for("general-purpose") == ["web-search"]
|
||||
assert cfg.custom_agents["analysis"].skills == ["data-analysis"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: custom agent lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryCustomAgentLookup:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_custom_agent_found(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Data analysis specialist",
|
||||
"system_prompt": "You are a data analysis subagent.",
|
||||
"skills": ["data-analysis"],
|
||||
"tools": ["bash", "read_file"],
|
||||
"max_turns": 80,
|
||||
"timeout_seconds": 600,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("analysis")
|
||||
assert config is not None
|
||||
assert config.name == "analysis"
|
||||
assert config.skills == ["data-analysis"]
|
||||
assert config.tools == ["bash", "read_file"]
|
||||
assert config.max_turns == 80
|
||||
assert config.timeout_seconds == 600
|
||||
assert config.model == "inherit"
|
||||
|
||||
def test_custom_agent_not_found(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config()
|
||||
assert get_subagent_config("nonexistent") is None
|
||||
|
||||
def test_builtin_takes_priority_over_custom(self):
|
||||
"""If a custom agent has the same name as a builtin, builtin wins."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"general-purpose": {
|
||||
"description": "Custom override attempt",
|
||||
"system_prompt": "Should not be used",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("general-purpose")
|
||||
# Should get the builtin description, not the custom one
|
||||
assert config.description == BUILTIN_SUBAGENTS["general-purpose"].description
|
||||
|
||||
def test_custom_agent_with_override(self):
|
||||
"""Per-agent overrides also apply to custom agents."""
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"timeout_seconds": 600,
|
||||
},
|
||||
},
|
||||
"agents": {
|
||||
"analysis": {"timeout_seconds": 300, "skills": ["overridden-skill"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("analysis")
|
||||
assert config is not None
|
||||
assert config.timeout_seconds == 300 # Override applied
|
||||
assert config.skills == ["overridden-skill"] # Override applied
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: skills override on builtin agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistrySkillsOverride:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_skills_override_applied_to_builtin(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search", "data-analysis"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.skills == ["web-search", "data-analysis"]
|
||||
|
||||
def test_empty_skills_override(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"agents": {
|
||||
"bash": {"skills": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
config = get_subagent_config("bash")
|
||||
assert config.skills == []
|
||||
|
||||
def test_no_skills_override_keeps_default(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config()
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.skills is None # Default: inherit all
|
||||
|
||||
def test_skills_override_does_not_mutate_builtin(self):
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"agents": {
|
||||
"general-purpose": {"skills": ["web-search"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
_ = get_subagent_config("general-purpose")
|
||||
assert BUILTIN_SUBAGENTS["general-purpose"].skills is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: get_available_subagent_names merges custom types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryAvailableNames:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_includes_builtin_names(self):
|
||||
from deerflow.subagents.registry import get_subagent_names
|
||||
|
||||
_reset_subagents_config()
|
||||
names = get_subagent_names()
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
|
||||
def test_includes_custom_names(self):
|
||||
from deerflow.subagents.registry import get_subagent_names
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
},
|
||||
"researcher": {
|
||||
"description": "Research",
|
||||
"system_prompt": "Research.",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
names = get_subagent_names()
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
assert "analysis" in names
|
||||
assert "researcher" in names
|
||||
|
||||
def test_no_duplicates_when_custom_name_matches_builtin(self):
|
||||
from deerflow.subagents.registry import get_subagent_names
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"general-purpose": {
|
||||
"description": "Duplicate name",
|
||||
"system_prompt": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
names = get_subagent_names()
|
||||
assert names.count("general-purpose") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry: list_subagents includes custom agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryListSubagentsWithCustom:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_list_includes_custom_agents(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"skills": ["data-analysis"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
configs = list_subagents()
|
||||
names = {c.name for c in configs}
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
assert "analysis" in names
|
||||
|
||||
def test_list_custom_agent_has_correct_skills(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"custom_agents": {
|
||||
"analysis": {
|
||||
"description": "Analysis",
|
||||
"system_prompt": "Analyze.",
|
||||
"skills": ["data-analysis", "visualization"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
by_name = {c.name: c for c in list_subagents()}
|
||||
assert by_name["analysis"].skills == ["data-analysis", "visualization"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skills filter passthrough: verify config.skills is used in task_tool assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillsFilterPassthrough:
|
||||
"""Test that SubagentConfig.skills is correctly passed to get_skills_prompt_section."""
|
||||
|
||||
def test_none_skills_passes_none_to_prompt(self):
|
||||
"""When config.skills is None, available_skills=None should be passed (inherit all)."""
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=None,
|
||||
)
|
||||
# Verify: set(None) would raise, so the code must check for None first
|
||||
available = set(config.skills) if config.skills is not None else None
|
||||
assert available is None
|
||||
|
||||
def test_empty_skills_passes_empty_set(self):
|
||||
"""When config.skills is [], available_skills=set() should be passed (no skills)."""
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=[],
|
||||
)
|
||||
available = set(config.skills) if config.skills is not None else None
|
||||
assert available == set()
|
||||
|
||||
def test_skills_whitelist_passes_correct_set(self):
|
||||
"""When config.skills has values, those should be passed as available_skills."""
|
||||
config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
skills=["data-analysis", "web-search"],
|
||||
)
|
||||
available = set(config.skills) if config.skills is not None else None
|
||||
assert available == {"data-analysis", "web-search"}
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
|
||||
- get_timeout_for() / get_max_turns_for() resolution logic
|
||||
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
|
||||
- AppConfig.subagents field access
|
||||
- registry.get_subagent_config() applies config overrides
|
||||
- registry.list_subagents() applies overrides for all agents
|
||||
- Polling timeout calculation in task_tool is consistent with config
|
||||
@@ -11,32 +11,28 @@ Covers:
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.subagents_config import (
|
||||
SubagentOverrideConfig,
|
||||
SubagentsAppConfig,
|
||||
get_subagents_app_config,
|
||||
load_subagents_config_from_dict,
|
||||
)
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reset_subagents_config(
|
||||
def _make_config(
|
||||
timeout_seconds: int = 900,
|
||||
*,
|
||||
max_turns: int | None = None,
|
||||
agents: dict | None = None,
|
||||
) -> None:
|
||||
"""Reset global subagents config to a known state."""
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"max_turns": max_turns,
|
||||
"agents": agents or {},
|
||||
}
|
||||
) -> AppConfig:
|
||||
"""Build an AppConfig with the given subagents settings."""
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
subagents=SubagentsAppConfig(
|
||||
timeout_seconds=timeout_seconds,
|
||||
max_turns=max_turns,
|
||||
agents={k: SubagentOverrideConfig(**v) for k, v in (agents or {}).items()},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -50,523 +46,131 @@ class TestSubagentOverrideConfig:
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.timeout_seconds is None
|
||||
assert override.max_turns is None
|
||||
assert override.model is None
|
||||
|
||||
def test_explicit_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42, model="gpt-5.4")
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.max_turns == 42
|
||||
assert override.model == "gpt-5.4"
|
||||
|
||||
def test_model_accepts_any_non_empty_string(self):
|
||||
"""Model name is a free-form non-empty string; cross-reference validation
|
||||
against the `models:` section happens at registry lookup time."""
|
||||
override = SubagentOverrideConfig(model="any-arbitrary-model-name")
|
||||
assert override.model == "any-arbitrary-model-name"
|
||||
|
||||
def test_rejects_zero(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=0)
|
||||
|
||||
def test_rejects_negative(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=-1)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=-1)
|
||||
|
||||
def test_rejects_empty_model(self):
|
||||
"""Empty-string model would silently bypass the `is not None` check and
|
||||
reach `create_chat_model(name="")` as a runtime error. Reject at load time
|
||||
instead, symmetric with the `ge=1` guard on timeout_seconds / max_turns."""
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(model="")
|
||||
|
||||
def test_minimum_valid_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
|
||||
assert override.timeout_seconds == 1
|
||||
assert override.max_turns == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig – defaults and validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentsAppConfigDefaults:
|
||||
def test_default_timeout(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.timeout_seconds == 900
|
||||
|
||||
def test_default_max_turns_override_is_none(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.max_turns is None
|
||||
|
||||
def test_default_agents_empty(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.agents == {}
|
||||
|
||||
def test_custom_global_runtime_overrides(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 120
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=0)
|
||||
def test_explicit_values(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=120, max_turns=50)
|
||||
assert override.timeout_seconds == 120
|
||||
assert override.max_turns == 50
|
||||
|
||||
def test_rejects_negative_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=-60)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=-60)
|
||||
with pytest.raises(Exception):
|
||||
SubagentOverrideConfig(timeout_seconds=-1)
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(Exception):
|
||||
SubagentOverrideConfig(timeout_seconds=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig resolution helpers
|
||||
# SubagentsAppConfig model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRuntimeResolution:
|
||||
def test_returns_global_default_when_no_override(self):
|
||||
class TestSubagentsAppConfig:
|
||||
def test_default_timeout_is_900(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.timeout_seconds == 900
|
||||
assert config.max_turns is None
|
||||
assert config.agents == {}
|
||||
|
||||
def test_custom_defaults(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=300, max_turns=50)
|
||||
assert config.timeout_seconds == 300
|
||||
assert config.max_turns == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_timeout_for / get_max_turns_for
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTimeoutResolution:
|
||||
def test_global_timeout_for_unknown_agent(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=600)
|
||||
assert config.get_timeout_for("unknown") == 600
|
||||
|
||||
def test_per_agent_timeout_overrides_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=600,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=120)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 120
|
||||
assert config.get_timeout_for("general-purpose") == 600
|
||||
|
||||
def test_per_agent_override_none_falls_back_to_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=600,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=None)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 600
|
||||
assert config.get_timeout_for("unknown-agent") == 600
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 100
|
||||
|
||||
|
||||
class TestMaxTurnsResolution:
|
||||
def test_builtin_default_when_no_override(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.get_max_turns_for("bash", 60) == 60
|
||||
|
||||
def test_returns_per_agent_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
def test_global_max_turns_overrides_builtin(self):
|
||||
config = SubagentsAppConfig(max_turns=100)
|
||||
assert config.get_max_turns_for("bash", 60) == 100
|
||||
|
||||
def test_other_agents_still_use_global_default(self):
|
||||
def test_per_agent_max_turns_overrides_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=140,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
max_turns=100,
|
||||
agents={"bash": SubagentOverrideConfig(max_turns=30)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 140
|
||||
assert config.get_max_turns_for("bash", 60) == 30
|
||||
assert config.get_max_turns_for("general-purpose", 60) == 100
|
||||
|
||||
def test_agent_with_none_override_falls_back_to_global(self):
|
||||
def test_per_agent_override_none_falls_back(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=150,
|
||||
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
|
||||
max_turns=100,
|
||||
agents={"bash": SubagentOverrideConfig(max_turns=None)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 150
|
||||
|
||||
def test_multiple_per_agent_overrides(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
|
||||
},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 1800
|
||||
assert config.get_timeout_for("bash") == 120
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_get_model_for_returns_none_when_no_override(self):
|
||||
"""No per-agent model override -> returns None so callers fall back to builtin/parent."""
|
||||
config = SubagentsAppConfig(timeout_seconds=900)
|
||||
assert config.get_model_for("general-purpose") is None
|
||||
assert config.get_model_for("bash") is None
|
||||
assert config.get_model_for("unknown-agent") is None
|
||||
|
||||
def test_get_model_for_returns_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(model="qwen3.5-35b-a3b"),
|
||||
"bash": SubagentOverrideConfig(model="gpt-5.4"),
|
||||
},
|
||||
)
|
||||
assert config.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
|
||||
assert config.get_model_for("bash") == "gpt-5.4"
|
||||
|
||||
def test_get_model_for_returns_none_for_omitted_agent(self):
|
||||
"""An agent not listed in overrides returns None even when other agents have model overrides."""
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(model="gpt-5.4")},
|
||||
)
|
||||
assert config.get_model_for("general-purpose") is None
|
||||
|
||||
def test_get_model_for_handles_explicit_none(self):
|
||||
"""Explicit model=None in the override is equivalent to no override."""
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, model=None)},
|
||||
)
|
||||
assert config.get_model_for("bash") is None
|
||||
# Timeout override is still applied even when model is None.
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
assert config.get_max_turns_for("bash", 60) == 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_subagents_config_from_dict / get_subagents_app_config singleton
|
||||
# AppConfig.subagents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadSubagentsConfig:
|
||||
def teardown_method(self):
|
||||
"""Restore defaults after each test."""
|
||||
_reset_subagents_config()
|
||||
|
||||
class TestAppConfigSubagents:
|
||||
def test_load_global_timeout(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
|
||||
assert get_subagents_app_config().timeout_seconds == 300
|
||||
assert get_subagents_app_config().max_turns == 120
|
||||
cfg = _make_config(timeout_seconds=300, max_turns=120)
|
||||
sub = cfg.subagents
|
||||
assert sub.timeout_seconds == 300
|
||||
assert sub.max_turns == 120
|
||||
|
||||
def test_load_with_per_agent_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
cfg = _make_config(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 1800
|
||||
assert cfg.get_timeout_for("bash") == 60
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert cfg.get_max_turns_for("bash", 60) == 80
|
||||
sub = cfg.subagents
|
||||
assert sub.get_timeout_for("general-purpose") == 1800
|
||||
assert sub.get_timeout_for("bash") == 60
|
||||
assert sub.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert sub.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_load_partial_override(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 600,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
|
||||
}
|
||||
cfg = _make_config(
|
||||
timeout_seconds=600,
|
||||
agents={"bash": {"timeout_seconds": 120, "max_turns": 70}},
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 600
|
||||
assert cfg.get_timeout_for("bash") == 120
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert cfg.get_max_turns_for("bash", 60) == 70
|
||||
sub = cfg.subagents
|
||||
assert sub.get_timeout_for("general-purpose") == 600
|
||||
assert sub.get_timeout_for("bash") == 120
|
||||
assert sub.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert sub.get_max_turns_for("bash", 60) == 70
|
||||
|
||||
def test_load_with_model_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {
|
||||
"general-purpose": {"model": "qwen3.5-35b-a3b"},
|
||||
"bash": {"model": "gpt-5.4", "timeout_seconds": 300},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_model_for("general-purpose") == "qwen3.5-35b-a3b"
|
||||
assert cfg.get_model_for("bash") == "gpt-5.4"
|
||||
# Other override fields on the same agent must still load correctly.
|
||||
assert cfg.get_timeout_for("bash") == 300
|
||||
|
||||
def test_load_empty_dict_uses_defaults(self):
|
||||
load_subagents_config_from_dict({})
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.timeout_seconds == 900
|
||||
assert cfg.max_turns is None
|
||||
assert cfg.agents == {}
|
||||
|
||||
def test_load_replaces_previous_config(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
|
||||
assert get_subagents_app_config().timeout_seconds == 100
|
||||
assert get_subagents_app_config().max_turns == 90
|
||||
|
||||
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
|
||||
assert get_subagents_app_config().timeout_seconds == 200
|
||||
assert get_subagents_app_config().max_turns == 110
|
||||
|
||||
def test_singleton_returns_same_instance_between_calls(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
|
||||
assert get_subagents_app_config() is get_subagents_app_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# registry.get_subagent_config – runtime overrides applied
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryGetSubagentConfig:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_returns_none_for_unknown_agent(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
assert get_subagent_config("nonexistent") is None
|
||||
|
||||
def test_returns_config_for_builtin_agents(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
assert get_subagent_config("general-purpose") is not None
|
||||
assert get_subagent_config("bash") is not None
|
||||
|
||||
def test_default_timeout_preserved_when_no_config(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=900)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 900
|
||||
assert config.max_turns == 100
|
||||
|
||||
def test_global_timeout_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 140
|
||||
|
||||
def test_per_agent_runtime_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.timeout_seconds == 120
|
||||
assert bash_config.max_turns == 80
|
||||
|
||||
def test_per_agent_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.timeout_seconds == 900
|
||||
assert gp_config.max_turns == 120
|
||||
|
||||
def test_per_agent_model_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == "gpt-5.4-mini"
|
||||
|
||||
def test_omitted_model_keeps_builtin_value(self):
|
||||
"""When config.yaml has no `model` field for an agent, the builtin default must be preserved."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 300}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == builtin_bash_model
|
||||
|
||||
def test_explicit_null_model_keeps_builtin_value(self):
|
||||
"""An explicit `model: null` in config.yaml is equivalent to omission — builtin wins."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": None}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.model == builtin_bash_model
|
||||
|
||||
def test_model_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
builtin_gp_model = BUILTIN_SUBAGENTS["general-purpose"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4"}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.model == builtin_gp_model
|
||||
|
||||
def test_model_override_preserves_other_fields(self):
|
||||
"""Applying a model override must leave timeout_seconds / max_turns / name intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original = BUILTIN_SUBAGENTS["bash"]
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
overridden = get_subagent_config("bash")
|
||||
assert overridden.model == "gpt-5.4-mini"
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
# No timeout / max_turns override was set, so they use global default / builtin.
|
||||
assert overridden.timeout_seconds == 900
|
||||
assert overridden.max_turns == original.max_turns
|
||||
|
||||
def test_model_override_does_not_mutate_builtin(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_bash_model = BUILTIN_SUBAGENTS["bash"].model
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"model": "gpt-5.4-mini"}},
|
||||
}
|
||||
)
|
||||
_ = get_subagent_config("bash")
|
||||
assert BUILTIN_SUBAGENTS["bash"].model == original_bash_model
|
||||
|
||||
def test_builtin_config_object_is_not_mutated(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
|
||||
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
|
||||
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
|
||||
|
||||
returned = get_subagent_config("bash")
|
||||
assert returned.timeout_seconds == 42
|
||||
assert returned.max_turns == 88
|
||||
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
|
||||
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
|
||||
|
||||
def test_config_preserves_other_fields(self):
|
||||
"""Applying runtime overrides must not change other SubagentConfig fields."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=300, max_turns=140)
|
||||
original = BUILTIN_SUBAGENTS["general-purpose"]
|
||||
overridden = get_subagent_config("general-purpose")
|
||||
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
assert overridden.max_turns == 140
|
||||
assert overridden.model == original.model
|
||||
assert overridden.tools == original.tools
|
||||
assert overridden.disallowed_tools == original.disallowed_tools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# registry.list_subagents – all agents get overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryListSubagents:
|
||||
def teardown_method(self):
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_lists_both_builtin_agents(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
names = {cfg.name for cfg in list_subagents()}
|
||||
assert "general-purpose" in names
|
||||
assert "bash" in names
|
||||
|
||||
def test_all_returned_configs_get_global_override(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
_reset_subagents_config(timeout_seconds=123, max_turns=77)
|
||||
for cfg in list_subagents():
|
||||
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
|
||||
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
|
||||
|
||||
def test_per_agent_overrides_reflected_in_list(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
)
|
||||
by_name = {cfg.name: cfg for cfg in list_subagents()}
|
||||
assert by_name["general-purpose"].timeout_seconds == 1800
|
||||
assert by_name["bash"].timeout_seconds == 60
|
||||
assert by_name["general-purpose"].max_turns == 200
|
||||
assert by_name["bash"].max_turns == 80
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Polling timeout calculation (logic extracted from task_tool)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPollingTimeoutCalculation:
|
||||
"""Verify the formula (timeout_seconds + 60) // 5 is correct for various inputs."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"timeout_seconds, expected_max_polls",
|
||||
[
|
||||
(900, 192), # default 15 min → (900+60)//5 = 192
|
||||
(300, 72), # 5 min → (300+60)//5 = 72
|
||||
(1800, 372), # 30 min → (1800+60)//5 = 372
|
||||
(60, 24), # 1 min → (60+60)//5 = 24
|
||||
(1, 12), # minimum → (1+60)//5 = 12
|
||||
],
|
||||
)
|
||||
def test_polling_timeout_formula(self, timeout_seconds: int, expected_max_polls: int):
|
||||
dummy_config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
|
||||
assert max_poll_count == expected_max_polls
|
||||
|
||||
def test_polling_timeout_exceeds_execution_timeout(self):
|
||||
"""Safety-net polling window must always be longer than the execution timeout."""
|
||||
for timeout_seconds in [60, 300, 900, 1800]:
|
||||
dummy_config = SubagentConfig(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
max_poll_count = (dummy_config.timeout_seconds + 60) // 5
|
||||
polling_window_seconds = max_poll_count * 5
|
||||
assert polling_window_seconds > timeout_seconds
|
||||
def test_load_empty_uses_defaults(self):
|
||||
cfg = _make_config()
|
||||
sub = cfg.subagents
|
||||
assert sub.timeout_seconds == 900
|
||||
assert sub.max_turns is None
|
||||
assert sub.agents == {}
|
||||
|
||||
@@ -46,7 +46,9 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||
fake_model.ainvoke.assert_awaited_once()
|
||||
@@ -66,7 +68,9 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
fake_model.ainvoke.assert_awaited_once()
|
||||
@@ -86,7 +90,9 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
fake_model.ainvoke.assert_awaited_once()
|
||||
@@ -103,6 +109,8 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == []
|
||||
|
||||
@@ -8,6 +8,9 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
# Use module import so tests can patch the exact symbols referenced inside task_tool().
|
||||
@@ -24,6 +27,13 @@ class FakeSubagentStatus(Enum):
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime() -> SimpleNamespace:
|
||||
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
|
||||
return SimpleNamespace(
|
||||
@@ -35,7 +45,7 @@ def _make_runtime() -> SimpleNamespace:
|
||||
"outputs_path": "/tmp/outputs",
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=_make_context("thread-1"),
|
||||
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
|
||||
)
|
||||
|
||||
@@ -83,11 +93,11 @@ class _DummyScheduledTask:
|
||||
|
||||
|
||||
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: None)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
|
||||
|
||||
result = _run_task_tool(
|
||||
runtime=None,
|
||||
runtime=_make_runtime(),
|
||||
description="执行任务",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
@@ -98,8 +108,8 @@ def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
|
||||
|
||||
def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda: False)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = _run_task_tool(
|
||||
runtime=_make_runtime(),
|
||||
@@ -142,9 +152,9 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "Skills Appendix")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
|
||||
@@ -165,225 +175,18 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
|
||||
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
|
||||
assert captured["executor_kwargs"]["config"].max_turns == 7
|
||||
# Skills are no longer appended to system_prompt; they are loaded per-session
|
||||
# by SubagentExecutor and injected as conversation items (Codex pattern).
|
||||
assert captured["executor_kwargs"]["config"].system_prompt == "Base system prompt"
|
||||
assert "Skills Appendix" in captured["executor_kwargs"]["config"].system_prompt
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=None, subagent_enabled=False)
|
||||
from unittest.mock import ANY
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False, app_config=ANY)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
|
||||
assert events[-1]["result"] == "all done"
|
||||
|
||||
|
||||
def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch):
|
||||
"""Verify tool_groups from parent metadata are passed to get_available_tools(groups=...)."""
|
||||
config = _make_subagent_config()
|
||||
parent_tool_groups = ["file:read", "file:write", "bash"]
|
||||
runtime = SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
"thread_data": {"workspace_path": "/tmp/workspace"},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1", "tool_groups": parent_tool_groups}},
|
||||
)
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=["tool-a"])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="file work only",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-groups",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
# The key assertion: groups should be propagated from parent metadata
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
|
||||
|
||||
|
||||
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
runtime.config["metadata"]["available_skills"] = ["safe-skill"]
|
||||
events = []
|
||||
captured = {}
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
captured["config"] = kwargs["config"]
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="use skills",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-skills",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
assert captured["config"].skills == ["safe-skill"]
|
||||
|
||||
|
||||
def test_task_tool_intersects_parent_and_subagent_skill_allowlists(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
config = SubagentConfig(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
system_prompt=config.system_prompt,
|
||||
max_turns=config.max_turns,
|
||||
timeout_seconds=config.timeout_seconds,
|
||||
skills=["safe-skill", "other-skill"],
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
runtime.config["metadata"]["available_skills"] = ["safe-skill"]
|
||||
events = []
|
||||
captured = {}
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
captured["config"] = kwargs["config"]
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="use skills",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-skills-intersection",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
assert captured["config"].skills == ["safe-skill"]
|
||||
|
||||
|
||||
def test_task_tool_no_tool_groups_passes_none(monkeypatch):
|
||||
"""Verify that when metadata has no tool_groups, groups=None is passed (backward compat)."""
|
||||
config = _make_subagent_config()
|
||||
# Default _make_runtime() has no tool_groups in metadata
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="ok"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="执行任务",
|
||||
prompt="normal work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-groups",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: ok"
|
||||
# No tool_groups in metadata → groups=None (default behavior preserved)
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=None, subagent_enabled=False)
|
||||
|
||||
|
||||
def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
"""Verify that when runtime is None, groups=None is passed (e.g., unknown subagent path exits early, but tools still load correctly)."""
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="ok"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=None,
|
||||
description="执行任务",
|
||||
prompt="no runtime",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-runtime",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: ok"
|
||||
# runtime is None → metadata is empty dict → groups=None
|
||||
get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False)
|
||||
|
||||
def test_task_tool_returns_failed_message(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
|
||||
@@ -393,12 +196,12 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -427,12 +230,12 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -463,12 +266,12 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -499,12 +302,12 @@ def test_cleanup_called_on_completed(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -539,12 +342,12 @@ def test_cleanup_called_on_failed(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -579,12 +382,12 @@ def test_cleanup_called_on_timed_out(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -626,12 +429,12 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@@ -679,8 +482,8 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@@ -730,12 +533,12 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@@ -785,12 +588,12 @@ def test_cancellation_calls_request_cancel(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@@ -843,9 +646,9 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
|
||||
@@ -1,58 +1,55 @@
|
||||
import pytest
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _as_posix(path: str) -> str:
|
||||
return path.replace("\\", "/")
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadDataMiddleware:
|
||||
def test_before_agent_returns_paths_when_thread_id_present_in_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context={"thread_id": "thread-123"}))
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-123")))
|
||||
|
||||
assert result is not None
|
||||
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-123/user-data/workspace")
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-123/user-data/uploads")
|
||||
assert _as_posix(result["thread_data"]["outputs_path"]).endswith("threads/thread-123/user-data/outputs")
|
||||
|
||||
def test_before_agent_uses_thread_id_from_configurable_when_context_is_none(self, tmp_path, monkeypatch):
|
||||
def test_before_agent_uses_thread_id_from_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
runtime = Runtime(context=None)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
||||
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
result = middleware.before_agent(state={}, runtime=runtime)
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-config")))
|
||||
|
||||
assert result is not None
|
||||
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-from-config/user-data/workspace")
|
||||
assert runtime.context is None
|
||||
|
||||
def test_before_agent_uses_thread_id_from_configurable_when_context_missing_thread_id(self, tmp_path, monkeypatch):
|
||||
def test_before_agent_uses_thread_id_from_typed_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
runtime = Runtime(context={})
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
||||
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
result = middleware.before_agent(state={}, runtime=runtime)
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context=_make_context("thread-from-dict")))
|
||||
|
||||
assert result is not None
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-config/user-data/uploads")
|
||||
assert runtime.context == {}
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-dict/user-data/uploads")
|
||||
|
||||
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
|
||||
def test_before_agent_raises_clear_error_when_thread_id_missing(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
||||
lambda: {"configurable": {}},
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
with pytest.raises(ValueError, match="Thread ID is required in runtime context or config.configurable"):
|
||||
middleware.before_agent(state={}, runtime=Runtime(context=None))
|
||||
with pytest.raises(ValueError, match="Thread ID is required"):
|
||||
middleware.before_agent(state={}, runtime=Runtime(context=_make_context("")))
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Tests for ThreadMetaRepository (SQLAlchemy-backed)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return ThreadMetaRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1")
|
||||
assert record["thread_id"] == "t1"
|
||||
assert record["status"] == "idle"
|
||||
assert "created_at" in record
|
||||
|
||||
fetched = await repo.get("t1")
|
||||
assert fetched is not None
|
||||
assert fetched["thread_id"] == "t1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_assistant_id(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", assistant_id="agent1")
|
||||
assert record["assistant_id"] == "agent1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||
assert record["user_id"] == "user1"
|
||||
assert record["display_name"] == "My Thread"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_metadata(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", metadata={"key": "value"})
|
||||
assert record["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_record_allows(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.check_access("unknown", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_matches(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2") is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
# Explicit user_id=None to bypass the new AUTO default that
|
||||
# would otherwise pick up the test user from the autouse fixture.
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_missing_row_denied(self, tmp_path):
|
||||
"""require_existing=True flips the missing-row case to *denied*.
|
||||
|
||||
Closes the delete-idempotence cross-user gap: after a thread is
|
||||
deleted, the row is gone, and the permissive default would let any
|
||||
caller "claim" it as untracked. The strict mode demands a row.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.check_access("never-existed", "user1", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||
|
||||
The strict flag tightens the *missing row* case, not the *shared
|
||||
row* case — legacy pre-auth rows that survived a clean migration
|
||||
without an owner are still everyone's.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1")
|
||||
await repo.update_status("t1", "busy")
|
||||
record = await repo.get("t1")
|
||||
assert record["status"] == "busy"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1")
|
||||
await repo.delete("t1")
|
||||
assert await repo.get("t1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.delete("nonexistent") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_merges(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", metadata={"a": 1, "b": 2})
|
||||
await repo.update_metadata("t1", {"b": 99, "c": 3})
|
||||
record = await repo.get("t1")
|
||||
# Existing key preserved, overlapping key overwritten, new key added
|
||||
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_on_empty(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1")
|
||||
await repo.update_metadata("t1", {"k": "v"})
|
||||
record = await repo.get("t1")
|
||||
assert record["metadata"] == {"k": "v"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||
await _cleanup()
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests for paginated GET /api/threads/{thread_id}/runs/{run_id}/messages endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import thread_runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(event_store=None):
|
||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||
app = make_authed_test_app()
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
if event_store is not None:
|
||||
app.state.run_event_store = event_store
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_event_store(rows: list[dict]):
|
||||
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||
store = MagicMock()
|
||||
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||
return store
|
||||
|
||||
|
||||
def _make_message(seq: int) -> dict:
|
||||
return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_returns_paginated_envelope():
|
||||
"""GET /api/threads/{tid}/runs/{rid}/messages returns {data: [...], has_more: bool}."""
|
||||
rows = [_make_message(i) for i in range(1, 4)]
|
||||
app = _make_app(event_store=_make_event_store(rows))
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/runs/run-1/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "data" in body
|
||||
assert "has_more" in body
|
||||
assert body["has_more"] is False
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
|
||||
def test_has_more_true_when_extra_row_returned():
|
||||
"""has_more=True when event store returns limit+1 rows."""
|
||||
# Default limit is 50; provide 51 rows
|
||||
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||
app = _make_app(event_store=_make_event_store(rows))
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-2/runs/run-2/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["has_more"] is True
|
||||
assert len(body["data"]) == 50 # trimmed to limit
|
||||
|
||||
|
||||
def test_after_seq_forwarded_to_event_store():
|
||||
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(10)]
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(event_store=event_store)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-3/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
)
|
||||
|
||||
|
||||
def test_before_seq_forwarded_to_event_store():
|
||||
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(3)]
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(event_store=event_store)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-4/runs/run-4/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4", "run-4",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_custom_limit_forwarded_to_event_store():
|
||||
"""Custom limit is forwarded as limit+1 to the event store."""
|
||||
rows = [_make_message(i) for i in range(1, 6)]
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(event_store=event_store)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-5/runs/run-5/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5", "run-5",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_data_when_no_messages():
|
||||
"""Returns empty data list with has_more=False when no messages exist."""
|
||||
app = _make_app(event_store=_make_event_store([]))
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-6/runs/run-6/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert body["has_more"] is False
|
||||
@@ -1,7 +1,8 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import threads
|
||||
@@ -49,12 +50,15 @@ def test_delete_thread_data_rejects_invalid_thread_id(tmp_path):
|
||||
|
||||
|
||||
def test_delete_thread_route_cleans_thread_directory(tmp_path):
|
||||
paths = Paths(tmp_path)
|
||||
thread_dir = paths.thread_dir("thread-route")
|
||||
paths.sandbox_work_dir("thread-route").mkdir(parents=True, exist_ok=True)
|
||||
(paths.sandbox_work_dir("thread-route") / "notes.txt").write_text("hello", encoding="utf-8")
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
app = FastAPI()
|
||||
paths = Paths(tmp_path)
|
||||
user_id = get_effective_user_id()
|
||||
thread_dir = paths.thread_dir("thread-route", user_id=user_id)
|
||||
paths.sandbox_work_dir("thread-route", user_id=user_id).mkdir(parents=True, exist_ok=True)
|
||||
(paths.sandbox_work_dir("thread-route", user_id=user_id) / "notes.txt").write_text("hello", encoding="utf-8")
|
||||
|
||||
app = make_authed_test_app()
|
||||
app.include_router(threads.router)
|
||||
|
||||
with patch("app.gateway.routers.threads.get_paths", return_value=paths):
|
||||
@@ -69,7 +73,7 @@ def test_delete_thread_route_cleans_thread_directory(tmp_path):
|
||||
def test_delete_thread_route_rejects_invalid_thread_id(tmp_path):
|
||||
paths = Paths(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app = make_authed_test_app()
|
||||
app.include_router(threads.router)
|
||||
|
||||
with patch("app.gateway.routers.threads.get_paths", return_value=paths):
|
||||
@@ -82,7 +86,7 @@ def test_delete_thread_route_rejects_invalid_thread_id(tmp_path):
|
||||
def test_delete_thread_route_returns_422_for_route_safe_invalid_id(tmp_path):
|
||||
paths = Paths(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app = make_authed_test_app()
|
||||
app.include_router(threads.router)
|
||||
|
||||
with patch("app.gateway.routers.threads.get_paths", return_value=paths):
|
||||
@@ -107,3 +111,28 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path):
|
||||
assert exc_info.value.detail == "Failed to delete local thread data."
|
||||
assert "/secret/path" not in exc_info.value.detail
|
||||
log_exception.assert_called_once_with("Failed to delete thread data for %s", "thread-cleanup")
|
||||
|
||||
|
||||
# ── Server-reserved metadata key stripping ──────────────────────────────────
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_removes_user_id():
|
||||
"""Client-supplied user_id is dropped to prevent reflection attacks."""
|
||||
out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"})
|
||||
assert out == {"title": "ok"}
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_passes_through_safe_keys():
|
||||
"""Non-reserved keys are preserved verbatim."""
|
||||
md = {"title": "ok", "tags": ["a", "b"], "custom": {"x": 1}}
|
||||
assert threads._strip_reserved_metadata(md) == md
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_empty_input():
|
||||
"""Empty / None metadata returns same object — no crash."""
|
||||
assert threads._strip_reserved_metadata({}) == {}
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_strips_all_reserved_keys():
|
||||
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
|
||||
assert out == {"keep": "me"}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
|
||||
class TestTitleConfig:
|
||||
@@ -44,21 +44,6 @@ class TestTitleConfig:
|
||||
with pytest.raises(ValueError):
|
||||
TitleConfig(max_chars=201)
|
||||
|
||||
def test_get_set_config(self):
|
||||
"""Test global config getter and setter."""
|
||||
original_config = get_title_config()
|
||||
|
||||
# Set new config
|
||||
new_config = TitleConfig(enabled=False, max_words=10)
|
||||
set_title_config(new_config)
|
||||
|
||||
# Verify it was set
|
||||
assert get_title_config().enabled is False
|
||||
assert get_title_config().max_words == 10
|
||||
|
||||
# Restore original config
|
||||
set_title_config(original_config)
|
||||
|
||||
|
||||
class TestTitleMiddleware:
|
||||
"""Tests for TitleMiddleware."""
|
||||
@@ -68,23 +53,3 @@ class TestTitleMiddleware:
|
||||
middleware = TitleMiddleware()
|
||||
assert middleware is not None
|
||||
assert middleware.state_schema is not None
|
||||
|
||||
# TODO: Add integration tests with mock Runtime
|
||||
# def test_should_generate_title(self):
|
||||
# """Test title generation trigger logic."""
|
||||
# pass
|
||||
|
||||
# def test_generate_title(self):
|
||||
# """Test title generation."""
|
||||
# pass
|
||||
|
||||
# def test_after_agent_hook(self):
|
||||
# """Test after_agent hook."""
|
||||
# pass
|
||||
|
||||
|
||||
# TODO: Add integration tests
|
||||
# - Test with real LangGraph runtime
|
||||
# - Test title persistence with checkpointer
|
||||
# - Test fallback behavior when LLM fails
|
||||
# - Test concurrent title generation
|
||||
|
||||
@@ -1,38 +1,32 @@
|
||||
"""Core behavior tests for TitleMiddleware."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
|
||||
def _clone_title_config(config: TitleConfig) -> TitleConfig:
|
||||
# Avoid mutating shared global config objects across tests.
|
||||
return TitleConfig(**config.model_dump())
|
||||
def _make_title_config(**overrides) -> TitleConfig:
|
||||
return TitleConfig(**overrides)
|
||||
|
||||
|
||||
def _set_test_title_config(**overrides) -> TitleConfig:
|
||||
config = _clone_title_config(get_title_config())
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
set_title_config(config)
|
||||
return config
|
||||
def _make_runtime(**title_overrides) -> SimpleNamespace:
|
||||
"""Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
|
||||
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
|
||||
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
|
||||
return SimpleNamespace(context=ctx)
|
||||
|
||||
|
||||
class TestTitleMiddlewareCoreLogic:
|
||||
def setup_method(self):
|
||||
# Title config is a global singleton; snapshot and restore for test isolation.
|
||||
self._original = _clone_title_config(get_title_config())
|
||||
|
||||
def teardown_method(self):
|
||||
set_title_config(self._original)
|
||||
|
||||
def test_should_generate_title_for_first_complete_exchange(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -41,27 +35,24 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is True
|
||||
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
|
||||
|
||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
_set_test_title_config(enabled=False)
|
||||
disabled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": None,
|
||||
}
|
||||
assert middleware._should_generate_title(disabled_state) is False
|
||||
assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
|
||||
|
||||
_set_test_title_config(enabled=True)
|
||||
titled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": "Existing Title",
|
||||
}
|
||||
assert middleware._should_generate_title(titled_state) is False
|
||||
assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
|
||||
|
||||
def test_should_not_generate_title_after_second_user_turn(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -72,10 +63,9 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is False
|
||||
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
|
||||
|
||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=12)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
@@ -87,16 +77,17 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=12))))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
from unittest.mock import ANY
|
||||
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, app_config=ANY)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "title_agent"}
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
||||
@@ -109,13 +100,12 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "请帮我总结这段代码"
|
||||
|
||||
def test_generate_title_fallback_for_long_message(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
||||
@@ -127,7 +117,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="收到"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
|
||||
title = result["title"]
|
||||
|
||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||
@@ -138,25 +128,24 @@ class TestTitleMiddlewareCoreLogic:
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
|
||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
|
||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
|
||||
assert result == {"title": "异步标题"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
|
||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
|
||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
|
||||
|
||||
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
|
||||
result = middleware.after_model({"messages": []}, runtime=MagicMock())
|
||||
result = middleware.after_model({"messages": []}, runtime=_make_runtime())
|
||||
assert result == {"title": "同步标题"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
|
||||
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
|
||||
assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
|
||||
|
||||
def test_sync_generate_title_uses_fallback_without_model(self):
|
||||
"""Sync path avoids LLM calls and derives a local fallback title."""
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
state = {
|
||||
@@ -165,12 +154,11 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="好的"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
|
||||
assert result == {"title": "请帮我写测试"}
|
||||
|
||||
def test_sync_generate_title_respects_fallback_truncation(self):
|
||||
"""Sync fallback path still respects max_chars truncation rules."""
|
||||
_set_test_title_config(max_chars=50)
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
state = {
|
||||
@@ -179,7 +167,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="回复"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
|
||||
assert result["title"].endswith("...")
|
||||
assert result["title"].startswith("这是一个非常长的问题描述")
|
||||
|
||||
|
||||
@@ -154,8 +154,7 @@ class TestStreamUsageIntegration:
|
||||
"""Test that stream() emits usage_metadata in messages-tuple and end events."""
|
||||
|
||||
def _make_client(self):
|
||||
with patch("deerflow.client.get_app_config", return_value=_mock_app_config()):
|
||||
return DeerFlowClient()
|
||||
return DeerFlowClient()
|
||||
|
||||
def test_stream_emits_usage_in_messages_tuple(self):
|
||||
"""messages-tuple AI event should include usage_metadata when present."""
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool as langchain_tool
|
||||
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
from deerflow.tools.builtins.tool_search import (
|
||||
DeferredToolRegistry,
|
||||
get_deferred_registry,
|
||||
@@ -64,12 +64,12 @@ class TestToolSearchConfig:
|
||||
config = ToolSearchConfig(enabled=True)
|
||||
assert config.enabled is True
|
||||
|
||||
def test_load_from_dict(self):
|
||||
config = load_tool_search_config_from_dict({"enabled": True})
|
||||
def test_validate_from_dict(self):
|
||||
config = ToolSearchConfig.model_validate({"enabled": True})
|
||||
assert config.enabled is True
|
||||
|
||||
def test_load_from_empty_dict(self):
|
||||
config = load_tool_search_config_from_dict({})
|
||||
def test_validate_from_empty_dict(self):
|
||||
config = ToolSearchConfig.model_validate({})
|
||||
assert config.enabled is False
|
||||
|
||||
|
||||
@@ -266,48 +266,42 @@ class TestToolSearchTool:
|
||||
|
||||
|
||||
class TestDeferredToolsPromptSection:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_app_config(self, monkeypatch):
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Provide a minimal AppConfig mock so tests don't need config.yaml."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.tool_search = ToolSearchConfig() # disabled by default
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: mock_config)
|
||||
config = MagicMock()
|
||||
config.tool_search = ToolSearchConfig() # disabled by default
|
||||
return config
|
||||
|
||||
def test_empty_when_disabled(self):
|
||||
def test_empty_when_disabled(self, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
|
||||
# tool_search.enabled defaults to False
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
|
||||
def test_empty_when_enabled_but_no_registry(self, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
mock_config.tool_search = ToolSearchConfig(enabled=True)
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
|
||||
def test_empty_when_enabled_but_empty_registry(self, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
mock_config.tool_search = ToolSearchConfig(enabled=True)
|
||||
set_deferred_registry(DeferredToolRegistry())
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert section == ""
|
||||
|
||||
def test_lists_tool_names(self, registry, monkeypatch):
|
||||
def test_lists_tool_names(self, registry, mock_config):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
mock_config.tool_search = ToolSearchConfig(enabled=True)
|
||||
set_deferred_registry(registry)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(mock_config)
|
||||
assert "<available-deferred-tools>" in section
|
||||
assert "</available-deferred-tools>" in section
|
||||
assert "github_create_issue" in section
|
||||
|
||||
@@ -13,7 +13,10 @@ from unittest.mock import MagicMock
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
THREAD_ID = "thread-abc123"
|
||||
|
||||
@@ -23,18 +26,27 @@ THREAD_ID = "thread-abc123"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
def _middleware(tmp_path: Path) -> UploadsMiddleware:
|
||||
return UploadsMiddleware(base_dir=str(tmp_path))
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock:
|
||||
rt = MagicMock()
|
||||
rt.context = {"thread_id": thread_id}
|
||||
rt.context = _make_context(thread_id or "")
|
||||
return rt
|
||||
|
||||
|
||||
def _uploads_dir(tmp_path: Path, thread_id: str = THREAD_ID) -> Path:
|
||||
d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id)
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from _router_auth_helpers import call_unwrapped
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.gateway.routers import uploads
|
||||
@@ -25,7 +26,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
@@ -107,7 +108,7 @@ def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)),
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(uploads.upload_files("thread-aio", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
@@ -146,7 +147,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
|
||||
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(uploads.upload_files("thread-aio", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
make_writable.assert_any_call(thread_uploads_dir / "report.pdf")
|
||||
@@ -170,7 +171,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
||||
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
make_writable.assert_not_called()
|
||||
@@ -221,13 +222,13 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
|
||||
# These filenames must be rejected outright
|
||||
for bad_name in ["..", "."]:
|
||||
file = UploadFile(filename=bad_name, file=BytesIO(b"data"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
assert result.success is True
|
||||
assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}"
|
||||
|
||||
# Path-traversal prefixes are stripped to the basename and accepted safely
|
||||
file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
assert result.files[0]["filename"] == "passwd"
|
||||
@@ -243,7 +244,7 @@ def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
|
||||
(thread_uploads_dir / "report.md").write_text("converted", encoding="utf-8")
|
||||
|
||||
with patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir):
|
||||
result = asyncio.run(uploads.delete_uploaded_file("thread-aio", "report.pdf"))
|
||||
result = asyncio.run(call_unwrapped(uploads.delete_uploaded_file, "thread-aio", "report.pdf", request=MagicMock()))
|
||||
|
||||
assert result == {"success": True, "message": "Deleted report.pdf"}
|
||||
assert not (thread_uploads_dir / "report.pdf").exists()
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Tests for runtime.user_context — contextvar three-state semantics.
|
||||
|
||||
These tests opt out of the autouse contextvar fixture (added in
|
||||
commit 6) because they explicitly test the cases where the contextvar
|
||||
is set or unset.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
CurrentUser,
|
||||
DEFAULT_USER_ID,
|
||||
get_current_user,
|
||||
get_effective_user_id,
|
||||
require_current_user,
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_default_is_none():
|
||||
"""Before any set, contextvar returns None."""
|
||||
assert get_current_user() is None
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_set_and_reset_roundtrip():
|
||||
"""set_current_user returns a token that reset restores."""
|
||||
user = SimpleNamespace(id="user-1")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_current_user() is user
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
assert get_current_user() is None
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_require_current_user_raises_when_unset():
|
||||
"""require_current_user raises RuntimeError if contextvar is unset."""
|
||||
assert get_current_user() is None
|
||||
with pytest.raises(RuntimeError, match="without user context"):
|
||||
require_current_user()
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_require_current_user_returns_user_when_set():
|
||||
"""require_current_user returns the user when contextvar is set."""
|
||||
user = SimpleNamespace(id="user-2")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert require_current_user() is user
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_protocol_accepts_duck_typed():
|
||||
"""CurrentUser is a runtime_checkable Protocol matching any .id-bearing object."""
|
||||
user = SimpleNamespace(id="user-3")
|
||||
assert isinstance(user, CurrentUser)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_protocol_rejects_no_id():
|
||||
"""Objects without .id do not satisfy CurrentUser Protocol."""
|
||||
not_a_user = SimpleNamespace(email="no-id@example.com")
|
||||
assert not isinstance(not_a_user, CurrentUser)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_user_id / DEFAULT_USER_ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_user_id_is_default():
|
||||
assert DEFAULT_USER_ID == "default"
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_returns_default_when_no_user():
|
||||
"""No user in context -> fallback to DEFAULT_USER_ID."""
|
||||
assert get_effective_user_id() == "default"
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_returns_user_id_when_set():
|
||||
user = SimpleNamespace(id="u-abc-123")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_effective_user_id() == "u-abc-123"
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_coerces_to_str():
|
||||
"""User.id might be a UUID object; must come back as str."""
|
||||
import uuid
|
||||
uid = uuid.uuid4()
|
||||
|
||||
user = SimpleNamespace(id=uid)
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_effective_user_id() == str(uid)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
Reference in New Issue
Block a user