Merge branch 'main' into fix-2788

This commit is contained in:
Willem Jiang
2026-05-27 08:29:21 +08:00
committed by GitHub
282 changed files with 25568 additions and 2124 deletions
+68
View File
@@ -0,0 +1,68 @@
"""Shared helpers for user-isolation e2e tests on the custom-agent tooling.
Centralises the small fake-LLM shim and a few test-data builders that the
three e2e files in this PR (``test_setup_agent_e2e_user_isolation``,
``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``)
all need. The shim is what lets a real ``langchain.agents.create_agent``
graph run without an API key — every other layer in those tests is real
production code, which is the entire point of the test design.
"""
from __future__ import annotations
from typing import Any
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import Runnable
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent.
``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to
expose the tool schemas to the model; the upstream fake raises
``NotImplementedError`` there. We just return ``self`` because we
drive deterministic tool_call output via ``responses=...``, no schema
handling needed.
"""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
def build_single_tool_call_model(
*,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str = "call_e2e_1",
final_text: str = "done",
) -> FakeToolCallingModel:
"""Build a fake model that emits exactly one tool_call then finishes.
Two-turn behaviour, identical across our e2e tests:
turn 1 → AIMessage with a single tool_call for *tool_name*
turn 2 → AIMessage with *final_text* (terminates the agent loop)
"""
return FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": tool_name,
"args": tool_args,
"id": tool_call_id,
"type": "tool_call",
}
],
),
AIMessage(content=final_text),
]
)
+37
View File
@@ -0,0 +1,37 @@
"""Pytest conftest for the strict Blockbuster runtime gate.
Activates `detect_blocking_io_strict()` around the entire pytest item
protocol (setup + call + teardown) so blocking IO in async fixtures and
lifespan code is also caught, not just blocking IO inside the test body.
Scope: only applies to items whose path is under `backend/tests/blocking_io/`.
Pytest registers conftest hookwrappers globally once the file is loaded,
so an explicit path filter is required to keep the strict gate from
firing on unrelated tests when the full suite is collected.
Opt-out: mark a test with `@pytest.mark.allow_blocking_io` to skip the gate.
"""
from __future__ import annotations
from collections.abc import Generator
from pathlib import Path
import pytest
from support.detectors.blocking_io_runtime import detect_blocking_io_strict
_BLOCKING_IO_TEST_ROOT = Path(__file__).resolve().parent
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_protocol(item: pytest.Item, nextitem: pytest.Item | None) -> Generator[None, None, None]:
if not _is_blocking_io_item(item) or item.get_closest_marker("allow_blocking_io") is not None:
yield
return
with detect_blocking_io_strict():
yield
def _is_blocking_io_item(item: pytest.Item) -> bool:
return Path(item.path).resolve().is_relative_to(_BLOCKING_IO_TEST_ROOT)
@@ -0,0 +1,55 @@
"""Smoke test: the strict Blockbuster gate is wired up and actively catching.
Independent of any specific production code path, asserts that calling a
known blocking IO function directly from an `async def` (without an
`asyncio.to_thread` wrapper) raises `BlockingError`. If this test ever
stops raising, the gate machinery itself is broken — typical causes are
`scanned_modules` misconfiguration, accidental removal of the Blockbuster
dev dependency, or the conftest hookwrapper no longer firing.
This is the meta-test that protects every other test in this directory
from silent regressions (a green gate that no longer catches anything is
worse than no gate at all).
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from blockbuster import BlockingError
from support.detectors.blocking_io_runtime import detect_blocking_io_strict
pytestmark = pytest.mark.asyncio
async def test_gate_catches_unoffloaded_blocking_io_in_deerflow_module(tmp_path: Path) -> None:
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir
db_file = tmp_path / "subdir" / "store.db"
with pytest.raises(BlockingError):
ensure_sqlite_parent_dir(str(db_file))
async def test_gate_restores_blockbuster_patches_after_exceptions() -> None:
original_stat = os.stat
with pytest.raises(RuntimeError, match="boom"):
with detect_blocking_io_strict():
raise RuntimeError("boom")
assert os.stat is original_stat
@pytest.mark.allow_blocking_io
async def test_allow_blocking_io_marker_opts_out_of_gate(tmp_path: Path) -> None:
"""Verify the @pytest.mark.allow_blocking_io opt-out actually disables the gate."""
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir
db_file = tmp_path / "subdir" / "store.db"
ensure_sqlite_parent_dir(str(db_file))
assert db_file.parent.exists()
@@ -0,0 +1,102 @@
"""Regression test: skill loading must remain releasable to a worker thread.
Anchors the production offload from `subagents/executor.py:_load_skills`,
where both `get_or_new_skill_storage` and the sync `storage.load_skills(...)`
method are dispatched via `asyncio.to_thread`. That fix addressed #1917,
where `os.walk` inside `load_skills` blocked the LangGraph async event loop.
This test invokes the production `_load_skills()` call path under the strict
Blockbuster context against a real `LocalSkillStorage` instance pointed at
a tmp directory. If the production `asyncio.to_thread` offload is removed,
Blockbuster raises `BlockingError` and this test fails.
"""
from __future__ import annotations
import importlib
import sys
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
pytestmark = pytest.mark.asyncio
_MISSING = object()
_EXECUTOR_IMPORT_MOCKS = (
"deerflow.agents",
"deerflow.agents.thread_state",
"deerflow.models",
)
def _seed_skill(skills_root: Path) -> None:
skill = skills_root / "public" / "demo"
skill.mkdir(parents=True, exist_ok=True)
(skill / "SKILL.md").write_text(
"---\nname: demo\ndescription: regression-test skill\n---\n# demo\n",
encoding="utf-8",
)
@contextmanager
def _real_subagent_executor() -> Iterator[type]:
"""Import the real executor despite the suite-level circular-import mock."""
original_modules = {name: sys.modules.get(name, _MISSING) for name in _EXECUTOR_IMPORT_MOCKS}
original_executor = sys.modules.get("deerflow.subagents.executor", _MISSING)
parent_module = sys.modules.get("deerflow.subagents")
original_parent_executor = getattr(parent_module, "executor", _MISSING) if parent_module is not None else _MISSING
sys.modules.pop("deerflow.subagents.executor", None)
for name in _EXECUTOR_IMPORT_MOCKS:
sys.modules[name] = MagicMock()
try:
executor_module = importlib.import_module("deerflow.subagents.executor")
yield executor_module.SubagentExecutor
finally:
if original_executor is _MISSING:
sys.modules.pop("deerflow.subagents.executor", None)
else:
sys.modules["deerflow.subagents.executor"] = original_executor
if parent_module is not None:
if original_parent_executor is _MISSING:
try:
delattr(parent_module, "executor")
except AttributeError:
pass
else:
parent_module.executor = original_parent_executor
for name, module in original_modules.items():
if module is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = module
async def test_load_skills_via_to_thread_does_not_block_event_loop(tmp_path: Path) -> None:
from deerflow.config.skills_config import SkillsConfig
from deerflow.subagents.config import SubagentConfig
_seed_skill(tmp_path)
with _real_subagent_executor() as SubagentExecutor:
executor = SubagentExecutor(
config=SubagentConfig(
name="demo",
description="Loads skills through the production async path.",
),
tools=[],
app_config=SimpleNamespace(skills=SkillsConfig(path=str(tmp_path))),
parent_model="test-model",
)
skills = await executor._load_skills()
assert isinstance(skills, list)
assert any(s.name == "demo" for s in skills)
@@ -0,0 +1,52 @@
"""Regression test: sqlite path setup must run off the event loop.
Anchors the production offload from
`runtime/checkpointer/async_provider.py:_async_checkpointer`, where SQLite
path resolution and `ensure_sqlite_parent_dir` are dispatched via
`await asyncio.to_thread(...)`.
That fix addressed #1912, where the sync `Path.mkdir` / `os.mkdir` inside
`ensure_sqlite_parent_dir` ran on the FastAPI lifespan event loop thread
and blocked startup.
This test invokes the production `_async_checkpointer()` path under the
strict Blockbuster context. The target path's parent does not yet exist, so
the underlying path resolution and `os.mkdir` both execute. If either step is
regressed to run directly on the event loop, Blockbuster raises
`BlockingError` and this test fails.
"""
from __future__ import annotations
import sys
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.asyncio
async def test_async_checkpointer_sqlite_setup_does_not_block_event_loop(tmp_path: Path) -> None:
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.checkpointer.async_provider import _async_checkpointer
db_file = tmp_path / "subdir" / "store.db"
mock_saver = AsyncMock()
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__.return_value = mock_saver
mock_context_manager.__aexit__.return_value = False
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string.return_value = mock_context_manager
mock_module = MagicMock()
mock_module.AsyncSqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}):
async with _async_checkpointer(CheckpointerConfig(type="sqlite", connection_string=str(db_file))) as saver:
assert saver is mock_saver
assert db_file.parent.exists()
mock_saver_cls.from_conn_string.assert_called_once_with(str(db_file.resolve()))
mock_saver.setup.assert_awaited_once()
+27
View File
@@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
issues when unit-testing lightweight config/registry code in isolation.
"""
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
@@ -83,6 +85,31 @@ def _reset_skill_storage_singleton():
reset_skill_storage()
@pytest.fixture(autouse=True)
def _restore_title_config_singleton():
"""Reset ``_title_config`` to its pristine default after every test.
``AppConfig.from_file()`` writes the on-disk ``title`` block into the
module-level singleton (``config/app_config.py`` calls
``load_title_config_from_dict``). Any test that loads the real
``config.yaml`` therefore leaves the singleton in a state that
``test_title_middleware_core_logic.py`` does not expect; that suite
relies on the pristine ``TitleConfig()`` default (``enabled=True``).
We restore the default after every test so test files stay
independent regardless of order.
"""
try:
from deerflow.config.title_config import reset_title_config
except ImportError:
yield
return
try:
yield
finally:
reset_title_config()
@pytest.fixture(autouse=True)
def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar.
+1
View File
@@ -0,0 +1 @@
"""Shared test support helpers."""
@@ -0,0 +1 @@
"""Runtime and static detectors used by tests."""
@@ -0,0 +1,44 @@
"""Strict Blockbuster runtime context scoped to DeerFlow business code.
Creates a `BlockBuster` instance with `scanned_modules=("app", "deerflow")`
so that test infrastructure (pytest, langchain, importlib, third-party libs)
is out of scope and does not produce false positives. Only loop-blocking
sync IO whose caller stack passes through `app.*` or `deerflow.*` raises
`BlockingError`.
Used by `backend/tests/blocking_io/conftest.py` to gate the regression suite.
"""
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from blockbuster import BlockBuster, BlockBusterFunction, BlockingError
_SCANNED_MODULES: tuple[str, ...] = ("app", "deerflow")
# Add DeerFlow-local rules here only when Blockbuster's default rule set misses
# a generic blocking primitive used by production code. If a path is invisible
# because no test exercises it, add a production-path runtime anchor instead.
_PROJECT_BLOCKING_RULES: tuple[tuple[str, BlockBusterFunction], ...] = ()
def _install_project_rules(bb: BlockBuster) -> None:
for name, rule in _PROJECT_BLOCKING_RULES:
bb.functions[name] = rule
@contextmanager
def detect_blocking_io_strict() -> Iterator[BlockBuster]:
"""Activate Blockbuster scoped to app.* and deerflow.* callers only."""
bb = BlockBuster(scanned_modules=list(_SCANNED_MODULES))
_install_project_rules(bb)
try:
bb.activate()
yield bb
finally:
bb.deactivate()
__all__ = ["BlockingError", "detect_blocking_io_strict"]
@@ -0,0 +1,892 @@
#!/usr/bin/env python3
"""Static inventory for likely backend event-loop blocking IO.
This detector parses backend business source with AST so untested paths are
still visible during review. Findings are prioritized static candidates, not
automatic bug decisions.
"""
from __future__ import annotations
import argparse
import ast
import json
import os
import sys
from collections import Counter, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[4]
DEFAULT_SCAN_PATHS = (
REPO_ROOT / "backend" / "app",
REPO_ROOT / "backend" / "packages" / "harness" / "deerflow",
REPO_ROOT / "backend" / "scripts",
)
IGNORED_DIR_NAMES = {
".git",
".mypy_cache",
".pytest_cache",
".ruff_cache",
".venv",
"__pycache__",
"node_modules",
}
CODE_SNIPPET_LIMIT = 200
PATH_METHOD_NAMES = {
"exists",
"glob",
"hardlink_to",
"is_dir",
"is_file",
"iterdir",
"mkdir",
"open",
"readlink",
"read_bytes",
"read_text",
"rename",
"resolve",
"rglob",
"rmdir",
"samefile",
"stat",
"symlink_to",
"touch",
"unlink",
"write_bytes",
"write_text",
}
AMBIGUOUS_PATH_METHOD_NAMES = {"replace"}
HTTP_METHOD_NAMES = {
"delete",
"get",
"head",
"options",
"patch",
"post",
"put",
"request",
"stream",
}
BUILTIN_OPEN_NAMES = {"builtins.open", "io.open", "open"}
BLOCKING_SLEEP_NAMES = {"time.sleep"}
BLOCKING_OS_FILE_NAMES = {
"os.listdir",
"os.lstat",
"os.makedirs",
"os.mkdir",
"os.remove",
"os.rename",
"os.replace",
"os.rmdir",
"os.scandir",
"os.stat",
"os.unlink",
"os.walk",
"os.path.exists",
"os.path.getsize",
"os.path.isdir",
"os.path.isfile",
}
BLOCKING_SUBPROCESS_NAMES = {
"subprocess.Popen",
"subprocess.check_call",
"subprocess.check_output",
"subprocess.run",
}
BLOCKING_HTTP_NAMES = {
"requests.delete",
"requests.get",
"requests.head",
"requests.options",
"requests.patch",
"requests.post",
"requests.put",
"requests.request",
"requests.sessions.Session.request",
"httpx.delete",
"httpx.get",
"httpx.head",
"httpx.options",
"httpx.patch",
"httpx.post",
"httpx.put",
"httpx.request",
"httpx.stream",
"urllib.request.urlopen",
}
SYNC_HTTP_CLIENT_FACTORIES = {
"httpx.Client": "httpx.Client",
"requests.Session": "requests.Session",
"requests.sessions.Session": "requests.Session",
"requests.session": "requests.Session",
}
BLOCKING_SHUTIL_NAMES = {
"shutil.copy",
"shutil.copyfile",
"shutil.copytree",
"shutil.move",
"shutil.rmtree",
}
SYNC_AGENT_MIDDLEWARE_HOOKS = {
"before_agent": "abefore_agent",
"before_model": "abefore_model",
"after_model": "aafter_model",
"after_agent": "aafter_agent",
}
PATH_METHOD_OPERATIONS = {
"exists": "FILE_METADATA",
"glob": "FILE_ENUMERATION",
"hardlink_to": "FILE_WRITE",
"is_dir": "FILE_METADATA",
"is_file": "FILE_METADATA",
"iterdir": "FILE_ENUMERATION",
"mkdir": "FILE_WRITE",
"open": "FILE_OPEN",
"readlink": "FILE_METADATA",
"read_bytes": "FILE_READ",
"read_text": "FILE_READ",
"rename": "FILE_COPY_MOVE",
"replace": "FILE_COPY_MOVE",
"resolve": "FILE_METADATA",
"rglob": "FILE_ENUMERATION",
"rmdir": "FILE_DELETE",
"samefile": "FILE_METADATA",
"stat": "FILE_METADATA",
"symlink_to": "FILE_WRITE",
"touch": "FILE_WRITE",
"unlink": "FILE_DELETE",
"write_bytes": "FILE_WRITE",
"write_text": "FILE_WRITE",
}
OS_FILE_OPERATIONS = {
"os.listdir": "FILE_ENUMERATION",
"os.lstat": "FILE_METADATA",
"os.makedirs": "FILE_WRITE",
"os.mkdir": "FILE_WRITE",
"os.remove": "FILE_DELETE",
"os.rename": "FILE_COPY_MOVE",
"os.replace": "FILE_COPY_MOVE",
"os.rmdir": "FILE_DELETE",
"os.scandir": "FILE_ENUMERATION",
"os.stat": "FILE_METADATA",
"os.unlink": "FILE_DELETE",
"os.walk": "FILE_ENUMERATION",
"os.path.exists": "FILE_METADATA",
"os.path.getsize": "FILE_METADATA",
"os.path.isdir": "FILE_METADATA",
"os.path.isfile": "FILE_METADATA",
}
SHUTIL_OPERATIONS = {
"shutil.copy": "FILE_COPY_MOVE",
"shutil.copyfile": "FILE_COPY_MOVE",
"shutil.copytree": "FILE_TREE_COPY",
"shutil.move": "FILE_COPY_MOVE",
"shutil.rmtree": "FILE_TREE_DELETE",
}
OPERATION_BASE_PRIORITY = {
"FILE_METADATA": "LOW",
"FILE_OPEN": "MEDIUM",
"FILE_READ": "MEDIUM",
"FILE_WRITE": "MEDIUM",
"FILE_ENUMERATION": "HIGH",
"FILE_DELETE": "MEDIUM",
"FILE_COPY_MOVE": "HIGH",
"FILE_TREE_COPY": "HIGH",
"FILE_TREE_DELETE": "HIGH",
"HTTP_REQUEST": "HIGH",
"SUBPROCESS": "HIGH",
"SLEEP": "HIGH",
"PARSE_ERROR": "MEDIUM",
}
@dataclass(frozen=True)
class BlockingIOStaticFinding:
category: str
operation: str
priority: str
path: str
line: int
column: int
function: str
exposure: str
symbol: str
code: str
def to_dict(self) -> dict[str, object]:
return {
"priority": self.priority,
"location": {
"path": self.path,
"line": self.line,
"column": self.column + 1,
"function": self.function,
},
"blocking_call": {
"category": self.category,
"operation": self.operation,
"symbol": self.symbol,
},
"event_loop_exposure": self.exposure,
"reason": _finding_reason(self.operation, self.exposure),
"code": self.code,
}
@dataclass(frozen=True)
class _FunctionContext:
qualname: str
class_name: str | None
is_async: bool
@dataclass(frozen=True)
class _FunctionInfo:
is_async: bool
@dataclass(frozen=True)
class _CallRef:
name: str
class_name: str | None
self_method: bool
@dataclass(frozen=True)
class _PotentialFinding:
category: str
operation: str
path: str
line: int
column: int
function: str
symbol: str
code: str
@dataclass(frozen=True)
class _BlockingRule:
category: str
operation: str
symbol: str
def dotted_name(node: ast.AST | None) -> str | None:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
parent = dotted_name(node.value)
if parent:
return f"{parent}.{node.attr}"
return node.attr
if isinstance(node, ast.Call):
return dotted_name(node.func)
if isinstance(node, ast.Subscript):
return dotted_name(node.value)
return None
def relative_to_repo(path: Path, repo_root: Path = REPO_ROOT) -> str:
try:
return path.resolve().relative_to(repo_root.resolve()).as_posix()
except ValueError:
return path.as_posix()
def _source_snippet(source_lines: Sequence[str], line: int) -> str:
if not 0 < line <= len(source_lines):
return ""
snippet = source_lines[line - 1].strip()
if len(snippet) <= CODE_SNIPPET_LIMIT:
return snippet
return f"{snippet[:CODE_SNIPPET_LIMIT]}..."
class BlockingIOStaticVisitor(ast.NodeVisitor):
def __init__(self, relative_path: str, source_lines: Sequence[str]) -> None:
self.relative_path = relative_path
self.source_lines = source_lines
self.import_aliases: dict[str, str] = {}
self.class_stack: list[str] = []
self.function_stack: list[_FunctionContext] = []
self.module_context = _FunctionContext("<module>", None, False)
self.module_sync_http_clients: dict[str, str] = {}
self.sync_http_client_stack: list[dict[str, str]] = []
self.class_bases: dict[str, set[str]] = defaultdict(set)
self.class_methods: dict[str, set[str]] = defaultdict(set)
self.function_defs: dict[str, _FunctionInfo] = {}
self.functions_by_name: dict[str, list[str]] = defaultdict(list)
self.call_refs: dict[str, list[_CallRef]] = defaultdict(list)
self.path_like_name_stack: list[set[str]] = []
self.potential_findings: list[_PotentialFinding] = []
@property
def current_function(self) -> _FunctionContext | None:
return self.function_stack[-1] if self.function_stack else None
@property
def current_context(self) -> _FunctionContext:
return self.current_function or self.module_context
@property
def current_sync_http_clients(self) -> dict[str, str]:
return self.sync_http_client_stack[-1] if self.sync_http_client_stack else self.module_sync_http_clients
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
local_name = alias.asname or alias.name.split(".", 1)[0]
canonical_name = alias.name if alias.asname else local_name
self.import_aliases[local_name] = canonical_name
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module is None:
return
for alias in node.names:
local_name = alias.asname or alias.name
self.import_aliases[local_name] = f"{node.module}.{alias.name}"
def visit_ClassDef(self, node: ast.ClassDef) -> None:
class_name = ".".join((*self.class_stack, node.name)) if self.class_stack else node.name
self.class_bases[class_name].update(canonical_name for base in node.bases if (canonical_name := self._canonical_name(dotted_name(base))) is not None)
self.class_stack.append(node.name)
self.generic_visit(node)
self.class_stack.pop()
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._visit_function(node, is_async=False)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self._visit_function(node, is_async=True)
def visit_Assign(self, node: ast.Assign) -> None:
self._record_sync_http_client_targets(node.value, node.targets)
self.generic_visit(node)
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self._record_path_like_annotation(node.annotation, [node.target])
if node.value is not None:
self._record_sync_http_client_targets(node.value, [node.target])
self.generic_visit(node)
def visit_With(self, node: ast.With) -> None:
temporary_clients: dict[str, str | None] = {}
current_clients = self.current_sync_http_clients
for item in node.items:
self.visit(item.context_expr)
client_base = self._sync_http_client_factory_base(item.context_expr)
if client_base is None or not isinstance(item.optional_vars, ast.Name):
continue
name = item.optional_vars.id
temporary_clients[name] = current_clients.get(name)
current_clients[name] = client_base
try:
for statement in node.body:
self.visit(statement)
finally:
for name, previous in temporary_clients.items():
if previous is None:
current_clients.pop(name, None)
else:
current_clients[name] = previous
def visit_Call(self, node: ast.Call) -> None:
current = self.current_context
call_name = self._canonical_name(dotted_name(node.func))
if call_name is not None:
self._record_call_ref(node, call_name, current)
self._record_blocking_candidate(node, call_name, current)
self.generic_visit(node)
def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, *, is_async: bool) -> None:
qualname = ".".join((*self.class_stack, node.name)) if self.class_stack else node.name
class_name = self.class_stack[-1] if self.class_stack else None
context = _FunctionContext(qualname, class_name, is_async)
self.function_defs[qualname] = _FunctionInfo(is_async)
self.functions_by_name[node.name].append(qualname)
if class_name is not None:
self.class_methods[class_name].add(node.name)
self.function_stack.append(context)
self.sync_http_client_stack.append({})
self.path_like_name_stack.append(set(_path_like_argument_names(node.args, self._canonical_name)))
self.generic_visit(node)
self.path_like_name_stack.pop()
self.sync_http_client_stack.pop()
self.function_stack.pop()
def _canonical_name(self, name: str | None) -> str | None:
if name is None:
return None
parts = name.split(".")
if parts and parts[0] in self.import_aliases:
return ".".join((self.import_aliases[parts[0]], *parts[1:]))
return name
def _record_call_ref(self, node: ast.Call, call_name: str, current: _FunctionContext) -> None:
if current.qualname == "<module>":
return
if isinstance(node.func, ast.Name):
self.call_refs[current.qualname].append(_CallRef(node.func.id, current.class_name, self_method=False))
return
if not isinstance(node.func, ast.Attribute):
return
receiver = dotted_name(node.func.value)
if receiver in {"self", "cls"}:
self.call_refs[current.qualname].append(_CallRef(node.func.attr, current.class_name, self_method=True))
return
# Keep same-module direct calls through canonical aliases out of the call graph.
# External calls are handled as blocking candidates instead.
if "." not in call_name:
self.call_refs[current.qualname].append(_CallRef(call_name, current.class_name, self_method=False))
def _record_blocking_candidate(self, node: ast.Call, call_name: str, current: _FunctionContext) -> None:
rule = self._blocking_rule(node, call_name)
if rule is None:
return
line = getattr(node, "lineno", 0)
column = getattr(node, "col_offset", 0)
code = _source_snippet(self.source_lines, line)
self.potential_findings.append(
_PotentialFinding(
category=rule.category,
operation=rule.operation,
path=self.relative_path,
line=line,
column=column,
function=current.qualname,
symbol=rule.symbol,
code=code,
)
)
def _blocking_rule(self, node: ast.Call, call_name: str) -> _BlockingRule | None:
sync_client_symbol = self._sync_http_client_method_symbol(call_name)
if sync_client_symbol is not None:
return _BlockingRule("BLOCKING_HTTP_IO", "HTTP_REQUEST", sync_client_symbol)
chained_client_symbol = _sync_http_client_chained_method_symbol(call_name)
if chained_client_symbol is not None:
return _BlockingRule("BLOCKING_HTTP_IO", "HTTP_REQUEST", chained_client_symbol)
leaf_name = call_name.rsplit(".", 1)[-1]
if call_name in BUILTIN_OPEN_NAMES:
return _BlockingRule("BLOCKING_FILE_IO", "FILE_OPEN", call_name)
if leaf_name in PATH_METHOD_NAMES | AMBIGUOUS_PATH_METHOD_NAMES:
if self._is_path_method_call(node):
return _BlockingRule("BLOCKING_FILE_IO", _path_method_operation(leaf_name), call_name)
if call_name in BLOCKING_OS_FILE_NAMES:
return _BlockingRule("BLOCKING_FILE_IO", OS_FILE_OPERATIONS[call_name], call_name)
if call_name in BLOCKING_SLEEP_NAMES:
return _BlockingRule("BLOCKING_SLEEP", "SLEEP", call_name)
if call_name in BLOCKING_SUBPROCESS_NAMES:
return _BlockingRule("BLOCKING_SUBPROCESS", "SUBPROCESS", call_name)
if call_name in BLOCKING_HTTP_NAMES:
return _BlockingRule("BLOCKING_HTTP_IO", "HTTP_REQUEST", call_name)
if call_name in BLOCKING_SHUTIL_NAMES:
return _BlockingRule("BLOCKING_FILE_IO", SHUTIL_OPERATIONS[call_name], call_name)
return None
def _is_path_method_call(self, node: ast.Call) -> bool:
if not isinstance(node.func, ast.Attribute):
return False
if node.func.attr in AMBIGUOUS_PATH_METHOD_NAMES and node.func.attr == "replace" and len(node.args) >= 2:
return False
receiver = node.func.value
if _is_constructed_path(receiver):
return True
receiver_name = dotted_name(receiver)
if receiver_name in self.current_path_like_names:
return True
if _looks_like_path_receiver_name(receiver_name):
return True
if node.func.attr in PATH_METHOD_NAMES and isinstance(receiver, ast.Attribute):
return True
return False
@property
def current_path_like_names(self) -> set[str]:
return self.path_like_name_stack[-1] if self.path_like_name_stack else set()
def _record_path_like_annotation(self, annotation: ast.AST, targets: Iterable[ast.AST]) -> None:
if not self.path_like_name_stack or not _is_path_annotation(annotation, self._canonical_name):
return
self.current_path_like_names.update(name for target in targets for name in _iter_assigned_names(target))
def _record_sync_http_client_targets(self, value: ast.AST, targets: Iterable[ast.AST]) -> None:
client_base = self._sync_http_client_factory_base(value)
if client_base is None:
return
current_clients = self.current_sync_http_clients
for target in targets:
for name in _iter_assigned_names(target):
current_clients[name] = client_base
def _sync_http_client_factory_base(self, node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
call_name = self._canonical_name(dotted_name(node.func))
if call_name is None:
return None
return SYNC_HTTP_CLIENT_FACTORIES.get(call_name)
def _sync_http_client_method_symbol(self, call_name: str) -> str | None:
parts = call_name.split(".")
if len(parts) != 2 or parts[1] not in HTTP_METHOD_NAMES:
return None
client_base = self.current_sync_http_clients.get(parts[0])
if client_base is None:
return None
return f"{client_base}.{parts[1]}"
def _path_method_operation(method_name: str) -> str:
return PATH_METHOD_OPERATIONS.get(method_name, "FILE_METADATA")
def _is_constructed_path(node: ast.AST) -> bool:
return isinstance(node, ast.Call) and dotted_name(node.func) in {"Path", "pathlib.Path"}
def _looks_like_path_receiver_name(receiver_name: str | None) -> bool:
if receiver_name is None:
return False
leaf = receiver_name.rsplit(".", 1)[-1].lower()
return leaf in {"path", "file_path", "dir_path", "target", "dest", "destination", "source"} or leaf.endswith(("_path", "_dir", "_file", "_root")) or "path" in leaf
def _is_path_annotation(annotation: ast.AST | None, canonical_name: Callable[[str | None], str | None]) -> bool:
if annotation is None:
return False
if isinstance(annotation, ast.BinOp) and isinstance(annotation.op, ast.BitOr):
return _is_path_annotation(annotation.left, canonical_name) or _is_path_annotation(annotation.right, canonical_name)
name = dotted_name(annotation)
canonical = canonical_name(name)
if canonical in {"pathlib.Path", "Path"}:
return True
if isinstance(annotation, ast.Subscript):
return _is_path_annotation(annotation.slice, canonical_name)
return False
def _path_like_argument_names(arguments: ast.arguments, canonical_name: Callable[[str | None], str | None]) -> Iterable[str]:
candidates = [*arguments.posonlyargs, *arguments.args, *arguments.kwonlyargs]
if arguments.vararg is not None:
candidates.append(arguments.vararg)
if arguments.kwarg is not None:
candidates.append(arguments.kwarg)
for argument in candidates:
if _is_path_annotation(argument.annotation, canonical_name):
yield argument.arg
def _iter_assigned_names(target: ast.AST) -> Iterable[str]:
if isinstance(target, ast.Name):
yield target.id
return
if isinstance(target, (ast.Tuple, ast.List)):
for element in target.elts:
yield from _iter_assigned_names(element)
def _sync_http_client_chained_method_symbol(call_name: str) -> str | None:
for factory_name, client_base in SYNC_HTTP_CLIENT_FACTORIES.items():
prefix = f"{factory_name}."
if not call_name.startswith(prefix):
continue
method_name = call_name[len(prefix) :]
if method_name in HTTP_METHOD_NAMES:
return f"{client_base}.{method_name}"
return None
def _resolve_call_ref(visitor: BlockingIOStaticVisitor, ref: _CallRef) -> list[str]:
if ref.self_method and ref.class_name is not None:
qualname = f"{ref.class_name}.{ref.name}"
return [qualname] if qualname in visitor.function_defs else []
return list(visitor.functions_by_name.get(ref.name, ()))
def _reachable_functions(visitor: BlockingIOStaticVisitor, roots: Iterable[str]) -> set[str]:
reachable = set(roots)
queue: deque[str] = deque(reachable)
while queue:
qualname = queue.popleft()
for ref in visitor.call_refs.get(qualname, ()):
for target in _resolve_call_ref(visitor, ref):
if target in reachable:
continue
reachable.add(target)
queue.append(target)
return reachable
def _async_reachable_functions(visitor: BlockingIOStaticVisitor) -> set[str]:
return _reachable_functions(
visitor,
(qualname for qualname, info in visitor.function_defs.items() if info.is_async),
)
def _agent_middleware_classes(visitor: BlockingIOStaticVisitor) -> set[str]:
middleware_classes: set[str] = set()
changed = True
while changed:
changed = False
for class_name, bases in visitor.class_bases.items():
if class_name in middleware_classes:
continue
if any(_is_agent_middleware_base(base, middleware_classes) for base in bases):
middleware_classes.add(class_name)
changed = True
return middleware_classes
def _is_agent_middleware_base(base: str, known_middleware_classes: set[str]) -> bool:
leaf = base.rsplit(".", 1)[-1]
return leaf == "AgentMiddleware" or leaf in known_middleware_classes
def _sync_only_agent_middleware_entrypoints(visitor: BlockingIOStaticVisitor) -> set[str]:
entrypoints: set[str] = set()
middleware_classes = _agent_middleware_classes(visitor)
for class_name in middleware_classes:
methods = visitor.class_methods.get(class_name, set())
for sync_hook, async_hook in SYNC_AGENT_MIDDLEWARE_HOOKS.items():
if sync_hook in methods and async_hook not in methods:
qualname = f"{class_name}.{sync_hook}"
if qualname in visitor.function_defs:
entrypoints.add(qualname)
return entrypoints
def _event_loop_exposures(
visitor: BlockingIOStaticVisitor,
async_reachable: set[str],
middleware_reachable: set[str],
) -> dict[str, str]:
exposures: dict[str, str] = {}
for qualname, info in visitor.function_defs.items():
if info.is_async:
exposures[qualname] = "DIRECT_ASYNC"
for qualname in async_reachable:
exposures.setdefault(qualname, "ASYNC_REACHABLE_SAME_FILE")
for qualname in middleware_reachable:
exposures.setdefault(qualname, "SYNC_AGENT_MIDDLEWARE_HOOK")
return exposures
def _priority(operation: str) -> str:
return OPERATION_BASE_PRIORITY[operation]
def _finding_reason(operation: str, exposure: str) -> str:
if exposure == "DIRECT_ASYNC":
return f"{operation} is called directly inside an async function."
if exposure == "ASYNC_REACHABLE_SAME_FILE":
return f"{operation} is statically reachable from an async function in the same file."
if exposure == "SYNC_AGENT_MIDDLEWARE_HOOK":
return f"{operation} is statically reachable from a sync AgentMiddleware hook used by the async graph."
return "Source could not be parsed; scan coverage is incomplete for this file."
def _finalize_findings(visitor: BlockingIOStaticVisitor) -> list[BlockingIOStaticFinding]:
reachable = _async_reachable_functions(visitor)
middleware_reachable = _reachable_functions(visitor, _sync_only_agent_middleware_entrypoints(visitor))
event_loop_exposures = _event_loop_exposures(visitor, reachable, middleware_reachable)
findings: list[BlockingIOStaticFinding] = []
for candidate in visitor.potential_findings:
exposure = event_loop_exposures.get(candidate.function)
if exposure is None:
continue
findings.append(
BlockingIOStaticFinding(
category=candidate.category,
operation=candidate.operation,
priority=_priority(candidate.operation),
path=candidate.path,
line=candidate.line,
column=candidate.column,
function=candidate.function,
exposure=exposure,
symbol=candidate.symbol,
code=candidate.code,
)
)
return findings
def scan_file(path: Path, *, repo_root: Path = REPO_ROOT) -> list[BlockingIOStaticFinding]:
source = path.read_text(encoding="utf-8")
source_lines = source.splitlines()
relative_path = relative_to_repo(path, repo_root)
try:
tree = ast.parse(source, filename=str(path))
except SyntaxError as exc:
line = exc.lineno or 0
code = _source_snippet(source_lines, line)
return [
BlockingIOStaticFinding(
category="PARSE_ERROR",
operation="PARSE_ERROR",
priority="MEDIUM",
path=relative_path,
line=line,
column=max((exc.offset or 1) - 1, 0),
function="<module>",
exposure="PARSE_INCOMPLETE",
symbol="SyntaxError",
code=code,
)
]
visitor = BlockingIOStaticVisitor(relative_path, source_lines)
visitor.visit(tree)
return sorted(_finalize_findings(visitor), key=lambda finding: (finding.path, finding.line, finding.column, finding.category))
def is_ignored_path(path: Path) -> bool:
return any(part in IGNORED_DIR_NAMES for part in path.parts)
def iter_python_files(paths: Iterable[Path]) -> Iterable[Path]:
for path in paths:
if not path.exists() or is_ignored_path(path):
continue
if path.is_file():
if path.suffix == ".py" and not is_ignored_path(path):
yield path
continue
for dirpath, dirnames, filenames in os.walk(path):
dirnames[:] = [dirname for dirname in dirnames if dirname not in IGNORED_DIR_NAMES]
for filename in filenames:
if filename.endswith(".py"):
yield Path(dirpath) / filename
def scan_paths(paths: Iterable[Path], *, repo_root: Path = REPO_ROOT) -> list[BlockingIOStaticFinding]:
findings: list[BlockingIOStaticFinding] = []
for path in sorted(iter_python_files(paths)):
findings.extend(scan_file(path, repo_root=repo_root))
return sorted(findings, key=lambda finding: (finding.path, finding.line, finding.column, finding.category))
def findings_to_json(findings: Sequence[BlockingIOStaticFinding]) -> str:
return json.dumps([finding.to_dict() for finding in findings], indent=2) + "\n"
def write_json_report(findings: Sequence[BlockingIOStaticFinding], output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(findings_to_json(findings), encoding="utf-8")
def _scan_root(path: str) -> str:
parts = path.split("/")
if parts[:4] == ["backend", "packages", "harness", "deerflow"]:
return "backend/packages/harness/deerflow"
if len(parts) >= 2 and parts[0] == "backend":
return "/".join(parts[:2])
return parts[0] if parts else path
def _format_counter(title: str, counter: Counter[str], *, limit: int | None = None, order: Sequence[str] | None = None) -> list[str]:
lines = [title]
if order is None:
items = sorted(counter.items(), key=lambda item: (-item[1], item[0]))
else:
ordered = [(name, counter[name]) for name in order if counter.get(name)]
ordered_names = {name for name, _ in ordered}
extras = sorted((item for item in counter.items() if item[0] not in ordered_names), key=lambda item: (-item[1], item[0]))
items = ordered + extras
if limit is not None:
items = items[:limit]
width = max((len(str(count)) for _, count in items), default=1)
lines.extend(f" {count:>{width}} {name}" for name, count in items)
return lines
def format_summary(findings: Sequence[BlockingIOStaticFinding], *, output_path: Path | None = None) -> str:
if not findings:
lines = ["No static blocking IO event-loop risk findings in backend business code."]
else:
lines = [
f"Static blocking IO event-loop risk findings: {len(findings)}",
"",
*_format_counter("By category:", Counter(finding.category for finding in findings)),
"",
*_format_counter("By priority:", Counter(finding.priority for finding in findings), order=("HIGH", "MEDIUM", "LOW")),
"",
*_format_counter("By operation:", Counter(finding.operation for finding in findings)),
"",
*_format_counter("By event-loop exposure:", Counter(finding.exposure for finding in findings)),
"",
*_format_counter("By scan root:", Counter(_scan_root(finding.path) for finding in findings)),
"",
*_format_counter("Top files:", Counter(finding.path for finding in findings), limit=10),
]
if output_path is not None:
lines.extend(["", f"Full JSON report: {relative_to_repo(output_path.resolve())}"])
else:
lines.extend(["", "Use --format json for full structured findings."])
return "\n".join(lines)
def format_text(findings: Sequence[BlockingIOStaticFinding]) -> str:
if not findings:
return "No static blocking IO event-loop risk findings in backend business code."
lines: list[str] = []
for finding in findings:
lines.append(f"{finding.priority} {finding.category}/{finding.operation} {finding.path}:{finding.line}:{finding.column + 1} in {finding.function} exposure={finding.exposure}")
lines.append(f" symbol: {finding.symbol}")
lines.append(f" reason: {_finding_reason(finding.operation, finding.exposure)}")
if finding.code:
lines.append(f" code: {finding.code}")
return "\n".join(lines)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=("Statically inventory blocking IO calls that may block the backend asyncio event loop. Findings are prioritized review candidates, not automatic bug decisions."))
parser.add_argument(
"paths",
nargs="*",
type=Path,
help="Files or directories to scan. Defaults to backend app and harness sources.",
)
parser.add_argument(
"--format",
choices=("summary", "text", "json"),
default="summary",
help="Output format.",
)
parser.add_argument(
"--output",
type=Path,
help="Write the complete finding list as JSON to this file.",
)
return parser
def main(argv: Sequence[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
paths = args.paths or list(DEFAULT_SCAN_PATHS)
findings = scan_paths(paths)
output_path = args.output
if output_path is not None:
write_json_report(findings, output_path)
if args.format == "summary":
print(format_summary(findings, output_path=output_path))
elif args.format == "json":
print(findings_to_json(findings), end="")
else:
print(format_text(findings))
return 0
if __name__ == "__main__":
sys.exit(main())
@@ -0,0 +1,507 @@
#!/usr/bin/env python3
"""Inventory async/thread boundary points for developer review.
This detector is intentionally non-invasive: it parses Python source with AST
and reports places where code crosses sync/async/thread boundaries. Findings
are review evidence, not automatic bug decisions.
"""
from __future__ import annotations
import argparse
import ast
import json
import os
import sys
from collections.abc import Iterable, Sequence
from dataclasses import asdict, dataclass
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[4]
DEFAULT_SCAN_PATHS = (
REPO_ROOT / "backend" / "app",
REPO_ROOT / "backend" / "packages" / "harness" / "deerflow",
)
IGNORED_DIR_NAMES = {
".git",
".mypy_cache",
".pytest_cache",
".ruff_cache",
".venv",
"__pycache__",
"node_modules",
}
SEVERITY_ORDER = {"INFO": 0, "WARN": 1, "FAIL": 2}
@dataclass(frozen=True)
class BoundaryFinding:
severity: str
category: str
path: str
line: int
column: int
function: str
async_context: bool
symbol: str
message: str
code: str
def to_dict(self) -> dict[str, object]:
return asdict(self)
@dataclass(frozen=True)
class _FunctionContext:
name: str
is_async: bool
@dataclass(frozen=True)
class _CallRule:
severity: str
category: str
message: str
EXACT_CALL_RULES: dict[str, _CallRule] = {
"asyncio.run": _CallRule(
"WARN",
"SYNC_ASYNC_BRIDGE",
"Runs a coroutine from synchronous code by creating an event loop boundary.",
),
"asyncio.to_thread": _CallRule(
"INFO",
"ASYNC_THREAD_OFFLOAD",
"Offloads synchronous work from an async context into a worker thread.",
),
"asyncio.new_event_loop": _CallRule(
"WARN",
"NEW_EVENT_LOOP",
"Creates a separate event loop; review resource ownership across loops.",
),
"asyncio.run_coroutine_threadsafe": _CallRule(
"WARN",
"CROSS_THREAD_COROUTINE",
"Submits a coroutine to an event loop from another thread.",
),
"concurrent.futures.ThreadPoolExecutor": _CallRule(
"INFO",
"THREAD_POOL",
"Creates a thread pool boundary.",
),
"threading.Thread": _CallRule(
"INFO",
"RAW_THREAD",
"Creates a raw thread; ContextVar values do not propagate automatically.",
),
"threading.Timer": _CallRule(
"INFO",
"RAW_TIMER_THREAD",
"Creates a timer-backed raw thread; ContextVar values do not propagate automatically.",
),
"make_sync_tool_wrapper": _CallRule(
"INFO",
"SYNC_TOOL_WRAPPER",
"Adapts an async tool coroutine for synchronous tool invocation.",
),
}
THREAD_POOL_CONSTRUCTORS = {"concurrent.futures.ThreadPoolExecutor"}
ASYNC_TOOL_FACTORY_CALLS = {
"StructuredTool.from_function",
"langchain.tools.StructuredTool.from_function",
"langchain_core.tools.StructuredTool.from_function",
}
LANGCHAIN_INVOKE_RECEIVER_NAMES = {
"agent",
"chain",
"chat_model",
"graph",
"llm",
"model",
"runnable",
}
LANGCHAIN_INVOKE_RECEIVER_SUFFIXES = (
"_agent",
"_chain",
"_graph",
"_llm",
"_model",
"_runnable",
)
ASYNC_BLOCKING_CALL_RULES: dict[str, _CallRule] = {
"time.sleep": _CallRule(
"WARN",
"BLOCKING_CALL_IN_ASYNC",
"Blocks the event loop when called directly inside async code.",
),
"subprocess.run": _CallRule(
"WARN",
"BLOCKING_SUBPROCESS_IN_ASYNC",
"Runs a blocking subprocess from async code.",
),
"subprocess.check_call": _CallRule(
"WARN",
"BLOCKING_SUBPROCESS_IN_ASYNC",
"Runs a blocking subprocess from async code.",
),
"subprocess.check_output": _CallRule(
"WARN",
"BLOCKING_SUBPROCESS_IN_ASYNC",
"Runs a blocking subprocess from async code.",
),
"subprocess.Popen": _CallRule(
"WARN",
"BLOCKING_SUBPROCESS_IN_ASYNC",
"Starts a subprocess from async code; review whether it blocks later.",
),
}
def dotted_name(node: ast.AST | None) -> str | None:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
parent = dotted_name(node.value)
if parent:
return f"{parent}.{node.attr}"
return node.attr
return None
def call_receiver_name(node: ast.Call) -> str | None:
if not isinstance(node.func, ast.Attribute):
return None
return dotted_name(node.func.value)
def is_none_node(node: ast.AST | None) -> bool:
return isinstance(node, ast.Constant) and node.value is None
class BoundaryVisitor(ast.NodeVisitor):
def __init__(self, path: Path, relative_path: str, source_lines: Sequence[str]) -> None:
self.path = path
self.relative_path = relative_path
self.source_lines = source_lines
self.findings: list[BoundaryFinding] = []
self.function_stack: list[_FunctionContext] = []
self.import_aliases: dict[str, str] = {}
self.executor_names: set[str] = set()
@property
def current_function(self) -> str:
if not self.function_stack:
return "<module>"
return ".".join(context.name for context in self.function_stack)
@property
def in_async_context(self) -> bool:
return bool(self.function_stack and self.function_stack[-1].is_async)
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
local_name = alias.asname or alias.name.split(".", 1)[0]
canonical_name = alias.name if alias.asname else local_name
self.import_aliases[local_name] = canonical_name
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module is None:
return
for alias in node.names:
local_name = alias.asname or alias.name
self.import_aliases[local_name] = f"{node.module}.{alias.name}"
def visit_Assign(self, node: ast.Assign) -> None:
self._record_executor_targets(node.value, node.targets)
self.generic_visit(node)
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value is not None:
self._record_executor_targets(node.value, [node.target])
self.generic_visit(node)
def visit_With(self, node: ast.With) -> None:
for item in node.items:
if item.optional_vars is not None:
self._record_executor_targets(item.context_expr, [item.optional_vars])
self.generic_visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.function_stack.append(_FunctionContext(node.name, is_async=False))
self.generic_visit(node)
self.function_stack.pop()
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.function_stack.append(_FunctionContext(node.name, is_async=True))
try:
self._check_async_tool_definition(node)
self.generic_visit(node)
finally:
self.function_stack.pop()
def visit_Call(self, node: ast.Call) -> None:
call_name = self._canonical_name(dotted_name(node.func))
if call_name:
self._check_call(node, call_name)
self.generic_visit(node)
def _check_async_tool_definition(self, node: ast.AsyncFunctionDef) -> None:
for decorator in node.decorator_list:
decorator_call = decorator.func if isinstance(decorator, ast.Call) else decorator
decorator_name = self._canonical_name(dotted_name(decorator_call))
if decorator_name in {"langchain.tools.tool", "langchain_core.tools.tool"}:
self._emit(
node,
severity="INFO",
category="ASYNC_TOOL_DEFINITION",
symbol=decorator_name,
message="Defines an async LangChain tool; sync clients need a wrapper before invoke().",
)
return
def _check_call(self, node: ast.Call, call_name: str) -> None:
rule = EXACT_CALL_RULES.get(call_name)
if rule:
self._emit_rule(node, call_name, rule)
if call_name.endswith(".run_until_complete"):
self._emit(
node,
severity="WARN",
category="RUN_UNTIL_COMPLETE",
symbol=call_name,
message="Drives an event loop from synchronous code; review nested-loop behavior.",
)
if self._is_executor_submit(node, call_name):
self._emit(
node,
severity="INFO",
category="EXECUTOR_SUBMIT",
symbol=call_name,
message="Submits work to an executor; review context propagation and cancellation.",
)
if call_name in ASYNC_TOOL_FACTORY_CALLS:
if any(keyword.arg == "coroutine" and not is_none_node(keyword.value) for keyword in node.keywords):
self._emit(
node,
severity="INFO",
category="ASYNC_ONLY_TOOL_FACTORY",
symbol=call_name,
message="Creates a StructuredTool from a coroutine; sync clients need a wrapper.",
)
if self.in_async_context and call_name in ASYNC_BLOCKING_CALL_RULES:
self._emit_rule(node, call_name, ASYNC_BLOCKING_CALL_RULES[call_name])
if self.in_async_context and self._is_langchain_invoke(node, call_name, method_name="invoke"):
self._emit(
node,
severity="WARN",
category="SYNC_INVOKE_IN_ASYNC",
symbol=call_name,
message="Calls a synchronous invoke() from async code; review event-loop blocking.",
)
if not self.in_async_context and self._is_langchain_invoke(node, call_name, method_name="ainvoke"):
self._emit(
node,
severity="WARN",
category="ASYNC_INVOKE_IN_SYNC",
symbol=call_name,
message="Calls async ainvoke() from sync code; review how the coroutine is awaited.",
)
def _canonical_name(self, name: str | None) -> str | None:
if name is None:
return None
parts = name.split(".")
if parts and parts[0] in self.import_aliases:
return ".".join((self.import_aliases[parts[0]], *parts[1:]))
return name
def _record_executor_targets(self, value: ast.AST, targets: Sequence[ast.AST]) -> None:
if not isinstance(value, ast.Call):
return
call_name = self._canonical_name(dotted_name(value.func))
if call_name not in THREAD_POOL_CONSTRUCTORS:
return
for target in targets:
for name in self._target_names(target):
self.executor_names.add(name)
def _target_names(self, target: ast.AST) -> Iterable[str]:
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, (ast.Tuple, ast.List)):
for element in target.elts:
yield from self._target_names(element)
def _is_executor_submit(self, node: ast.Call, call_name: str) -> bool:
if not call_name.endswith(".submit"):
return False
receiver_name = call_receiver_name(node)
return receiver_name in self.executor_names
def _is_langchain_invoke(self, node: ast.Call, call_name: str, *, method_name: str) -> bool:
if not call_name.endswith(f".{method_name}"):
return False
receiver_name = call_receiver_name(node)
if receiver_name is None:
return False
receiver_leaf = receiver_name.rsplit(".", 1)[-1]
return receiver_leaf in LANGCHAIN_INVOKE_RECEIVER_NAMES or receiver_leaf.endswith(LANGCHAIN_INVOKE_RECEIVER_SUFFIXES)
def _emit_rule(self, node: ast.AST, symbol: str, rule: _CallRule) -> None:
self._emit(
node,
severity=rule.severity,
category=rule.category,
symbol=symbol,
message=rule.message,
)
def _emit(self, node: ast.AST, *, severity: str, category: str, symbol: str, message: str) -> None:
line = getattr(node, "lineno", 0)
column = getattr(node, "col_offset", 0)
code = ""
if line > 0 and line <= len(self.source_lines):
code = self.source_lines[line - 1].strip()
self.findings.append(
BoundaryFinding(
severity=severity,
category=category,
path=self.relative_path,
line=line,
column=column,
function=self.current_function,
async_context=self.in_async_context,
symbol=symbol,
message=message,
code=code,
)
)
def relative_to_repo(path: Path, repo_root: Path = REPO_ROOT) -> str:
try:
return path.resolve().relative_to(repo_root.resolve()).as_posix()
except ValueError:
return path.as_posix()
def scan_file(path: Path, *, repo_root: Path = REPO_ROOT) -> list[BoundaryFinding]:
source = path.read_text(encoding="utf-8")
source_lines = source.splitlines()
relative_path = relative_to_repo(path, repo_root)
try:
tree = ast.parse(source, filename=str(path))
except SyntaxError as exc:
line = exc.lineno or 0
code = source_lines[line - 1].strip() if line > 0 and line <= len(source_lines) else ""
return [
BoundaryFinding(
severity="WARN",
category="PARSE_ERROR",
path=relative_path,
line=line,
column=max((exc.offset or 1) - 1, 0),
function="<module>",
async_context=False,
symbol="SyntaxError",
message=str(exc),
code=code,
)
]
visitor = BoundaryVisitor(path, relative_path, source_lines)
visitor.visit(tree)
return visitor.findings
def is_ignored_path(path: Path) -> bool:
return any(part in IGNORED_DIR_NAMES for part in path.parts)
def iter_python_files(paths: Iterable[Path]) -> Iterable[Path]:
for path in paths:
if not path.exists() or is_ignored_path(path):
continue
if path.is_file():
if path.suffix == ".py" and not is_ignored_path(path):
yield path
continue
for dirpath, dirnames, filenames in os.walk(path):
dirnames[:] = [dirname for dirname in dirnames if dirname not in IGNORED_DIR_NAMES]
for filename in filenames:
if filename.endswith(".py"):
yield Path(dirpath) / filename
def scan_paths(paths: Iterable[Path], *, repo_root: Path = REPO_ROOT) -> list[BoundaryFinding]:
findings: list[BoundaryFinding] = []
for path in sorted(iter_python_files(paths)):
findings.extend(scan_file(path, repo_root=repo_root))
return sorted(findings, key=lambda finding: (finding.path, finding.line, finding.column, finding.category))
def filter_findings(findings: Iterable[BoundaryFinding], min_severity: str) -> list[BoundaryFinding]:
threshold = SEVERITY_ORDER[min_severity]
return [finding for finding in findings if SEVERITY_ORDER[finding.severity] >= threshold]
def format_text(findings: Sequence[BoundaryFinding]) -> str:
if not findings:
return "No async/thread boundary findings."
lines: list[str] = []
for finding in findings:
lines.append(f"{finding.severity} {finding.category} {finding.path}:{finding.line}:{finding.column + 1} in {finding.function} async={str(finding.async_context).lower()}")
lines.append(f" symbol: {finding.symbol}")
lines.append(f" note: {finding.message}")
if finding.code:
lines.append(f" code: {finding.code}")
return "\n".join(lines)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=("Detect async/thread boundary points for developer review. Findings are an inventory, not automatic bug decisions."))
parser.add_argument(
"paths",
nargs="*",
type=Path,
help="Files or directories to scan. Defaults to backend app and harness sources.",
)
parser.add_argument(
"--format",
choices=("text", "json"),
default="text",
help="Output format.",
)
parser.add_argument(
"--min-severity",
choices=tuple(SEVERITY_ORDER),
default="INFO",
help="Only show findings at or above this severity.",
)
return parser
def main(argv: Sequence[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
paths = args.paths or list(DEFAULT_SCAN_PATHS)
findings = filter_findings(scan_paths(paths), args.min_severity)
if args.format == "json":
print(json.dumps([finding.to_dict() for finding in findings], indent=2, sort_keys=True))
else:
print(format_text(findings))
return 0
if __name__ == "__main__":
sys.exit(main())
+85
View File
@@ -233,3 +233,88 @@ class TestConcurrentFileWrites:
thread.join()
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
class TestDownloadFile:
"""Tests for AioSandbox.download_file."""
def test_returns_concatenated_bytes(self, sandbox):
"""download_file should join chunks from the client iterator into bytes."""
sandbox._client.file.download_file = MagicMock(return_value=[b"hel", b"lo"])
result = sandbox.download_file("/mnt/user-data/outputs/file.bin")
assert result == b"hello"
sandbox._client.file.download_file.assert_called_once_with(path="/mnt/user-data/outputs/file.bin")
def test_returns_empty_bytes_for_empty_file(self, sandbox):
"""download_file should return b'' when the iterator yields nothing."""
sandbox._client.file.download_file = MagicMock(return_value=iter([]))
result = sandbox.download_file("/mnt/user-data/outputs/empty.bin")
assert result == b""
def test_uses_lock_during_download(self, sandbox):
"""download_file should hold the lock while calling the client."""
lock_was_held = []
def tracking_download(path):
lock_was_held.append(sandbox._lock.locked())
return iter([b"data"])
sandbox._client.file.download_file = tracking_download
sandbox.download_file("/mnt/user-data/outputs/file.bin")
assert lock_was_held == [True], "download_file must hold the lock during client call"
def test_raises_oserror_on_client_error(self, sandbox):
"""download_file should wrap client exceptions as OSError."""
sandbox._client.file.download_file = MagicMock(side_effect=RuntimeError("network error"))
with pytest.raises(OSError, match="network error"):
sandbox.download_file("/mnt/user-data/outputs/file.bin")
def test_preserves_oserror_from_client(self, sandbox):
"""OSError raised by the client should propagate without re-wrapping."""
sandbox._client.file.download_file = MagicMock(side_effect=OSError("disk error"))
with pytest.raises(OSError, match="disk error"):
sandbox.download_file("/mnt/user-data/outputs/file.bin")
def test_rejects_path_outside_virtual_prefix_and_logs_error(self, sandbox, caplog):
"""download_file must reject downloads outside /mnt/user-data and log the reason."""
sandbox._client.file.download_file = MagicMock()
with caplog.at_level("ERROR"):
with pytest.raises(PermissionError, match="must be under"):
sandbox.download_file("/etc/passwd")
assert "outside allowed directory" in caplog.text
sandbox._client.file.download_file.assert_not_called()
@pytest.mark.parametrize(
"path",
[
"/mnt/workspace/../../etc/passwd",
"../secret",
"/a/b/../../../etc/shadow",
],
)
def test_rejects_path_traversal(self, sandbox, path):
"""download_file must reject paths containing '..' before calling the client."""
sandbox._client.file.download_file = MagicMock()
with pytest.raises(PermissionError, match="path traversal"):
sandbox.download_file(path)
sandbox._client.file.download_file.assert_not_called()
def test_single_chunk(self, sandbox):
"""download_file should work correctly with a single-chunk response."""
sandbox._client.file.download_file = MagicMock(return_value=[b"single-chunk"])
result = sandbox.download_file("/mnt/user-data/outputs/single.bin")
assert result == b"single-chunk"
+212
View File
@@ -1,11 +1,14 @@
"""Tests for AioSandboxProvider mount helpers."""
import asyncio
import importlib
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.paths import Paths, join_host_path
from deerflow.runtime.user_context import reset_current_user, set_current_user
# ── ensure_thread_dirs ───────────────────────────────────────────────────────
@@ -136,3 +139,212 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
provider._discover_or_create_with_lock("thread-5", "sandbox-5")
assert unlock_calls == []
@pytest.mark.anyio
async def test_acquire_async_uses_async_readiness_polling(monkeypatch):
"""AioSandboxProvider async creation must not use sync readiness polling."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(None)
provider._config = {"replicas": 3}
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
provider._backend = SimpleNamespace(
create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-async", sandbox_url="http://sandbox")),
destroy=MagicMock(),
discover=MagicMock(return_value=None),
)
async_readiness_calls: list[tuple[str, int]] = []
async def fake_wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool:
async_readiness_calls.append((sandbox_url, timeout))
return True
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready_async", fake_wait_for_sandbox_ready_async)
monkeypatch.setattr(
aio_mod,
"wait_for_sandbox_ready",
lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("sync readiness should not be used")),
)
sandbox_id = await provider._create_sandbox_async("thread-async", "sandbox-async")
assert sandbox_id == "sandbox-async"
assert async_readiness_calls == [("http://sandbox", 60)]
assert provider._backend.destroy.call_count == 0
assert provider._thread_sandboxes["thread-async"] == "sandbox-async"
@pytest.mark.anyio
async def test_discover_or_create_with_lock_async_offloads_lock_file_open_and_close(tmp_path, monkeypatch):
"""Async lock path must not open or close lock files on the event loop."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._discover_or_create_with_lock_async = aio_mod.AioSandboxProvider._discover_or_create_with_lock_async.__get__(
provider,
aio_mod.AioSandboxProvider,
)
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {"thread-async-lock": "sandbox-async-lock"}
provider._sandboxes = {"sandbox-async-lock": aio_mod.AioSandbox(id="sandbox-async-lock", base_url="http://sandbox")}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
provider._backend = SimpleNamespace(discover=MagicMock(return_value=None))
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
to_thread_calls: list[object] = []
async def fake_to_thread(func, /, *args, **kwargs):
to_thread_calls.append(func)
return func(*args, **kwargs)
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread)
sandbox_id = await provider._discover_or_create_with_lock_async("thread-async-lock", "sandbox-async-lock")
assert sandbox_id == "sandbox-async-lock"
assert aio_mod._open_lock_file in to_thread_calls
assert any(getattr(func, "__name__", "") == "close" for func in to_thread_calls)
@pytest.mark.anyio
async def test_acquire_thread_lock_async_uses_dedicated_executor(monkeypatch):
"""Per-thread lock waits should not consume the default asyncio.to_thread pool."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
lock = aio_mod.threading.Lock()
async def fail_to_thread(*_args, **_kwargs):
raise AssertionError("thread-lock acquisition must not use asyncio.to_thread")
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fail_to_thread)
await aio_mod._acquire_thread_lock_async(lock)
try:
assert not lock.acquire(blocking=False)
finally:
lock.release()
@pytest.mark.anyio
async def test_acquire_async_cancellation_does_not_leak_thread_lock(tmp_path):
"""Cancelled async lock waiters must not leave the per-thread lock held."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
thread_id = "thread-cancel-lock"
thread_lock = provider._get_thread_lock(thread_id)
thread_lock.acquire()
task = asyncio.create_task(provider.acquire_async(thread_id))
await asyncio.sleep(0.05)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
thread_lock.release()
deadline = asyncio.get_running_loop().time() + 1
while asyncio.get_running_loop().time() < deadline:
acquired = thread_lock.acquire(blocking=False)
if acquired:
thread_lock.release()
return
await asyncio.sleep(0.01)
pytest.fail("provider thread lock was leaked after cancelling acquire_async")
@pytest.mark.anyio
async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path, monkeypatch):
"""A cancelled waiter must not prevent the next live waiter from acquiring."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
async def fake_acquire_internal_async(thread_id: str | None) -> str:
assert thread_id == "thread-successor-lock"
await asyncio.sleep(0)
return "sandbox-successor"
monkeypatch.setattr(provider, "_acquire_internal_async", fake_acquire_internal_async)
thread_id = "thread-successor-lock"
thread_lock = provider._get_thread_lock(thread_id)
thread_lock.acquire()
cancelled_waiter = asyncio.create_task(provider.acquire_async(thread_id))
await asyncio.sleep(0.05)
cancelled_waiter.cancel()
try:
await cancelled_waiter
except asyncio.CancelledError:
pass
live_waiter = asyncio.create_task(provider.acquire_async(thread_id))
thread_lock.release()
assert await asyncio.wait_for(live_waiter, timeout=1) == "sandbox-successor"
deadline = asyncio.get_running_loop().time() + 1
while asyncio.get_running_loop().time() < deadline:
acquired = thread_lock.acquire(blocking=False)
if acquired:
thread_lock.release()
return
await asyncio.sleep(0.01)
pytest.fail("provider thread lock was not released after successor acquire_async")
def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
"""Provisioner mode must receive user_id so PVC subPath matches user isolation."""
remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend")
backend = remote_mod.RemoteSandboxBackend("http://provisioner:8002")
token = set_current_user(SimpleNamespace(id="user-7"))
posted: dict = {}
class _Response:
def raise_for_status(self):
return None
def json(self):
return {"sandbox_url": "http://sandbox.local"}
def _post(url, json, timeout): # noqa: A002 - mirrors requests.post kwarg
posted.update({"url": url, "json": json, "timeout": timeout})
return _Response()
monkeypatch.setattr(remote_mod.requests, "post", _post)
try:
backend.create("thread-42", "sandbox-42")
finally:
reset_current_user(token)
assert posted["url"] == "http://provisioner:8002/api/sandboxes"
assert posted["json"] == {
"sandbox_id": "sandbox-42",
"thread_id": "thread-42",
"user_id": "user-7",
}
+119
View File
@@ -0,0 +1,119 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from deerflow.community.aio_sandbox import backend as readiness
class _FakeAsyncClient:
def __init__(self, *, responses: list[object], calls: list[str], timeout: float, request_timeouts: list[float] | None = None) -> None:
self._responses = responses
self._calls = calls
self._timeout = timeout
self._request_timeouts = request_timeouts
async def __aenter__(self) -> _FakeAsyncClient:
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
return None
async def get(self, url: str, *, timeout: float):
self._calls.append(url)
if self._request_timeouts is not None:
self._request_timeouts.append(timeout)
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
class _FakeLoop:
def __init__(self, times: list[float]) -> None:
self._times = times
self._index = 0
def time(self) -> float:
value = self._times[self._index]
self._index += 1
return value
@pytest.mark.anyio
async def test_wait_for_sandbox_ready_async_uses_nonblocking_polling(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[str] = []
sleeps: list[float] = []
def fake_client(*, timeout: float):
return _FakeAsyncClient(
responses=[SimpleNamespace(status_code=503), SimpleNamespace(status_code=200)],
calls=calls,
timeout=timeout,
)
async def fake_sleep(delay: float) -> None:
sleeps.append(delay)
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
monkeypatch.setattr(readiness.requests, "get", lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("requests.get should not be used")))
monkeypatch.setattr(readiness.time, "sleep", lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("time.sleep should not be used")))
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.05) is True
assert calls == ["http://sandbox/v1/sandbox", "http://sandbox/v1/sandbox"]
assert sleeps == [0.05]
@pytest.mark.anyio
async def test_wait_for_sandbox_ready_async_retries_request_errors(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[str] = []
sleeps: list[float] = []
def fake_client(*, timeout: float):
return _FakeAsyncClient(
responses=[readiness.httpx.ConnectError("not ready"), SimpleNamespace(status_code=200)],
calls=calls,
timeout=timeout,
)
async def fake_sleep(delay: float) -> None:
sleeps.append(delay)
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.01) is True
assert len(calls) == 2
assert sleeps == [0.01]
@pytest.mark.anyio
async def test_wait_for_sandbox_ready_async_clamps_request_and_sleep_to_deadline(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[str] = []
request_timeouts: list[float] = []
sleeps: list[float] = []
def fake_client(*, timeout: float):
return _FakeAsyncClient(
responses=[SimpleNamespace(status_code=503)],
calls=calls,
timeout=timeout,
request_timeouts=request_timeouts,
)
async def fake_sleep(delay: float) -> None:
sleeps.append(delay)
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
monkeypatch.setattr(readiness.asyncio, "get_running_loop", lambda: _FakeLoop([100.0, 100.5, 101.75, 102.0]))
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=2, poll_interval=1.0) is False
assert calls == ["http://sandbox/v1/sandbox"]
assert request_timeouts == [1.5]
assert sleeps == [0.25]
+15
View File
@@ -4,6 +4,7 @@ from pathlib import Path
import pytest
from _router_auth_helpers import call_unwrapped, make_authed_test_app
from fastapi import HTTPException
from fastapi.testclient import TestClient
from starlette.requests import Request
from starlette.responses import FileResponse
@@ -102,3 +103,17 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
assert response.status_code == 200
assert response.text == "hello"
assert response.headers.get("content-disposition", "").startswith("attachment;")
def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None:
skill_path = tmp_path / "sample.skill"
payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1)
with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref:
zip_ref.writestr("SKILL.md", payload)
assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES
with pytest.raises(HTTPException) as exc_info:
artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md")
assert exc_info.value.status_code == 413
+47 -11
View File
@@ -5,28 +5,26 @@ from unittest.mock import patch
import pytest
from app.gateway.auth.config import AuthConfig
import app.gateway.auth.config as cfg
def test_auth_config_defaults():
config = AuthConfig(jwt_secret="test-secret-key-123")
config = cfg.AuthConfig(jwt_secret="test-secret-key-123")
assert config.token_expiry_days == 7
def test_auth_config_token_expiry_range():
AuthConfig(jwt_secret="s", token_expiry_days=1)
AuthConfig(jwt_secret="s", token_expiry_days=30)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=1)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=30)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=0)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=0)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=31)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=31)
def test_auth_config_from_env():
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
with patch.dict(os.environ, env, clear=False):
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
@@ -36,19 +34,57 @@ def test_auth_config_from_env():
cfg._auth_config = old
def test_auth_config_missing_secret_generates_ephemeral(caplog):
def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog):
import logging
import app.gateway.auth.config as cfg
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
secret_file = tmp_path / ".jwt_secret"
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
assert secret_file.exists()
assert secret_file.read_text().strip() == config.jwt_secret
finally:
cfg._auth_config = old
def test_auth_config_reuses_persisted_secret(tmp_path):
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
persisted = "persisted-secret-from-file-min-32-chars!!"
(tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8")
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
config = cfg.get_auth_config()
assert config.jwt_secret == persisted
finally:
cfg._auth_config = old
def test_auth_config_empty_secret_file_generates_new(tmp_path):
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
(tmp_path / ".jwt_secret").write_text("", encoding="utf-8")
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
config = cfg.get_auth_config()
assert config.jwt_secret
assert len(config.jwt_secret) > 20
assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret
finally:
cfg._auth_config = old
+142
View File
@@ -0,0 +1,142 @@
"""Tests for idempotent run cancellation (issue #3055).
RunManager.cancel() returns True when a run is already interrupted so that
a second cancel request from the same worker is treated as a no-op success
(202) rather than a conflict (409). Both the POST cancel endpoint and the
POST stream endpoint share this behaviour through the same cancel() call.
"""
from __future__ import annotations
import asyncio
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.gateway.routers import thread_runs
from deerflow.runtime import RunManager, RunStatus
THREAD_ID = "thread-cancel-test"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app(mgr: RunManager) -> TestClient:
app = make_authed_test_app()
app.include_router(thread_runs.router)
app.state.run_manager = mgr
return TestClient(app, raise_server_exceptions=False)
def _create_interrupted_run(mgr: RunManager) -> str:
"""Create a run and cancel it, returning its run_id."""
async def _setup():
record = await mgr.create(THREAD_ID)
await mgr.set_status(record.run_id, RunStatus.running)
await mgr.cancel(record.run_id)
return record.run_id
return asyncio.run(_setup())
# ---------------------------------------------------------------------------
# RunManager.cancel() unit tests
# ---------------------------------------------------------------------------
class TestRunManagerCancelIdempotency:
def test_cancel_returns_true_for_already_interrupted_run(self):
"""cancel() must return True when the run is already interrupted."""
async def run():
mgr = RunManager()
record = await mgr.create(THREAD_ID)
await mgr.set_status(record.run_id, RunStatus.running)
first = await mgr.cancel(record.run_id)
assert first is True
second = await mgr.cancel(record.run_id)
assert second is True # idempotent
asyncio.run(run())
def test_cancel_returns_false_for_successful_run(self):
"""cancel() must still return False for runs that completed successfully."""
async def run():
mgr = RunManager()
record = await mgr.create(THREAD_ID)
await mgr.set_status(record.run_id, RunStatus.running)
await mgr.set_status(record.run_id, RunStatus.success)
result = await mgr.cancel(record.run_id)
assert result is False
asyncio.run(run())
def test_cancel_returns_false_for_unknown_run(self):
async def run():
mgr = RunManager()
result = await mgr.cancel("nonexistent-run-id")
assert result is False
asyncio.run(run())
# ---------------------------------------------------------------------------
# POST /cancel endpoint — idempotent 202
# ---------------------------------------------------------------------------
class TestCancelRunEndpointIdempotency:
def test_double_cancel_returns_202_not_409(self):
"""Second cancel on an already-interrupted run must return 202, not 409."""
mgr = RunManager()
run_id = _create_interrupted_run(mgr)
client = _make_app(mgr)
resp = client.post(f"/api/threads/{THREAD_ID}/runs/{run_id}/cancel")
assert resp.status_code == 202, f"Expected 202, got {resp.status_code}: {resp.text}"
def test_cancel_unknown_run_returns_404(self):
mgr = RunManager()
client = _make_app(mgr)
resp = client.post(f"/api/threads/{THREAD_ID}/runs/no-such-run/cancel")
assert resp.status_code == 404
def test_cancel_successful_run_returns_409(self):
"""Successfully-completed runs cannot be cancelled — must return 409."""
async def _setup():
mgr = RunManager()
record = await mgr.create(THREAD_ID)
await mgr.set_status(record.run_id, RunStatus.running)
await mgr.set_status(record.run_id, RunStatus.success)
return mgr, record.run_id
mgr, run_id = asyncio.run(_setup())
client = _make_app(mgr)
resp = client.post(f"/api/threads/{THREAD_ID}/runs/{run_id}/cancel")
assert resp.status_code == 409
# ---------------------------------------------------------------------------
# POST /{thread_id}/runs/{run_id}/join (stream_existing_run) — idempotent cancel
# ---------------------------------------------------------------------------
class TestStreamExistingRunIdempotentCancel:
def test_stream_cancel_already_interrupted_returns_not_409(self):
"""stream_existing_run with action=interrupt on an already-interrupted run
must not raise 409 — the idempotent cancel path returns 202/SSE."""
mgr = RunManager()
run_id = _create_interrupted_run(mgr)
client = _make_app(mgr)
resp = client.post(
f"/api/threads/{THREAD_ID}/runs/{run_id}/join",
params={"action": "interrupt"},
)
assert resp.status_code != 409, f"Should not 409 on idempotent cancel, got {resp.status_code}"
+1 -32
View File
@@ -372,37 +372,6 @@ class TestExtractResponseText:
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
assert _extract_response_text(result) == ""
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "search the repo"},
{
"type": "ai",
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == ""
def test_preserves_visible_text_when_stripping_loop_warning(self):
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "prepare the report"},
{
"type": "ai",
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == "Here is the report."
# ---------------------------------------------------------------------------
# ChannelManager tests
@@ -761,7 +730,7 @@ class TestChannelManager:
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None):
del assistant_id, context # unused in this test, kept for signature parity
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
+47 -6
View File
@@ -291,7 +291,7 @@ class TestAsyncCheckpointer:
@pytest.mark.anyio
async def test_sqlite_creates_parent_dir_via_to_thread(self):
"""Async SQLite setup should move mkdir off the event loop."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.checkpointer.async_provider import _prepare_sqlite_checkpointer_path, make_checkpointer
mock_config = MagicMock()
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
@@ -310,22 +310,63 @@ class TestAsyncCheckpointer:
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch(
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
"deerflow.runtime.checkpointer.async_provider.asyncio.to_thread",
new_callable=AsyncMock,
return_value="/tmp/resolved/test.db",
),
) as mock_to_thread,
):
async with make_checkpointer() as saver:
assert saver is mock_saver
mock_to_thread.assert_awaited_once()
called_fn, called_path = mock_to_thread.await_args.args
assert called_fn.__name__ == "ensure_sqlite_parent_dir"
assert called_path == "/tmp/resolved/test.db"
assert called_fn is _prepare_sqlite_checkpointer_path
assert called_path == "relative/test.db"
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
mock_saver.setup.assert_awaited_once()
@pytest.mark.anyio
async def test_database_sqlite_creates_parent_dir_via_to_thread(self):
"""Unified database SQLite setup should also move path IO off the event loop."""
from deerflow.config.database_config import DatabaseConfig
from deerflow.runtime.checkpointer.async_provider import _prepare_database_sqlite_checkpointer_path, make_checkpointer
db_config = DatabaseConfig(backend="sqlite", sqlite_dir="relative-data")
mock_config = MagicMock()
mock_config.checkpointer = None
mock_config.database = db_config
mock_saver = AsyncMock()
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_saver
mock_cm.__aexit__.return_value = False
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string.return_value = mock_cm
mock_module = MagicMock()
mock_module.AsyncSqliteSaver = mock_saver_cls
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch(
"deerflow.runtime.checkpointer.async_provider.asyncio.to_thread",
new_callable=AsyncMock,
return_value="/tmp/data/deerflow.db",
) as mock_to_thread,
):
async with make_checkpointer() as saver:
assert saver is mock_saver
mock_to_thread.assert_awaited_once()
called_fn, called_db_config = mock_to_thread.await_args.args
assert called_fn is _prepare_database_sqlite_checkpointer_path
assert called_db_config is db_config
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/data/deerflow.db")
mock_saver.setup.assert_awaited_once()
# ---------------------------------------------------------------------------
# app_config.py integration
@@ -0,0 +1,159 @@
"""Tests for DeerFlowClient's graph-root tracing wiring.
Regression coverage for the Copilot review on PR #2944: when the title
and summarization middlewares request ``attach_tracing=False`` we must
make sure ``DeerFlowClient`` injects the tracing callbacks at the graph
invocation root instead, otherwise those middlewares produce untraced
LLM calls.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from deerflow.client import DeerFlowClient
class _FakeAgent:
"""Capture the ``config`` handed to ``agent.stream``."""
def __init__(self) -> None:
self.captured_config: dict | None = None
self.checkpointer = None
self.store = None
def stream(self, state, *, config, context, stream_mode):
self.captured_config = config
return iter(()) # empty stream
@pytest.fixture(autouse=True)
def _clear_langfuse_env(monkeypatch):
from deerflow.config.tracing_config import reset_tracing_config
for name in ("LANGFUSE_TRACING", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY", "LANGFUSE_BASE_URL"):
monkeypatch.delenv(name, raising=False)
reset_tracing_config()
yield
reset_tracing_config()
def _stub_agent_creation(monkeypatch, fake_agent: _FakeAgent) -> dict[str, Any]:
"""Short-circuit the heavy parts of ``_ensure_agent`` so we can drive
``stream()`` against a fake graph without touching real models, tools
or middleware factories.
"""
captured: dict[str, Any] = {}
def _stub_ensure_agent(self, config):
captured["config"] = config
self._agent = fake_agent
self._agent_config_key = ("stub",)
monkeypatch.setattr(DeerFlowClient, "_ensure_agent", _stub_ensure_agent)
return captured
def _make_client(_monkeypatch) -> DeerFlowClient:
"""Build a client without going through ``__init__`` so we never load
config.yaml or perform any other side-effectful startup work."""
fake_app_config = SimpleNamespace(models=[SimpleNamespace(name="stub-model")])
client = DeerFlowClient.__new__(DeerFlowClient)
client._app_config = fake_app_config
client._extensions_config = None
client._model_name = "stub-model"
client._thinking_enabled = False
client._plan_mode = False
client._subagent_enabled = False
client._agent_name = None
client._available_skills = None
client._middlewares = None
client._checkpointer = None
client._agent = None
client._agent_config_key = None
client._environment = None
return client
def test_stream_injects_langfuse_metadata_when_enabled(monkeypatch):
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
from deerflow.config.tracing_config import reset_tracing_config
reset_tracing_config()
class _SentinelHandler:
pass
sentinel = _SentinelHandler()
monkeypatch.setattr("deerflow.client.build_tracing_callbacks", lambda: [sentinel])
fake_agent = _FakeAgent()
captured = _stub_agent_creation(monkeypatch, fake_agent)
client = _make_client(monkeypatch)
list(client.stream("hi", thread_id="thread-client-1"))
config = captured["config"]
metadata = config.get("metadata") or {}
assert metadata.get("langfuse_session_id") == "thread-client-1"
assert metadata.get("langfuse_trace_name") == "lead-agent"
# Default no-auth context falls back to ``"default"`` user.
assert metadata.get("langfuse_user_id") in {"default", "test-user-autouse"}
callbacks = config.get("callbacks") or []
assert sentinel in callbacks
def test_stream_is_inert_when_langfuse_disabled(monkeypatch):
monkeypatch.setattr("deerflow.client.build_tracing_callbacks", lambda: [])
fake_agent = _FakeAgent()
captured = _stub_agent_creation(monkeypatch, fake_agent)
client = _make_client(monkeypatch)
list(client.stream("hi", thread_id="thread-client-2"))
config = captured["config"]
assert "callbacks" not in config or not config["callbacks"]
metadata = config.get("metadata") or {}
assert "langfuse_session_id" not in metadata
assert "langfuse_user_id" not in metadata
def test_stream_preserves_caller_metadata_overrides(monkeypatch):
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
from deerflow.config.tracing_config import reset_tracing_config
reset_tracing_config()
monkeypatch.setattr("deerflow.client.build_tracing_callbacks", lambda: [])
fake_agent = _FakeAgent()
captured = _stub_agent_creation(monkeypatch, fake_agent)
client = _make_client(monkeypatch)
# Drive stream with a pre-populated metadata so the worker-equivalent
# ``setdefault`` semantics are exercised.
original_get_config = DeerFlowClient._get_runnable_config
def patched_get_runnable_config(self, thread_id, **overrides):
cfg = original_get_config(self, thread_id, **overrides)
cfg["metadata"] = {
"langfuse_session_id": "explicit-session-override",
"langfuse_user_id": "explicit-user",
}
return cfg
monkeypatch.setattr(DeerFlowClient, "_get_runnable_config", patched_get_runnable_config)
list(client.stream("hi", thread_id="thread-client-3"))
metadata = captured["config"].get("metadata") or {}
assert metadata["langfuse_session_id"] == "explicit-session-override"
assert metadata["langfuse_user_id"] == "explicit-user"
# ``trace_name`` was not supplied by caller so the worker still fills it.
assert metadata["langfuse_trace_name"] == "lead-agent"
@@ -14,6 +14,10 @@ def _ai_with_tool_calls(tool_calls):
return AIMessage(content="", tool_calls=tool_calls)
def _ai_with_invalid_tool_calls(invalid_tool_calls):
return AIMessage(content="", tool_calls=[], invalid_tool_calls=invalid_tool_calls)
def _tool_msg(tool_call_id, name="test_tool"):
return ToolMessage(content="result", tool_call_id=tool_call_id, name=name)
@@ -22,6 +26,16 @@ def _tc(name="bash", tc_id="call_1"):
return {"name": name, "id": tc_id, "args": {}}
def _invalid_tc(name="write_file", tc_id="write_file:36", error="Failed to parse tool arguments: malformed JSON"):
return {
"type": "invalid_tool_call",
"name": name,
"id": tc_id,
"args": '{"description":"write report","path":"/mnt/user-data/outputs/report.md","content":"bad {"json"}"}',
"error": error,
}
class TestBuildPatchedMessagesNoPatch:
def test_empty_messages(self):
mw = DanglingToolCallMiddleware()
@@ -144,6 +158,207 @@ class TestBuildPatchedMessagesPatching:
assert patched[1].name == "bash"
assert patched[1].status == "error"
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
middleware = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = middleware._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
HumanMessage(content="interruption"),
_tool_msg("call_2", "read"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert isinstance(patched[2], ToolMessage)
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
assert isinstance(patched[3], HumanMessage)
def test_non_tool_message_inserted_between_partial_tool_results_is_regrouped(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
_tool_msg("call_1", "bash"),
HumanMessage(content="interruption"),
_tool_msg("call_2", "read"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert isinstance(patched[2], ToolMessage)
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
assert isinstance(patched[3], HumanMessage)
def test_valid_adjacent_tool_results_are_unchanged(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
_tool_msg("call_1", "bash"),
HumanMessage(content="next"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_ai_with_tool_calls([_tc("read", "call_2")]),
_tool_msg("call_1", "bash"),
_tool_msg("call_2", "read"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
assert isinstance(patched[3], AIMessage)
assert isinstance(patched[4], ToolMessage)
assert patched[4].tool_call_id == "call_2"
def test_orphan_tool_message_is_preserved_during_grouping(self):
mw = DanglingToolCallMiddleware()
orphan = _tool_msg("orphan_call", "orphan")
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
orphan,
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert patched[2] is orphan
assert isinstance(patched[3], HumanMessage)
assert patched.count(orphan) == 1
def test_invalid_tool_call_is_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert len(patched) == 2
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "write_file:36"
assert patched[1].name == "write_file"
assert patched[1].status == "error"
assert "arguments were invalid" in patched[1].content
assert "Failed to parse tool arguments" in patched[1].content
def test_valid_and_invalid_tool_calls_are_both_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [
AIMessage(
content="",
tool_calls=[_tc("bash", "call_1")],
invalid_tool_calls=[_invalid_tc()],
)
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
tool_msgs = [m for m in patched if isinstance(m, ToolMessage)]
assert len(tool_msgs) == 2
assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "write_file:36"}
def test_invalid_tool_call_already_responded_is_not_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_invalid_tool_calls([_invalid_tc()]),
_tool_msg("write_file:36", "write_file"),
]
assert mw._build_patched_messages(msgs) is None
class TestWrapModelCall:
def test_no_patch_passthrough(self):
@@ -0,0 +1,222 @@
"""Real-LLM end-to-end verification for issue #2884.
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
and the production ``get_available_tools`` pipeline. The only thing we mock is
the MCP tool source we hand-roll two ``@tool``s and inject them through
``deerflow.mcp.cache.get_cached_mcp_tools``.
The flow exercised:
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
that re-enters ``get_available_tools`` on the same task this is the
code path issue #2884 reports). It must call ``tool_search`` to
discover the deferred ``fake_calculator`` tool.
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
``fake_subagent_trigger`` re-enters ``get_available_tools``.
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
so it can actually call it. Without this PR's fix, the re-entry wipes
the promotion and the model can no longer invoke the tool.
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
test run. Run with::
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
PYTHONPATH=. uv run pytest \
tests/test_deferred_tool_promotion_real_llm.py -v -s
"""
from __future__ import annotations
import os
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool as as_tool
# ---------------------------------------------------------------------------
# Skip control: only run when explicitly opted in.
# ---------------------------------------------------------------------------
pytestmark = pytest.mark.skipif(
os.getenv("ONEAPI_E2E") != "1",
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
)
# ---------------------------------------------------------------------------
# Fake "MCP" tools the agent should discover via tool_search.
# Keep them obviously synthetic so the model can pattern-match the search.
# ---------------------------------------------------------------------------
_calls: list[str] = []
@as_tool
def fake_calculator(expression: str) -> str:
"""Evaluate a tiny arithmetic expression like '2 + 2'.
Reserved for the user only call this if the user asks for arithmetic.
"""
_calls.append(f"fake_calculator:{expression}")
try:
# Trivially safe-eval just for the e2e check
allowed = set("0123456789+-*/() .")
if not set(expression) <= allowed:
return "expression contains disallowed characters"
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
except Exception as e:
return f"error: {e}"
@as_tool
def fake_translator(text: str, target_lang: str) -> str:
"""Translate text into the given language code. Decorative — not used."""
_calls.append(f"fake_translator:{text}:{target_lang}")
return f"[{target_lang}] {text}"
# ---------------------------------------------------------------------------
# Pipeline wiring (same shape as the in-process tests).
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_registry_between_tests():
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Build a minimal mock AppConfig and patch the symbol — never call the
real loader, which would trigger ``_apply_singleton_configs`` and
permanently mutate cross-test singletons (memory, title, )."""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Real-LLM e2e test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
"""End-to-end against a real OpenAI-compatible LLM.
The model must:
Turn 1 see ``tool_search`` (deferred tools aren't bound yet) and
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
``fake_subagent_trigger(...)``.
Turn 2 call ``fake_calculator`` and finish.
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
layer recorded in ``_calls`` which proves the model received the
promoted schema after the re-entrant ``get_available_tools`` call.
"""
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
_force_tool_search_enabled(monkeypatch)
_calls.clear()
@as_tool
async def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
Use this whenever the user asks you to delegate work pass a short
description as ``prompt``.
"""
# ``task_tool`` does this internally. Whether the registry-reset that
# used to happen here actually leaks back to the parent task depends
# on asyncio's implicit context-copying semantics (gather creates
# child tasks with copied contexts, so reset_deferred_registry is
# task-local) — but the fix in this PR is what GUARANTEES the
# promotion sticks regardless of which integration path triggers a
# re-entrant ``get_available_tools`` call.
get_available_tools(subagent_enabled=False)
_calls.append(f"fake_subagent_trigger:{prompt}")
return "subagent completed"
tools = get_available_tools() + [fake_subagent_trigger]
model = ChatOpenAI(
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
temperature=0,
max_retries=1,
)
system_prompt = (
"You are a meticulous assistant. Available deferred tools include a "
"calculator and a translator — their schemas are hidden until you "
"search for them via tool_search.\n\n"
"Procedure for the user's request:\n"
" 1. Call tool_search with query 'select:fake_calculator' AND "
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
"to delegate the side work. Put both tool_calls in your first response.\n"
" 2. After both tool messages come back, call fake_calculator with "
"the user's expression.\n"
" 3. Reply with just the numeric result."
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt=system_prompt,
)
result = await graph.ainvoke(
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
config={"recursion_limit": 12},
)
print("\n=== tool calls recorded ===")
for c in _calls:
print(f" {c}")
print("\n=== final message ===")
final_text = result["messages"][-1].content if result["messages"] else "(none)"
print(f" {final_text!r}")
# The smoking-gun assertion: fake_calculator was actually invoked at the
# tool layer. This is only possible if the promoted schema reached the
# model in turn 2, despite the subagent-style re-entry in turn 1.
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
# And the math should actually be done correctly (sanity that the LLM
# really used the result, not just hallucinated the answer).
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
@@ -0,0 +1,390 @@
"""Reproduce + regression-guard issue #2884.
Hypothesis from the issue:
``tools.tools.get_available_tools`` unconditionally calls
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
every time it is invoked. If anything calls ``get_available_tools`` again
during the same async context (after the agent has promoted tools via
``tool_search``), the promotion is wiped and the next model call hides the
tool's schema again.
These tests pin two things:
A. **At the unit boundary** verify the failure mode directly. Promote a
tool in the registry, then call ``get_available_tools`` again and observe
that the ContextVar registry is reset and the promotion is lost.
B. **At the graph-execution boundary** drive a real ``create_agent`` graph
with the real ``DeferredToolFilterMiddleware`` through two model turns.
The first turn calls ``tool_search`` which promotes a tool. The second
turn must see that tool's schema in ``request.tools``. If
``get_available_tools`` were to run again between the two turns and reset
the registry, the second turn's filter would strip the tool.
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
unmodified; mock only the LLM and the MCP tool source. Patch
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
``get_available_tools`` resolves via lazy import) to return our fixture
tools so we don't need a real MCP server.
"""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
# ---------------------------------------------------------------------------
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
# ---------------------------------------------------------------------------
@as_tool
def fake_mcp_search(query: str) -> str:
"""Pretend to search a knowledge base for the given query."""
return f"results for {query}"
@as_tool
def fake_mcp_fetch(url: str) -> str:
"""Pretend to fetch a page at the given URL."""
return f"content of {url}"
@pytest.fixture(autouse=True)
def _supply_env(monkeypatch: pytest.MonkeyPatch):
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
@pytest.fixture(autouse=True)
def _reset_deferred_registry_between_tests():
"""Each test must start with a clean ContextVar.
The registry lives in a module-level ContextVar with no per-task isolation
in a synchronous test runner, so one test's promotion can leak into the
next and silently break filter assertions.
"""
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
"""Make get_available_tools believe an MCP server is registered.
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
that both ``AppConfig.from_file`` (which calls
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
see a valid instance. Then point the MCP tool cache at our fixture tools.
"""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Force config.tool_search.enabled=True without touching the yaml.
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
which permanently mutates module-level singletons (``_memory_config``,
``_title_config``, ) to match the developer's ``config.yaml`` — even
after pytest restores our patch. That leaks across tests later in the
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
require ``_memory_config.enabled = True``, which is the dataclass default
but FALSE in the actual yaml).
Build a minimal mock AppConfig instead and never call the real loader.
"""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Section A — direct unit-level reproduction
# ---------------------------------------------------------------------------
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
Step 1: call get_available_tools() registers MCP tools as deferred.
Step 2: simulate the agent calling tool_search by promoting one tool.
Step 3: call get_available_tools() again (the same code path
``task_tool`` exercises mid-run).
Assertion: after step 3, the promoted tool is STILL promoted (not
re-deferred). On ``main`` before the fix, step 3's
``reset_deferred_registry()`` wiped the promotion and re-registered
every MCP tool as deferred this assertion fired with
``REGRESSION (#2884)``.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# Step 1: first call — both MCP tools start deferred
get_available_tools()
reg1 = get_deferred_registry()
assert reg1 is not None
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
# Step 2: simulate tool_search promoting one of them
reg1.promote({"fake_mcp_search"})
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
# Step 3: second call — registry must NOT silently undo the promotion
get_available_tools()
reg2 = get_deferred_registry()
assert reg2 is not None
deferred_after = {e.name for e in reg2.entries}
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
# ---------------------------------------------------------------------------
# Section B — graph-execution reproduction
# ---------------------------------------------------------------------------
class _ToolSearchPromotingModel(FakeToolCallingModel):
"""Two-turn model that:
Turn 1 emit a tool_call for ``tool_search`` (the real one)
Turn 2 emit a tool_call for ``fake_mcp_search`` (the promoted tool)
Records the tools it received on each turn so the test can inspect what
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
"""
bound_tools_per_turn: list[list[str]] = []
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
# Record the tool names the model would see in this turn
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
self.bound_tools_per_turn.append(names)
return self
def _build_promoting_model() -> _ToolSearchPromotingModel:
return _ToolSearchPromotingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
}
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
"""End-to-end: drive a real create_agent graph through two turns.
Without the fix, the second-turn bind_tools call should NOT contain
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
registry and strips it). With the fix, the model sees the schema and can
invoke it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
tools = get_available_tools()
# Sanity: the assembled tool list includes the deferred tools (they're in
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
# they reach the model)
tool_names = {getattr(t, "name", "") for t in tools}
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
model = _build_promoting_model()
model.bound_tools_per_turn = [] # reset class-level recorder
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
turn1 = set(model.bound_tools_per_turn[0])
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
# This is the load-bearing assertion for issue #2884.
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
turn2 = set(model.bound_tools_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
# ---------------------------------------------------------------------------
# Section C — the actual issue #2884 trigger: a re-entrant
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
# wipe the parent's promotion.
# ---------------------------------------------------------------------------
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
(the same pattern that happens when ``task_tool`` builds a subagent's
toolset mid-run) must not wipe the parent agent's tool_search promotions.
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
``get_available_tools`` again exactly what ``task_tool`` does when it
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
promoted tool. Without the fix, the re-entry wipes the registry and
the filter re-hides it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# The trigger tool simulates what task_tool does internally: rebuild the
# toolset by calling get_available_tools while the registry is live.
@as_tool
def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
get_available_tools(subagent_enabled=False)
return f"spawned subagent for: {prompt}"
tools = get_available_tools() + [fake_subagent_trigger]
bound_per_turn: list[list[str]] = []
class _Model(FakeToolCallingModel):
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
return self
model = _Model(
responses=[
# Turn 1: do both in one batch — promote AND trigger the
# subagent-style rebuild. LangGraph executes them in order in the
# same agent step.
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
},
{
"name": "fake_subagent_trigger",
"args": {"prompt": "go"},
"id": "call_trigger_1",
"type": "tool_call",
},
],
),
# Turn 2: try to invoke the promoted tool. The model gets this
# turn only if turn 1's bind_tools recorded what the filter sent.
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-subagent-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1 sanity: deferred tool not visible yet
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
# re-entrant get_available_tools call that happened in turn 1's tool batch.
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
turn2 = set(bound_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
@@ -0,0 +1,421 @@
from __future__ import annotations
import json
import textwrap
from pathlib import Path
from support.detectors import blocking_io_static as detector
def _write_python(path: Path, source: str) -> Path:
path.write_text(textwrap.dedent(source).strip() + "\n", encoding="utf-8")
return path
def _payload(path: Path, repo_root: Path) -> list[dict[str, object]]:
return [finding.to_dict() for finding in detector.scan_file(path, repo_root=repo_root)]
def test_scan_file_detects_direct_blocking_calls_in_async_code(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import subprocess
import time
import urllib.request
from pathlib import Path
async def handler(path: Path):
time.sleep(1)
subprocess.run(["echo", "ok"])
path.read_text(encoding="utf-8")
with open(path, encoding="utf-8") as handle:
return urllib.request.urlopen(handle.read())
""",
)
findings = _payload(source_file, tmp_path)
categories = {finding["blocking_call"]["category"] for finding in findings}
symbols = {finding["blocking_call"]["symbol"] for finding in findings}
assert categories == {
"BLOCKING_FILE_IO",
"BLOCKING_HTTP_IO",
"BLOCKING_SLEEP",
"BLOCKING_SUBPROCESS",
}
assert {"time.sleep", "subprocess.run", "path.read_text", "open", "urllib.request.urlopen"}.issubset(symbols)
assert {finding["event_loop_exposure"] for finding in findings} == {"DIRECT_ASYNC"}
def test_scan_file_detects_blocking_calls_in_sync_helper_reached_from_async_code(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
from pathlib import Path
def load_payload(path: Path) -> bytes:
return path.read_bytes()
async def route(path: Path) -> bytes:
return load_payload(path)
""",
)
findings = _payload(source_file, tmp_path)
assert len(findings) == 1
assert findings[0]["blocking_call"]["category"] == "BLOCKING_FILE_IO"
assert findings[0]["location"]["function"] == "load_payload"
assert findings[0]["event_loop_exposure"] == "ASYNC_REACHABLE_SAME_FILE"
assert findings[0]["blocking_call"]["symbol"] == "path.read_bytes"
def test_scan_file_omits_sync_only_blocking_calls_from_default_results(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
from pathlib import Path
def load_payload(path: Path) -> str:
return path.read_text()
""",
)
assert detector.scan_file(source_file, repo_root=tmp_path) == []
def test_scan_file_detects_self_helper_reached_from_async_method(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
class ArtifactRouter:
def read_payload(self, path):
return path.read_text(encoding="utf-8")
async def get(self, path):
return self.read_payload(path)
""",
)
findings = _payload(source_file, tmp_path)
assert len(findings) == 1
assert findings[0]["location"]["function"] == "ArtifactRouter.read_payload"
assert findings[0]["event_loop_exposure"] == "ASYNC_REACHABLE_SAME_FILE"
def test_json_output_uses_concise_review_record_schema(tmp_path: Path, capsys) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import subprocess
async def handler():
subprocess.run(["echo", "ok"])
""",
)
exit_code = detector.main(["--format", "json", str(source_file)])
assert exit_code == 0
payload = json.loads(capsys.readouterr().out)
assert payload == [
{
"priority": "HIGH",
"location": {
"path": str(source_file),
"line": 4,
"column": 5,
"function": "handler",
},
"blocking_call": {
"category": "BLOCKING_SUBPROCESS",
"operation": "SUBPROCESS",
"symbol": "subprocess.run",
},
"event_loop_exposure": "DIRECT_ASYNC",
"reason": "SUBPROCESS is called directly inside an async function.",
"code": 'subprocess.run(["echo", "ok"])',
}
]
assert "confidence" not in payload[0]
assert "severity" not in payload[0]
assert "event_loop_risk" not in payload[0]
def test_summary_output_writes_json_report(tmp_path: Path, capsys) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import subprocess
async def handler():
subprocess.run(["echo", "ok"])
""",
)
output_path = tmp_path / "reports" / "blocking-io.json"
exit_code = detector.main(["--output", str(output_path), str(source_file)])
assert exit_code == 0
stdout = capsys.readouterr().out
assert "Static blocking IO event-loop risk findings: 1" in stdout
assert "By category:" in stdout
assert "BLOCKING_SUBPROCESS" in stdout
assert "Full JSON report:" in stdout
payload = json.loads(output_path.read_text(encoding="utf-8"))
assert [finding["blocking_call"]["category"] for finding in payload] == ["BLOCKING_SUBPROCESS"]
def test_json_output_ranks_operations_without_confidence_noise(tmp_path: Path, capsys) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import shutil
async def handler(path):
path.exists()
path.read_text()
shutil.rmtree(path)
""",
)
exit_code = detector.main(["--format", "json", str(source_file)])
assert exit_code == 0
payload = json.loads(capsys.readouterr().out)
by_symbol = {finding["blocking_call"]["symbol"]: finding for finding in payload}
assert by_symbol["path.exists"]["blocking_call"]["operation"] == "FILE_METADATA"
assert by_symbol["path.exists"]["priority"] == "LOW"
assert by_symbol["path.read_text"]["blocking_call"]["operation"] == "FILE_READ"
assert by_symbol["path.read_text"]["priority"] == "MEDIUM"
assert by_symbol["shutil.rmtree"]["blocking_call"]["operation"] == "FILE_TREE_DELETE"
assert by_symbol["shutil.rmtree"]["priority"] == "HIGH"
assert {finding["event_loop_exposure"] for finding in payload} == {"DIRECT_ASYNC"}
assert all("confidence" not in finding for finding in payload)
def test_path_receiver_detection_uses_path_annotations(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
from pathlib import Path
async def typed(path: Path):
return path.read_text()
async def constructed():
return Path("payload.txt").read_text()
""",
)
findings = _payload(source_file, tmp_path)
assert {finding["blocking_call"]["symbol"] for finding in findings} == {"path.read_text", "pathlib.Path.read_text"}
assert {finding["priority"] for finding in findings} == {"MEDIUM"}
def test_summary_groups_findings_by_priority_and_operation(tmp_path: Path, capsys) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import os
from pathlib import Path
def load_payload(path: Path) -> str:
return path.read_text()
async def handler(path: Path) -> str:
path.exists()
list(os.walk(path))
return load_payload(path)
""",
)
exit_code = detector.main([str(source_file)])
assert exit_code == 0
stdout = capsys.readouterr().out
assert "By priority:" in stdout
assert "HIGH" in stdout
assert "MEDIUM" in stdout
assert "By operation:" in stdout
assert "FILE_ENUMERATION" in stdout
assert "FILE_METADATA" in stdout
assert "FILE_READ" in stdout
assert "By event-loop exposure:" in stdout
assert "DIRECT_ASYNC" in stdout
assert "ASYNC_REACHABLE_SAME_FILE" in stdout
def test_source_code_snippet_is_truncated_for_json_output(tmp_path: Path) -> None:
long_suffix = " + ".join('"chunk"' for _ in range(80))
source_file = _write_python(
tmp_path / "sample.py",
f"""
async def handler(path):
return path.read_text() + {long_suffix}
""",
)
findings = _payload(source_file, tmp_path)
assert len(findings) == 1
assert len(findings[0]["code"]) <= 203
assert findings[0]["code"].endswith("...")
def test_cli_default_filters_sync_only_inventory_items(tmp_path: Path, capsys) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
from pathlib import Path
def load_payload(path: Path) -> str:
return path.read_text()
""",
)
exit_code = detector.main(["--format", "json", str(source_file)])
assert exit_code == 0
assert json.loads(capsys.readouterr().out) == []
def test_sync_only_agent_middleware_hook_gets_event_loop_exposure(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
from langchain.agents.middleware import AgentMiddleware
from pathlib import Path
class UploadsMiddleware(AgentMiddleware):
def before_agent(self, state, runtime):
return self._load(Path("uploads"))
def _load(self, path: Path) -> str:
return path.read_text()
""",
)
findings = _payload(source_file, tmp_path)
assert len(findings) == 1
assert findings[0]["location"]["function"] == "UploadsMiddleware._load"
assert findings[0]["event_loop_exposure"] == "SYNC_AGENT_MIDDLEWARE_HOOK"
assert "statically reachable from a sync AgentMiddleware hook" in findings[0]["reason"]
def test_sync_agent_middleware_hook_with_async_counterpart_is_not_reported(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
from langchain.agents.middleware import AgentMiddleware
from pathlib import Path
class UploadsMiddleware(AgentMiddleware):
def before_agent(self, state, runtime):
return Path("uploads").read_text()
async def abefore_agent(self, state, runtime):
return None
""",
)
assert detector.scan_file(source_file, repo_root=tmp_path) == []
def test_scan_file_detects_sync_httpx_client_methods_in_async_code(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import httpx
async def search() -> str:
with httpx.Client(timeout=30) as client:
return client.post("https://example.invalid").text
""",
)
findings = _payload(source_file, tmp_path)
assert len(findings) == 1
assert findings[0]["blocking_call"]["category"] == "BLOCKING_HTTP_IO"
assert findings[0]["location"]["function"] == "search"
assert findings[0]["event_loop_exposure"] == "DIRECT_ASYNC"
assert findings[0]["blocking_call"]["symbol"] == "httpx.Client.post"
def test_scan_file_detects_chained_sync_http_client_methods_in_async_code(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import httpx
import requests
async def fetch() -> tuple[object, object]:
return (
httpx.Client().get("https://example.invalid"),
requests.Session().post("https://example.invalid"),
)
""",
)
findings = _payload(source_file, tmp_path)
symbols = {finding["blocking_call"]["symbol"] for finding in findings}
assert symbols == {"httpx.Client.get", "requests.Session.post"}
assert {finding["blocking_call"]["category"] for finding in findings} == {"BLOCKING_HTTP_IO"}
def test_scan_file_detects_os_walk_and_path_resolve_in_async_code(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
import os
from pathlib import Path
async def inspect_tree(path: Path) -> list[str]:
root = path.resolve()
return [name for _, _, names in os.walk(root) for name in names]
""",
)
findings = _payload(source_file, tmp_path)
symbols = {finding["blocking_call"]["symbol"] for finding in findings}
assert symbols == {"path.resolve", "os.walk"}
assert {finding["blocking_call"]["category"] for finding in findings} == {"BLOCKING_FILE_IO"}
def test_scan_file_does_not_treat_string_replace_as_file_io(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "sample.py",
"""
def _path_variants(path: str) -> set[str]:
return {path, path.replace("\\\\", "/"), path.replace("/", "\\\\")}
async def normalize(text: str) -> str:
return text.replace("a", "b")
""",
)
assert detector.scan_file(source_file, repo_root=tmp_path) == []
def test_parse_errors_are_reported_as_findings(tmp_path: Path) -> None:
source_file = _write_python(
tmp_path / "broken.py",
"""
async def broken(:
pass
""",
)
findings = _payload(source_file, tmp_path)
assert len(findings) == 1
assert findings[0]["blocking_call"]["category"] == "PARSE_ERROR"
assert findings[0]["priority"] == "MEDIUM"
assert f"{source_file.name}:1:18" in detector.format_text(detector.scan_file(source_file, repo_root=tmp_path))
@@ -0,0 +1,182 @@
from __future__ import annotations
import json
import textwrap
from pathlib import Path
from support.detectors import thread_boundaries as detector
def _write_python(path: Path, source: str) -> Path:
path.write_text(textwrap.dedent(source).strip() + "\n", encoding="utf-8")
return path
def test_scan_file_detects_async_thread_and_tool_boundaries(tmp_path):
source_file = _write_python(
tmp_path / "sample.py",
"""
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from langchain.tools import tool
from langchain_core.tools import StructuredTool
@tool
async def async_tool(value: int) -> str:
return str(value)
async def handler(model):
await asyncio.to_thread(str, "x")
model.invoke("blocking")
time.sleep(1)
def sync_entry():
asyncio.run(handler(None))
pool = ThreadPoolExecutor(max_workers=1)
pool.submit(str, "x")
threading.Thread(target=sync_entry).start()
return StructuredTool.from_function(
name="factory_tool",
description="factory",
coroutine=async_tool,
)
""",
)
findings = detector.scan_file(source_file, repo_root=tmp_path)
categories = {finding.category for finding in findings}
async_tool_finding = next(finding for finding in findings if finding.category == "ASYNC_TOOL_DEFINITION")
assert "ASYNC_TOOL_DEFINITION" in categories
assert async_tool_finding.function == "async_tool"
assert async_tool_finding.async_context is True
assert "ASYNC_THREAD_OFFLOAD" in categories
assert "SYNC_INVOKE_IN_ASYNC" in categories
assert "BLOCKING_CALL_IN_ASYNC" in categories
assert "SYNC_ASYNC_BRIDGE" in categories
assert "THREAD_POOL" in categories
assert "EXECUTOR_SUBMIT" in categories
assert "RAW_THREAD" in categories
assert "ASYNC_ONLY_TOOL_FACTORY" in categories
def test_scan_file_ignores_unqualified_threads_and_generic_method_names(tmp_path):
source_file = _write_python(
tmp_path / "sample.py",
"""
class Thread:
pass
class Timer:
pass
async def handler(form, runner):
form.submit()
runner.invoke("not a langchain model")
def sync_entry(runner):
Thread()
Timer()
runner.ainvoke("not a langchain model")
""",
)
findings = detector.scan_file(source_file, repo_root=tmp_path)
categories = {finding.category for finding in findings}
assert "RAW_THREAD" not in categories
assert "RAW_TIMER_THREAD" not in categories
assert "EXECUTOR_SUBMIT" not in categories
assert "SYNC_INVOKE_IN_ASYNC" not in categories
assert "ASYNC_INVOKE_IN_SYNC" not in categories
def test_scan_file_uses_import_evidence_for_thread_and_executor_aliases(tmp_path):
source_file = _write_python(
tmp_path / "sample.py",
"""
from concurrent.futures import ThreadPoolExecutor as Pool
from threading import Thread as WorkerThread, Timer
def sync_entry():
pool = Pool(max_workers=1)
pool.submit(str, "x")
WorkerThread(target=sync_entry).start()
Timer(1, sync_entry).start()
""",
)
findings = detector.scan_file(source_file, repo_root=tmp_path)
categories = {finding.category for finding in findings}
assert "THREAD_POOL" in categories
assert "EXECUTOR_SUBMIT" in categories
assert "RAW_THREAD" in categories
assert "RAW_TIMER_THREAD" in categories
def test_scan_paths_ignores_virtualenv_like_directories(tmp_path):
scanned_file = _write_python(
tmp_path / "app.py",
"""
import asyncio
def main():
return asyncio.run(asyncio.sleep(0))
""",
)
ignored_dir = tmp_path / ".venv"
ignored_dir.mkdir()
_write_python(
ignored_dir / "ignored.py",
"""
import threading
thread = threading.Thread(target=lambda: None)
""",
)
findings = detector.scan_paths([tmp_path], repo_root=tmp_path)
assert any(finding.path == scanned_file.name for finding in findings)
assert all(".venv" not in finding.path for finding in findings)
def test_json_output_and_min_severity_filter(tmp_path, capsys):
source_file = _write_python(
tmp_path / "sample.py",
"""
import asyncio
async def handler(model):
await asyncio.to_thread(str, "x")
model.invoke("blocking")
""",
)
exit_code = detector.main(["--format", "json", "--min-severity", "WARN", str(source_file)])
assert exit_code == 0
payload = json.loads(capsys.readouterr().out)
categories = {finding["category"] for finding in payload}
assert categories == {"SYNC_INVOKE_IN_ASYNC"}
def test_parse_errors_are_reported_as_findings(tmp_path):
source_file = _write_python(
tmp_path / "broken.py",
"""
def broken(:
pass
""",
)
findings = detector.scan_file(source_file, repo_root=tmp_path)
assert len(findings) == 1
assert findings[0].category == "PARSE_ERROR"
assert findings[0].severity == "WARN"
assert findings[0].column == 11
assert f"{source_file.name}:1:12" in detector.format_text(findings)
@@ -0,0 +1,189 @@
"""Regression tests for gateway config freshness on the request hot path.
Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path
captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during
runtime were therefore ignored ``get_app_config()``'s mtime-based reload
existed but was bypassed because the snapshot object was passed through
explicitly.
These tests pin the desired behaviour: a request-time ``get_config`` call must
observe the most recent on-disk ``config.yaml`` (mtime reload), and the
runtime ``ContextVar`` override must keep working for per-request injection.
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway import deps as gateway_deps
from app.gateway.deps import get_config
from deerflow.config.app_config import (
AppConfig,
pop_current_app_config,
push_current_app_config,
reset_app_config,
set_app_config,
)
from deerflow.config.sandbox_config import SandboxConfig
@pytest.fixture(autouse=True)
def _isolate_app_config_singleton():
"""Ensure each test starts with a clean module-level cache."""
reset_app_config()
yield
reset_app_config()
def _write_config_yaml(path: Path, *, log_level: str) -> None:
path.write_text(
f"""
sandbox:
use: deerflow.sandbox.local.provider:LocalSandboxProvider
log_level: {log_level}
""".strip()
+ "\n",
encoding="utf-8",
)
def _build_app() -> FastAPI:
app = FastAPI()
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"log_level": cfg.log_level}
return app
def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch):
"""Editing config.yaml at runtime must be visible to /probe without restart.
This is the literal repro for the issue: the gateway must not freeze the
config to whatever was on disk when the process started.
"""
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "info"}
# Edit the file and bump its mtime — simulating a maintainer changing
# max_tokens / model settings in production while the gateway is live.
_write_config_yaml(config_file, log_level="debug")
future_mtime = config_file.stat().st_mtime + 5
os.utime(config_file, (future_mtime, future_mtime))
assert client.get("/probe").json() == {"log_level": "debug"}
def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch):
"""Per-request ``push_current_app_config`` injection must still win."""
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace")
push_current_app_config(override)
try:
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "trace"}
finally:
pop_current_app_config()
def test_get_config_respects_test_set_app_config():
"""``set_app_config`` (used by upload/skills router tests) keeps working."""
injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning")
set_app_config(injected)
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "warning"}
def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch):
"""``RunContext.app_config`` must follow live `config.yaml` edits.
BUG-001 review feedback: the run-context that feeds worker / lead-agent
factories must observe the same mtime reload that `get_config()` does;
otherwise stale config slips back in through the run path even after the
request dependency is fixed.
"""
from unittest.mock import MagicMock
from app.gateway.deps import get_run_context
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
app = FastAPI()
# Sentinel values for the rest of the RunContext wiring — we only care
# about ``ctx.app_config`` for this assertion.
app.state.checkpointer = MagicMock()
app.state.store = MagicMock()
app.state.run_event_store = MagicMock()
app.state.run_events_config = {"frozen": "startup"}
app.state.thread_store = MagicMock()
@app.get("/run-ctx-log-level")
def probe(ctx=Depends(get_run_context)):
return {
"log_level": ctx.app_config.log_level,
"run_events_config": ctx.run_events_config,
}
client = TestClient(app)
first = client.get("/run-ctx-log-level").json()
assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}}
_write_config_yaml(config_file, log_level="debug")
future_mtime = config_file.stat().st_mtime + 5
os.utime(config_file, (future_mtime, future_mtime))
second = client.get("/run-ctx-log-level").json()
# app_config follows the edit; run_events_config stays frozen to the
# startup snapshot we wrote onto app.state above.
assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}}
@pytest.mark.parametrize(
"exception",
[
FileNotFoundError("config.yaml not found"),
PermissionError("config.yaml not readable"),
ValueError("invalid config"),
RuntimeError("yaml parse error"),
],
)
def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception):
"""Any failure to materialise the config must surface as 503, not 500.
Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot
contract returned 503 when ``app.state.config is None``. The first cut of
this fix only mapped ``FileNotFoundError`` to 503, which left
``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling
up as 500. Catch every load failure at the request boundary.
"""
def _broken_get_app_config():
raise exception
monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config)
app = _build_app()
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/probe")
assert response.status_code == 503
assert response.json() == {"detail": "Configuration not available"}
-41
View File
@@ -1,41 +0,0 @@
from __future__ import annotations
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def test_get_config_returns_app_state_config():
"""get_config should return the exact AppConfig stored on app.state."""
app = FastAPI()
config = AppConfig(sandbox=SandboxConfig(use="test"))
app.state.config = config
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"same_identity": cfg is config, "log_level": cfg.log_level}
client = TestClient(app)
response = client.get("/probe")
assert response.status_code == 200
assert response.json() == {"same_identity": True, "log_level": "info"}
def test_get_config_reads_updated_app_state():
"""Swapping app.state.config should be visible to the dependency."""
app = FastAPI()
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
@app.get("/log-level")
def log_level(cfg: AppConfig = Depends(get_config)):
return {"level": cfg.log_level}
client = TestClient(app)
assert client.get("/log-level").json() == {"level": "info"}
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
assert client.get("/log-level").json() == {"level": "debug"}
+42
View File
@@ -122,3 +122,45 @@ def test_health_still_works_when_docs_disabled():
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json()["status"] == "healthy"
# ---------------------------------------------------------------------------
# Runtime CORS behavior
# ---------------------------------------------------------------------------
def _make_gateway_client(cors_origins: str) -> TestClient:
with patch.dict(os.environ, {"GATEWAY_CORS_ORIGINS": cors_origins}):
_reset_gateway_config()
from app.gateway.app import create_app
return TestClient(create_app())
def test_gateway_cors_allows_configured_origin():
"""GATEWAY_CORS_ORIGINS should control actual browser CORS responses."""
client = _make_gateway_client("https://app.example")
response = client.get("/health", headers={"Origin": "https://app.example"})
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://app.example"
assert response.headers["access-control-allow-credentials"] == "true"
def test_gateway_cors_rejects_unconfigured_origin():
client = _make_gateway_client("https://app.example")
response = client.get("/health", headers={"Origin": "https://evil.example"})
assert response.status_code == 200
assert "access-control-allow-origin" not in response.headers
def test_gateway_cors_normalizes_configured_default_port():
client = _make_gateway_client("https://app.example:443")
response = client.get("/health", headers={"Origin": "https://app.example"})
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://app.example"
@@ -17,7 +17,7 @@ from fastapi import FastAPI
@asynccontextmanager
async def _noop_langgraph_runtime(_app):
async def _noop_langgraph_runtime(_app, _startup_config):
yield
+127
View File
@@ -0,0 +1,127 @@
"""Gateway startup recovery for stale persisted runs."""
from __future__ import annotations
from contextlib import asynccontextmanager
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
import deerflow.runtime as runtime_module
from app.gateway import deps as gateway_deps
from deerflow.persistence import engine as engine_module
from deerflow.persistence import thread_meta as thread_meta_module
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
from deerflow.runtime.events import store as event_store_module
@asynccontextmanager
async def _fake_context(value):
yield value
class _FakeRunManager:
"""RunManager double that records startup reconciliation calls."""
instances: list[_FakeRunManager] = []
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
def __init__(self, *, store):
self.store = store
self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = []
_FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
self.reconcile_calls.append({"error": error, "before": before})
return self.recovered_runs
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
class _FakeThreadStore:
def __init__(self) -> None:
self.status_updates: list[tuple[str, str, str | None]] = []
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
self.status_updates.append((thread_id, status, user_id))
@pytest.mark.anyio
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
"""SQLite startup should recover stale active runs before serving requests."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].reconcile_calls
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == [("thread-1", "error", None)]
@pytest.mark.anyio
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == []
@@ -53,6 +53,29 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api():
assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content
def test_nginx_defers_cors_to_gateway_allowlist():
for path in ("docker/nginx/nginx.local.conf", "docker/nginx/nginx.conf"):
content = _read(path)
assert "Access-Control-Allow-Origin" not in content
assert "Access-Control-Allow-Methods" not in content
assert "Access-Control-Allow-Headers" not in content
assert "Access-Control-Allow-Credentials" not in content
assert "proxy_hide_header 'Access-Control-Allow-" not in content
assert "if ($request_method = 'OPTIONS')" not in content
def test_gateway_cors_configuration_uses_gateway_allowlist():
gateway_config = _read("backend/app/gateway/config.py")
gateway_app = _read("backend/app/gateway/app.py")
csrf_middleware = _read("backend/app/gateway/csrf_middleware.py")
assert not re.search(r"(?<!GATEWAY_)[\"']CORS_ORIGINS[\"']", gateway_config)
assert "cors_origins" not in gateway_config
assert "get_configured_cors_origins" in gateway_app
assert "GATEWAY_CORS_ORIGINS" in csrf_middleware
def test_frontend_rewrites_langgraph_prefix_to_gateway():
next_config = _read("frontend/next.config.js")
api_client = _read("frontend/src/core/api/api-client.ts")
+92
View File
@@ -81,6 +81,94 @@ def test_normalize_input_passthrough():
assert result == {"custom_key": "value"}
def test_normalize_input_preserves_additional_kwargs_and_id():
"""Regression: gh #3132 — frontend ships uploaded-file metadata in
additional_kwargs.files (and a client-side message id). The gateway must
not strip them before the graph runs, otherwise UploadsMiddleware reports
"(empty)" for new uploads and the frontend message loses its file chip.
"""
from langchain_core.messages import HumanMessage
from app.gateway.services import normalize_input
files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}]
result = normalize_input(
{
"messages": [
{
"type": "human",
"id": "client-msg-1",
"name": "user-input",
"content": [{"type": "text", "text": "clean it"}],
"additional_kwargs": {"files": files, "custom": "keep-me"},
}
]
}
)
assert len(result["messages"]) == 1
msg = result["messages"][0]
assert isinstance(msg, HumanMessage)
assert msg.id == "client-msg-1"
assert msg.name == "user-input"
assert msg.content == [{"type": "text", "text": "clean it"}]
assert msg.additional_kwargs == {"files": files, "custom": "keep-me"}
def test_normalize_input_passes_through_basemessage_instances():
from langchain_core.messages import HumanMessage
from app.gateway.services import normalize_input
msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]})
result = normalize_input({"messages": [msg]})
assert result["messages"][0] is msg
def test_normalize_input_rejects_malformed_message_with_400():
"""Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a
message dict is missing ``role``/``type``/``content``. ``normalize_input``
runs inside the gateway HTTP boundary, so a malformed payload should surface
as a 400 referencing the offending entry not bubble up as a 500.
Raised after the Copilot review on PR #3136.
"""
import pytest
from fastapi import HTTPException
from app.gateway.services import normalize_input
with pytest.raises(HTTPException) as excinfo:
normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]})
assert excinfo.value.status_code == 400
assert "input.messages[1]" in excinfo.value.detail
def test_normalize_input_handles_non_human_roles():
"""The previous implementation collapsed every role to HumanMessage with a
`# TODO: handle other message types` comment. Resuming a thread with prior
AI/tool messages would silently rewrite them as human turns corrupting
the conversation. Use langchain's standard conversion so ai/system/tool
roles round-trip correctly.
"""
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
from app.gateway.services import normalize_input
result = normalize_input(
{
"messages": [
{"role": "system", "content": "sys"},
{"role": "ai", "content": "hi", "id": "ai-1"},
{"role": "tool", "content": "result", "tool_call_id": "call-1"},
]
}
)
types = [type(m) for m in result["messages"]]
assert types == [SystemMessage, AIMessage, ToolMessage]
assert result["messages"][1].id == "ai-1"
assert result["messages"][2].tool_call_id == "call-1"
def test_build_run_config_basic():
from app.gateway.services import build_run_config
@@ -114,6 +202,7 @@ def test_build_run_config_custom_agent_injects_agent_name():
config = build_run_config("thread-1", None, None, assistant_id="finalis")
assert config["configurable"]["agent_name"] == "finalis"
assert config["run_name"] == "finalis"
def test_build_run_config_lead_agent_no_agent_name():
@@ -122,6 +211,7 @@ def test_build_run_config_lead_agent_no_agent_name():
config = build_run_config("thread-1", None, None, assistant_id="lead_agent")
assert "agent_name" not in config["configurable"]
assert "run_name" not in config
def test_build_run_config_none_assistant_id_no_agent_name():
@@ -130,6 +220,7 @@ def test_build_run_config_none_assistant_id_no_agent_name():
config = build_run_config("thread-1", None, None, assistant_id=None)
assert "agent_name" not in config["configurable"]
assert "run_name" not in config
def test_build_run_config_explicit_agent_name_not_overwritten():
@@ -143,6 +234,7 @@ def test_build_run_config_explicit_agent_name_not_overwritten():
assistant_id="other-agent",
)
assert config["configurable"]["agent_name"] == "explicit-agent"
assert config["run_name"] == "explicit-agent"
def test_build_run_config_context_custom_agent_injects_agent_name():
+74 -11
View File
@@ -22,7 +22,7 @@ _TEST_SECRET = "test-secret-key-initialize-admin-min-32"
def _setup_auth(tmp_path):
"""Fresh SQLite engine + auth config per test."""
from app.gateway import deps
from app.gateway.routers.auth import _SETUP_STATUS_COOLDOWN
from app.gateway.routers.auth import _SETUP_STATUS_CACHE, _SETUP_STATUS_INFLIGHT
from deerflow.persistence.engine import close_engine, init_engine
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
@@ -30,13 +30,15 @@ def _setup_auth(tmp_path):
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
deps._cached_local_provider = None
deps._cached_repo = None
_SETUP_STATUS_COOLDOWN.clear()
_SETUP_STATUS_CACHE.clear()
_SETUP_STATUS_INFLIGHT.clear()
try:
yield
finally:
deps._cached_local_provider = None
deps._cached_repo = None
_SETUP_STATUS_COOLDOWN.clear()
_SETUP_STATUS_CACHE.clear()
_SETUP_STATUS_INFLIGHT.clear()
asyncio.run(close_engine())
@@ -168,15 +170,76 @@ def test_setup_status_false_when_only_regular_user_exists(client):
assert resp.json()["needs_setup"] is True
def test_setup_status_rate_limited_on_second_call(client):
"""Second /setup-status call within the cooldown window returns 429 with Retry-After."""
# First call succeeds.
def test_setup_status_returns_cached_result_on_rapid_calls(client):
"""Rapid /setup-status calls return the cached result (200) instead of 429."""
client.post("/api/v1/auth/initialize", json=_init_payload())
# First call succeeds and computes the result.
resp1 = client.get("/api/v1/auth/setup-status")
assert resp1.status_code == 200
# Immediate second call is rate-limited.
# Immediate second call returns cached result, not 429.
resp2 = client.get("/api/v1/auth/setup-status")
assert resp2.status_code == 429
assert "Retry-After" in resp2.headers
retry_after = int(resp2.headers["Retry-After"])
assert 1 <= retry_after <= 60
assert resp2.status_code == 200
assert resp2.json() == resp1.json()
assert resp2.json()["needs_setup"] is False
def test_setup_status_does_not_return_stale_true_after_initialize(client):
"""A pre-initialize setup-status response should not stay cached as True."""
before = client.get("/api/v1/auth/setup-status")
assert before.status_code == 200
assert before.json()["needs_setup"] is True
init = client.post("/api/v1/auth/initialize", json=_init_payload())
assert init.status_code == 201
after = client.get("/api/v1/auth/setup-status")
assert after.status_code == 200
assert after.json()["needs_setup"] is False
@pytest.mark.asyncio
async def test_setup_status_single_flight_per_ip(monkeypatch):
"""Concurrent requests from same IP share one in-flight DB query."""
from starlette.requests import Request
from app.gateway.routers.auth import (
_SETUP_STATUS_CACHE,
_SETUP_STATUS_INFLIGHT,
setup_status,
)
class _Provider:
def __init__(self):
self.calls = 0
async def count_admin_users(self):
self.calls += 1
await asyncio.sleep(0.05)
return 0
provider = _Provider()
monkeypatch.setattr("app.gateway.routers.auth.get_local_provider", lambda: provider)
_SETUP_STATUS_CACHE.clear()
_SETUP_STATUS_INFLIGHT.clear()
def _request() -> Request:
return Request(
{
"type": "http",
"method": "GET",
"path": "/api/v1/auth/setup-status",
"headers": [],
"client": ("127.0.0.1", 12345),
}
)
results = await asyncio.gather(
setup_status(_request()),
setup_status(_request()),
setup_status(_request()),
)
assert all(result["needs_setup"] is True for result in results)
assert provider.calls == 1
+35
View File
@@ -0,0 +1,35 @@
"""Tests for Gateway internal auth token handling."""
from __future__ import annotations
import importlib
def test_internal_auth_uses_shared_env_token(monkeypatch):
import app.gateway.internal_auth as internal_auth
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
reloaded = importlib.reload(internal_auth)
try:
headers = reloaded.create_internal_auth_headers()
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
assert reloaded.is_valid_internal_auth_token("shared-token") is True
assert reloaded.is_valid_internal_auth_token("other-token") is False
finally:
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
importlib.reload(reloaded)
def test_internal_auth_generates_process_local_fallback(monkeypatch):
import app.gateway.internal_auth as internal_auth
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
reloaded = importlib.reload(internal_auth)
try:
token = reloaded.create_internal_auth_headers()[reloaded.INTERNAL_AUTH_HEADER_NAME]
assert token
assert reloaded.is_valid_internal_auth_token(token) is True
finally:
importlib.reload(reloaded)
@@ -699,6 +699,92 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
load_acp_config_from_dict({})
def test_get_available_tools_sync_invoke_acp_agent_preserves_thread_workspace(monkeypatch, tmp_path):
from deerflow.config import paths as paths_module
from deerflow.runtime import user_context as uc_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
)
monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True)
captured: dict[str, object] = {}
class DummyClient:
@property
def collected_text(self) -> str:
return "ok"
async def session_update(self, session_id, update, **kwargs):
pass
async def request_permission(self, options, session_id, tool_call, **kwargs):
raise AssertionError("should not be called")
class DummyConn:
async def initialize(self, **kwargs):
pass
async def new_session(self, **kwargs):
return SimpleNamespace(session_id="s1")
async def prompt(self, **kwargs):
pass
class DummyProcessContext:
def __init__(self, client, cmd, *args, env=None, cwd):
captured["cwd"] = cwd
async def __aenter__(self):
return DummyConn(), object()
async def __aexit__(self, exc_type, exc, tb):
return False
monkeypatch.setitem(
sys.modules,
"acp",
SimpleNamespace(
PROTOCOL_VERSION="2026-03-24",
Client=DummyClient,
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
text_block=lambda text: {"type": "text", "text": text},
),
)
monkeypatch.setitem(
sys.modules,
"acp.schema",
SimpleNamespace(
ClientCapabilities=lambda: {},
Implementation=lambda **kwargs: kwargs,
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
),
)
explicit_config = SimpleNamespace(
tools=[],
models=[],
tool_search=SimpleNamespace(enabled=False),
skill_evolution=SimpleNamespace(enabled=False),
sandbox=SimpleNamespace(),
get_model_config=lambda name: None,
acp_agents={"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")},
)
tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config)
tool = next(tool for tool in tools if tool.name == "invoke_acp_agent")
thread_id = "thread-sync-123"
tool.invoke(
{"agent": "codex", "prompt": "Do something"},
config={"configurable": {"thread_id": thread_id}},
)
assert captured["cwd"] == str(tmp_path / "threads" / thread_id / "acp-workspace")
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
explicit_config = SimpleNamespace(
@@ -41,6 +41,49 @@ def test_make_lead_agent_signature_matches_langgraph_server_factory_abi():
assert list(inspect.signature(lead_agent_module.make_lead_agent).parameters) == ["config"]
def test_make_lead_agent_attaches_tracing_callbacks_at_graph_root(monkeypatch):
"""Regression guard: tracing handlers must be appended to
``config["callbacks"]`` (graph invocation root), and every in-graph
``create_chat_model`` call must pass ``attach_tracing=False``.
Catches future contributors who forget the flag when adding new
in-graph model creation, which would silently produce duplicate
spans and break Langfuse session/user propagation.
"""
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
import deerflow.tools as tools_module
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
sentinel_handler = object()
monkeypatch.setattr(lead_agent_module, "build_tracing_callbacks", lambda: [sentinel_handler])
seen_attach_tracing: list[bool] = []
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
seen_attach_tracing.append(attach_tracing)
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
config: dict = {"configurable": {"model_name": "safe-model"}}
lead_agent_module._make_lead_agent(config, app_config=app_config)
# Handler must land on the graph invocation config so the Langfuse
# CallbackHandler fires ``on_chain_start(parent_run_id=None)`` and
# propagates ``session_id`` / ``user_id`` onto the trace.
assert sentinel_handler in (config.get("callbacks") or []), "build_tracing_callbacks output must be appended to config['callbacks']"
# Every in-graph create_chat_model call must opt out of model-level
# tracing to avoid duplicate spans.
assert seen_attach_tracing, "_make_lead_agent did not call create_chat_model"
assert all(flag is False for flag in seen_attach_tracing), f"in-graph create_chat_model must pass attach_tracing=False; got {seen_attach_tracing}"
def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch):
app_config = _make_app_config([_make_model("explicit-model", supports_thinking=False)])
@@ -55,7 +98,7 @@ def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch):
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
captured["name"] = name
captured["app_config"] = app_config
return object()
@@ -89,7 +132,7 @@ def test_make_lead_agent_uses_runtime_app_config_from_context_without_global_rea
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
captured["name"] = name
captured["app_config"] = app_config
return object()
@@ -168,7 +211,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
@@ -212,7 +255,7 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
@@ -293,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
# verify the custom middleware is injected correctly.
# Chain tail order after the custom middleware is:
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
# so the custom mock sits at index [-3].
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
@@ -407,7 +453,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
@@ -441,7 +487,7 @@ def test_create_summarization_middleware_threads_resolved_app_config_to_model(mo
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True):
captured["app_config"] = app_config
return fake_model
@@ -204,6 +204,26 @@ class TestSymlinkEscapes:
assert exc_info.value.errno == errno.EACCES
def test_download_file_blocks_symlink_escape_from_mount(self, tmp_path):
mount_dir = tmp_path / "mount"
mount_dir.mkdir()
outside_dir = tmp_path / "outside"
outside_dir.mkdir()
(outside_dir / "secret.bin").write_bytes(b"\x00secret")
_symlink_to(outside_dir, mount_dir / "escape", target_is_directory=True)
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/user-data", local_path=str(mount_dir), read_only=False),
],
)
with pytest.raises(PermissionError) as exc_info:
sandbox.download_file("/mnt/user-data/escape/secret.bin")
assert exc_info.value.errno == errno.EACCES
def test_write_file_blocks_symlink_escape_from_mount(self, tmp_path):
mount_dir = tmp_path / "mount"
mount_dir.mkdir()
@@ -334,6 +354,74 @@ class TestSymlinkEscapes:
assert existing.read_bytes() == b"original"
class TestDownloadFileMappings:
"""download_file must use _resolve_path_with_mapping so path resolution, symlink
containment, and read-only awareness are consistent with read_file."""
def test_resolves_container_path_via_mapping(self, tmp_path):
"""download_file should resolve container paths through path mappings."""
data_dir = tmp_path / "data"
data_dir.mkdir()
(data_dir / "asset.bin").write_bytes(b"\x01\x02\x03")
sandbox = LocalSandbox(
"test",
[PathMapping(container_path="/mnt/user-data", local_path=str(data_dir))],
)
result = sandbox.download_file("/mnt/user-data/asset.bin")
assert result == b"\x01\x02\x03"
def test_raises_oserror_with_original_path_when_missing(self, tmp_path):
"""OSError filename should show the container path, not the resolved host path."""
data_dir = tmp_path / "data"
data_dir.mkdir()
sandbox = LocalSandbox(
"test",
[PathMapping(container_path="/mnt/user-data", local_path=str(data_dir))],
)
with pytest.raises(OSError) as exc_info:
sandbox.download_file("/mnt/user-data/missing.bin")
assert exc_info.value.filename == "/mnt/user-data/missing.bin"
def test_rejects_path_outside_virtual_prefix_and_logs_error(self, tmp_path, caplog):
"""download_file must reject paths outside /mnt/user-data and log the reason."""
data_dir = tmp_path / "data"
data_dir.mkdir()
(data_dir / "model.bin").write_bytes(b"weights")
sandbox = LocalSandbox(
"test",
[PathMapping(container_path="/mnt/user-data", local_path=str(data_dir), read_only=True)],
)
with caplog.at_level("ERROR"):
with pytest.raises(PermissionError) as exc_info:
sandbox.download_file("/mnt/skills/model.bin")
assert exc_info.value.errno == errno.EACCES
assert "outside allowed directory" in caplog.text
def test_readable_from_read_only_mount(self, tmp_path):
"""Read-only mounts must not block download_file — read-only only restricts writes."""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
(skills_dir / "model.bin").write_bytes(b"weights")
sandbox = LocalSandbox(
"test",
[PathMapping(container_path="/mnt/user-data", local_path=str(skills_dir), read_only=True)],
)
result = sandbox.download_file("/mnt/user-data/model.bin")
assert result == b"weights"
class TestMultipleMounts:
def test_multiple_read_write_mounts(self, tmp_path):
skills_dir = tmp_path / "skills"
@@ -639,3 +727,148 @@ class TestLocalSandboxProviderMounts:
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
class TestLocalSandboxProviderResetClearsSingleton:
"""Regression coverage for issue #2815.
The module-level LocalSandbox singleton must be cleared whenever the
provider is reset or shut down otherwise stale path mappings and
mount policy survive config reloads and test teardown.
"""
def _build_config(self, skills_dir, mounts):
from deerflow.config.sandbox_config import SandboxConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=mounts,
)
return SimpleNamespace(
skills=SimpleNamespace(
container_path="/mnt/skills",
get_skills_path=lambda: skills_dir,
use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage",
),
sandbox=sandbox_config,
)
def test_reset_sandbox_provider_clears_local_singleton(self, tmp_path):
from deerflow.config.sandbox_config import VolumeMountConfig
from deerflow.sandbox import local as local_module
from deerflow.sandbox.local import local_sandbox_provider as lsp_module
from deerflow.sandbox.sandbox_provider import (
get_sandbox_provider,
reset_sandbox_provider,
)
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
first_dir = tmp_path / "first"
first_dir.mkdir()
second_dir = tmp_path / "second"
second_dir.mkdir()
first_cfg = self._build_config(
skills_dir,
[VolumeMountConfig(host_path=str(first_dir), container_path="/mnt/first", read_only=False)],
)
second_cfg = self._build_config(
skills_dir,
[VolumeMountConfig(host_path=str(second_dir), container_path="/mnt/second", read_only=False)],
)
# Make sure no leftover singleton from a prior test interferes.
lsp_module._singleton = None
reset_sandbox_provider()
try:
with patch("deerflow.sandbox.sandbox_provider.get_app_config", return_value=first_cfg), patch("deerflow.config.get_app_config", return_value=first_cfg):
provider = get_sandbox_provider()
provider.acquire()
assert lsp_module._singleton is not None
first_container_paths = {m.container_path for m in lsp_module._singleton.path_mappings}
assert "/mnt/first" in first_container_paths
reset_sandbox_provider()
# The whole point of the regression: reset must drop the cached LocalSandbox.
assert lsp_module._singleton is None
with patch("deerflow.sandbox.sandbox_provider.get_app_config", return_value=second_cfg), patch("deerflow.config.get_app_config", return_value=second_cfg):
provider2 = get_sandbox_provider()
provider2.acquire()
assert provider2 is not provider
second_container_paths = {m.container_path for m in lsp_module._singleton.path_mappings}
assert "/mnt/second" in second_container_paths
assert "/mnt/first" not in second_container_paths
finally:
lsp_module._singleton = None
reset_sandbox_provider()
# Sanity: the local sandbox module still exposes the singleton symbol
# at the same module path (guards against accidental rename).
assert hasattr(local_module.local_sandbox_provider, "_singleton")
def test_shutdown_sandbox_provider_clears_local_singleton(self, tmp_path):
from deerflow.config.sandbox_config import VolumeMountConfig
from deerflow.sandbox.local import local_sandbox_provider as lsp_module
from deerflow.sandbox.sandbox_provider import (
get_sandbox_provider,
reset_sandbox_provider,
shutdown_sandbox_provider,
)
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
mount_dir = tmp_path / "mount"
mount_dir.mkdir()
cfg = self._build_config(
skills_dir,
[VolumeMountConfig(host_path=str(mount_dir), container_path="/mnt/data", read_only=False)],
)
lsp_module._singleton = None
reset_sandbox_provider()
try:
with patch("deerflow.sandbox.sandbox_provider.get_app_config", return_value=cfg), patch("deerflow.config.get_app_config", return_value=cfg):
provider = get_sandbox_provider()
provider.acquire()
assert lsp_module._singleton is not None
shutdown_sandbox_provider()
assert lsp_module._singleton is None
finally:
lsp_module._singleton = None
reset_sandbox_provider()
def test_provider_reset_method_is_idempotent(self, tmp_path):
from deerflow.sandbox.local import local_sandbox_provider as lsp_module
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
cfg = self._build_config(skills_dir, [])
lsp_module._singleton = None
try:
with patch("deerflow.config.get_app_config", return_value=cfg):
provider = LocalSandboxProvider()
provider.acquire()
assert lsp_module._singleton is not None
provider.reset()
assert lsp_module._singleton is None
# Calling reset again on an already-cleared singleton is safe.
provider.reset()
assert lsp_module._singleton is None
finally:
lsp_module._singleton = None
@@ -0,0 +1,366 @@
"""Issue #2873 regression — the public Sandbox API must honor the documented
/mnt/user-data contract uniformly across implementations.
Today AIO sandbox already accepts /mnt/user-data/... paths directly because the
container has those paths bind-mounted per-thread. LocalSandbox, however,
externalises that translation to ``deerflow.sandbox.tools`` via ``thread_data``,
so any caller that bypasses tools.py (e.g. ``uploads.py`` syncing files into a
remote sandbox via ``sandbox.update_file(virtual_path, ...)``) sees inconsistent
behaviour.
These tests pin down the **public Sandbox API boundary**: when a caller obtains
a ``LocalSandbox`` from ``LocalSandboxProvider.acquire(thread_id)`` and invokes
its abstract methods with documented virtual paths, those paths must resolve to
the thread's user-data directory automatically — no tools.py / thread_data
shim required.
"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
def _build_config(skills_dir: Path) -> SimpleNamespace:
"""Minimal app config covering what ``LocalSandboxProvider`` reads at init."""
return SimpleNamespace(
skills=SimpleNamespace(
container_path="/mnt/skills",
get_skills_path=lambda: skills_dir,
use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage",
),
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=[]),
)
@pytest.fixture
def isolated_paths(monkeypatch, tmp_path):
"""Redirect ``get_paths().base_dir`` to ``tmp_path`` and reset its singleton.
Without this, per-thread directories would be created under the developer's
real ``.deer-flow/`` tree.
"""
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
from deerflow.config import paths as paths_module
monkeypatch.setattr(paths_module, "_paths", None)
yield tmp_path
monkeypatch.setattr(paths_module, "_paths", None)
@pytest.fixture
def provider(isolated_paths, tmp_path):
"""Provider with a real skills dir and no custom mounts."""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
cfg = _build_config(skills_dir)
with patch("deerflow.config.get_app_config", return_value=cfg):
yield LocalSandboxProvider()
# ──────────────────────────────────────────────────────────────────────────
# 1. Direct Sandbox API accepts the virtual path contract for ``acquire(tid)``
# ──────────────────────────────────────────────────────────────────────────
def test_acquire_with_thread_id_returns_per_thread_id(provider):
sandbox_id = provider.acquire("alpha")
assert sandbox_id == "local:alpha"
def test_acquire_without_thread_id_remains_legacy_local_id(provider):
"""Backward-compat: ``acquire()`` with no thread keeps the singleton id."""
assert provider.acquire() == "local"
assert provider.acquire(None) == "local"
def test_write_then_read_via_public_api_with_virtual_path(provider):
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
assert sbx is not None
virtual = "/mnt/user-data/workspace/hello.txt"
sbx.write_file(virtual, "hi there")
assert sbx.read_file(virtual) == "hi there"
def test_list_dir_via_public_api_with_virtual_path(provider):
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
sbx.write_file("/mnt/user-data/workspace/foo.txt", "x")
entries = sbx.list_dir("/mnt/user-data/workspace")
# entries should be reverse-resolved back to the virtual prefix
assert any("/mnt/user-data/workspace/foo.txt" in e for e in entries)
def test_execute_command_with_virtual_path(provider):
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
sbx.write_file("/mnt/user-data/uploads/note.txt", "payload")
output = sbx.execute_command("ls /mnt/user-data/uploads")
assert "note.txt" in output
def test_glob_with_virtual_path(provider):
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
sbx.write_file("/mnt/user-data/outputs/report.md", "# r")
matches, _ = sbx.glob("/mnt/user-data/outputs", "*.md")
assert any(m.endswith("/mnt/user-data/outputs/report.md") for m in matches)
def test_grep_with_virtual_path(provider):
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
sbx.write_file("/mnt/user-data/workspace/findme.txt", "needle line\nother line")
matches, _ = sbx.grep("/mnt/user-data/workspace", "needle", literal=True)
assert matches
assert matches[0].path.endswith("/mnt/user-data/workspace/findme.txt")
def test_execute_command_lists_aggregate_user_data_root(provider):
"""``ls /mnt/user-data`` (the parent prefix itself) must list the three
subdirs matching the AIO container's natural filesystem view."""
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
# Touch all three subdirs so they materialise on disk
sbx.write_file("/mnt/user-data/workspace/.keep", "")
sbx.write_file("/mnt/user-data/uploads/.keep", "")
sbx.write_file("/mnt/user-data/outputs/.keep", "")
output = sbx.execute_command("ls /mnt/user-data")
assert "workspace" in output
assert "uploads" in output
assert "outputs" in output
def test_update_file_with_virtual_path_for_remote_sync_scenario(provider):
"""This is the exact code path used by ``uploads.py:282`` and ``feishu.py:389``.
They build a ``virtual_path`` like ``/mnt/user-data/uploads/foo.pdf`` and hand
raw bytes to the sandbox. Before this fix LocalSandbox would try to write to
the literal host path ``/mnt/user-data/uploads/foo.pdf`` and fail.
"""
sandbox_id = provider.acquire("alpha")
sbx = provider.get(sandbox_id)
sbx.update_file("/mnt/user-data/uploads/blob.bin", b"\x00\x01\x02binary")
assert sbx.read_file("/mnt/user-data/uploads/blob.bin").startswith("\x00\x01\x02")
# ──────────────────────────────────────────────────────────────────────────
# 2. Per-thread isolation (no cross-thread state leaks)
# ──────────────────────────────────────────────────────────────────────────
def test_two_threads_get_distinct_sandboxes(provider):
sid_a = provider.acquire("alpha")
sid_b = provider.acquire("beta")
assert sid_a != sid_b
sbx_a = provider.get(sid_a)
sbx_b = provider.get(sid_b)
assert sbx_a is not sbx_b
def test_per_thread_user_data_mapping_isolated(provider, isolated_paths):
"""Files written via one thread's sandbox must not be visible through another."""
sid_a = provider.acquire("alpha")
sid_b = provider.acquire("beta")
sbx_a = provider.get(sid_a)
sbx_b = provider.get(sid_b)
sbx_a.write_file("/mnt/user-data/workspace/secret.txt", "alpha-only")
# The same virtual path resolves to a different host path in thread "beta"
with pytest.raises(FileNotFoundError):
sbx_b.read_file("/mnt/user-data/workspace/secret.txt")
def test_agent_written_paths_per_thread_isolation(provider):
"""``_agent_written_paths`` tracks files this sandbox wrote so reverse-resolve
runs on read. The set must not leak across threads."""
sid_a = provider.acquire("alpha")
sid_b = provider.acquire("beta")
sbx_a = provider.get(sid_a)
sbx_b = provider.get(sid_b)
sbx_a.write_file("/mnt/user-data/workspace/in-a.txt", "marker")
assert sbx_a._agent_written_paths
assert not sbx_b._agent_written_paths
# ──────────────────────────────────────────────────────────────────────────
# 3. Lifecycle: get / release / reset
# ──────────────────────────────────────────────────────────────────────────
def test_get_returns_cached_instance_for_known_id(provider):
sid = provider.acquire("alpha")
assert provider.get(sid) is provider.get(sid)
def test_get_unknown_id_returns_none(provider):
assert provider.get("local:nonexistent") is None
def test_release_is_noop_keeps_instance_available(provider):
"""Local has no resources to release; the cached instance stays alive across
turns so ``_agent_written_paths`` persists for reverse-resolve on later reads."""
sid = provider.acquire("alpha")
sbx_before = provider.get(sid)
provider.release(sid)
sbx_after = provider.get(sid)
assert sbx_before is sbx_after
def test_reset_clears_both_generic_and_per_thread_caches(provider):
provider.acquire() # populate generic
provider.acquire("alpha") # populate per-thread
assert provider._generic_sandbox is not None
assert provider._thread_sandboxes
provider.reset()
assert provider._generic_sandbox is None
assert not provider._thread_sandboxes
# ──────────────────────────────────────────────────────────────────────────
# 4. is_local_sandbox detects both legacy and per-thread ids
# ──────────────────────────────────────────────────────────────────────────
def test_is_local_sandbox_accepts_both_id_formats():
from deerflow.sandbox.tools import is_local_sandbox
legacy = SimpleNamespace(state={"sandbox": {"sandbox_id": "local"}}, context={})
per_thread = SimpleNamespace(state={"sandbox": {"sandbox_id": "local:alpha"}}, context={})
foreign = SimpleNamespace(state={"sandbox": {"sandbox_id": "aio-12345"}}, context={})
unset = SimpleNamespace(state={}, context={})
assert is_local_sandbox(legacy) is True
assert is_local_sandbox(per_thread) is True
assert is_local_sandbox(foreign) is False
assert is_local_sandbox(unset) is False
# ──────────────────────────────────────────────────────────────────────────
# 5. Concurrency safety (Copilot review feedback)
# ──────────────────────────────────────────────────────────────────────────
def test_concurrent_acquire_same_thread_yields_single_instance(provider):
"""Two threads racing on ``acquire("alpha")`` must share one LocalSandbox.
Without the provider lock the check-then-act in ``acquire`` is non-atomic:
both racers would see an empty cache, both would build their own
LocalSandbox, and one would overwrite the other losing the loser's
``_agent_written_paths`` and any in-flight state on it.
"""
import threading
import time
from deerflow.sandbox.local import local_sandbox as local_sandbox_module
# Force a wide race window by slowing the LocalSandbox constructor down.
original_init = local_sandbox_module.LocalSandbox.__init__
def slow_init(self, *args, **kwargs):
time.sleep(0.05)
original_init(self, *args, **kwargs)
barrier = threading.Barrier(8)
results: list[str] = []
results_lock = threading.Lock()
def racer():
barrier.wait()
sid = provider.acquire("alpha")
with results_lock:
results.append(sid)
with patch.object(local_sandbox_module.LocalSandbox, "__init__", slow_init):
threads = [threading.Thread(target=racer) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
# Every racer must observe the same ``sandbox_id``…
assert len(set(results)) == 1, f"Racers saw different ids: {results}"
# …and the cache must hold exactly one instance for ``alpha``.
assert len(provider._thread_sandboxes) == 1
assert "alpha" in provider._thread_sandboxes
def test_concurrent_acquire_distinct_threads_yields_distinct_instances(provider):
"""Different thread_ids race-acquired in parallel each get their own sandbox."""
import threading
barrier = threading.Barrier(6)
sids: dict[str, str] = {}
lock = threading.Lock()
def racer(name: str):
barrier.wait()
sid = provider.acquire(name)
with lock:
sids[name] = sid
threads = [threading.Thread(target=racer, args=(f"t{i}",)) for i in range(6)]
for t in threads:
t.start()
for t in threads:
t.join()
assert set(sids.values()) == {f"local:t{i}" for i in range(6)}
assert set(provider._thread_sandboxes.keys()) == {f"t{i}" for i in range(6)}
# ──────────────────────────────────────────────────────────────────────────
# 6. Bounded memory growth (Copilot review feedback)
# ──────────────────────────────────────────────────────────────────────────
def test_thread_sandbox_cache_is_bounded(isolated_paths, tmp_path):
"""The LRU cap must evict the least-recently-used thread sandboxes once
exceeded otherwise long-running gateways would accumulate cache entries
for every distinct ``thread_id`` ever served."""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
cfg = _build_config(skills_dir)
with patch("deerflow.config.get_app_config", return_value=cfg):
provider = LocalSandboxProvider(max_cached_threads=3)
for i in range(5):
provider.acquire(f"t{i}")
# Only the 3 most-recent thread_ids should be retained.
assert set(provider._thread_sandboxes.keys()) == {"t2", "t3", "t4"}
assert provider.get("local:t0") is None
assert provider.get("local:t4") is not None
def test_lru_promotes_recently_used_thread(isolated_paths, tmp_path):
"""``get`` on a cached thread should mark it as most-recently used so a
later acquire-storm doesn't evict an active thread that is being polled."""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
cfg = _build_config(skills_dir)
with patch("deerflow.config.get_app_config", return_value=cfg):
provider = LocalSandboxProvider(max_cached_threads=3)
for name in ["a", "b", "c"]:
provider.acquire(name)
# Touch "a" via ``get`` so it becomes most-recently used.
provider.get("local:a")
# Adding a fourth thread should evict "b" (the new LRU), not "a".
provider.acquire("d")
assert "a" in provider._thread_sandboxes
assert "b" not in provider._thread_sandboxes
assert {"a", "c", "d"} == set(provider._thread_sandboxes.keys())
+412 -82
View File
@@ -1,24 +1,94 @@
"""Tests for LoopDetectionMiddleware."""
import copy
from collections import OrderedDict
from typing import Any
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, SystemMessage
import pytest
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
from pydantic import PrivateAttr
from deerflow.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
_MAX_PENDING_WARNINGS_PER_RUN,
LoopDetectionMiddleware,
_hash_tool_calls,
)
def _make_runtime(thread_id="test-thread"):
def _make_runtime(thread_id="test-thread", run_id="test-run"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
runtime.context = {"thread_id": thread_id, "run_id": run_id}
return runtime
def _pending_key(thread_id="test-thread", run_id="test-run"):
return (thread_id, run_id)
def _make_request(messages, runtime):
"""Build a minimal ModelRequest stand-in for wrap_model_call tests."""
request = MagicMock()
request.messages = list(messages)
request.runtime = runtime
request.override = lambda **updates: _override_request(request, updates)
return request
def _override_request(request, updates):
"""Mimic ModelRequest.override(): return a copy with fields replaced."""
new = MagicMock()
new.messages = updates.get("messages", request.messages)
new.runtime = updates.get("runtime", request.runtime)
new.override = lambda **u: _override_request(new, u)
return new
def _capture_handler():
"""Build a sync handler that records the request it was called with."""
captured: list = []
def handler(req):
captured.append(req)
return MagicMock()
return captured, handler
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
"""Fake chat model that records each model request's messages."""
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage.
@@ -138,7 +208,15 @@ class TestLoopDetection:
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_warn_at_threshold(self):
def test_warn_at_threshold_queues_but_does_not_mutate_state(self):
"""At warn threshold, ``after_model`` enqueues but returns None.
Detection observes the just-emitted AIMessage(tool_calls=...). The
tools node hasn't run yet, so injecting any non-tool message here
would split the assistant's tool_calls from their ToolMessage
responses and break OpenAI/Moonshot pairing. The warning is
delivered later from ``wrap_model_call``.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
runtime = _make_runtime()
call = [_bash_call("ls")]
@@ -146,44 +224,150 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning. The warning is appended to
# the AIMessage content (tool_calls preserved) — never inserted as a
# separate HumanMessage between the AIMessage(tool_calls) and its
# ToolMessage responses, which would break OpenAI/Moonshot strict
# tool-call pairing validation.
# Third identical call triggers warning detection.
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
assert "LOOP DETECTED" in msgs[0].content
# Detection must not mutate state — the AIMessage with tool_calls is
# left untouched so the tools node runs normally.
assert result is None
# ...but a warning is queued for the next model call.
assert mw._pending_warnings[_pending_key()]
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0]
def test_warn_does_not_break_tool_call_pairing(self):
"""Regression: the warn branch must NOT inject a non-tool message
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
request with 'tool_call_ids did not have response messages' if any
non-tool message is wedged between the AIMessage and its ToolMessage
responses. See #2029.
def test_warn_injected_at_next_model_call(self):
"""``wrap_model_call`` appends a HumanMessage(loop_warning) to the
outgoing messages *after* every existing message so that the
AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
# Build the messages the agent runtime would assemble for the next
# turn: prior AIMessage(tool_calls), its ToolMessage responses, ...
ai_msg = AIMessage(content="", tool_calls=call)
tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash")
request = _make_request([ai_msg, tool_msg], runtime)
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
sent = captured[0].messages
# AIMessage and ToolMessage stay in order, untouched.
assert sent[0] is ai_msg
assert sent[1] is tool_msg
# HumanMessage(warning) appears AFTER the ToolMessage — pairing intact.
assert isinstance(sent[2], HumanMessage)
assert sent[2].name == "loop_warning"
assert "LOOP DETECTED" in sent[2].content
def test_warn_queue_drained_after_injection(self):
"""A queued warning must be emitted exactly once per detection event."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
# First call: warning is appended.
mw.wrap_model_call(request, handler)
first = captured[0].messages
assert any(isinstance(m, HumanMessage) for m in first)
# Subsequent call without new detection: no warning re-emitted.
request2 = _make_request([AIMessage(content="hi")], runtime)
mw.wrap_model_call(request2, handler)
second = captured[1].messages
assert not any(isinstance(m, HumanMessage) for m in second)
def test_warn_queue_scoped_by_run_id(self):
"""A warning queued for one run must not be injected into another run."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime_a = _make_runtime(run_id="run-A")
runtime_b = _make_runtime(run_id="run-B")
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime_a)
request_b = _make_request([AIMessage(content="hi")], runtime_b)
captured, handler = _capture_handler()
mw.wrap_model_call(request_b, handler)
assert not any(isinstance(m, HumanMessage) for m in captured[0].messages)
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
request_a = _make_request([AIMessage(content="hi")], runtime_a)
mw.wrap_model_call(request_a, handler)
assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages)
def test_missing_run_id_uses_default_pending_scope(self):
"""When runtime has no run_id, warning handling falls back to the default run scope."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
assert mw._pending_warnings.get(_pending_key(run_id="default"))
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert len(loop_warnings) == 1
assert "LOOP DETECTED" in loop_warnings[0].content
assert not mw._pending_warnings.get(_pending_key(run_id="default"))
def test_before_agent_clears_stale_pending_warnings_for_thread(self):
"""Starting a new run drops stale warnings from prior runs in the same thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime_a = _make_runtime(run_id="run-A")
runtime_b = _make_runtime(run_id="run-B")
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime_a)
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
mw.before_agent({"messages": []}, runtime_b)
assert not mw._pending_warnings.get(_pending_key(run_id="run-A"))
def test_after_agent_clears_current_run_pending_warnings(self):
"""Run cleanup should drop warnings that never reached wrap_model_call."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
assert mw._pending_warnings.get(_pending_key())
mw.after_agent({"messages": []}, runtime)
assert not mw._pending_warnings.get(_pending_key())
def test_multiple_pending_warnings_are_merged_into_one_message(self):
"""Edge-case drains should produce one loop_warning prompt message."""
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"]
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert len(loop_warnings) == 1
assert loop_warnings[0].content == "first warning\n\nsecond warning"
def test_warn_only_queued_once_per_hash(self):
"""Same hash repeated past the threshold should warn only once."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
@@ -192,14 +376,13 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third — warning injected
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Third — warning queued
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
# Fourth — warning already injected, should return None
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
# Fourth — already warned for this hash, no additional enqueue.
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
def test_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
@@ -257,6 +440,7 @@ class TestLoopDetection:
mw.reset()
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
assert not mw._pending_warnings.get(_pending_key())
def test_non_ai_message_ignored(self):
mw = LoopDetectionMiddleware()
@@ -283,15 +467,16 @@ class TestLoopDetection:
# One call on thread B
mw._apply(_make_state(tool_calls=call), runtime_b)
# Second call on thread A — triggers warning (2 >= warn_threshold)
result = mw._apply(_make_state(tool_calls=call), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread A — queues warning under thread-A only.
mw._apply(_make_state(tool_calls=call), runtime_a)
assert mw._pending_warnings.get(_pending_key("thread-A"))
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
assert not mw._pending_warnings.get(_pending_key("thread-B"))
# Second call on thread B — also triggers (independent tracking)
result = mw._apply(_make_state(tool_calls=call), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread B — independent queue.
mw._apply(_make_state(tool_calls=call), runtime_b)
assert mw._pending_warnings.get(_pending_key("thread-B"))
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
def test_lru_eviction(self):
"""Old threads should be evicted when max_tracked_threads is exceeded."""
@@ -313,6 +498,55 @@ class TestLoopDetection:
assert "thread-new" in mw._history
assert len(mw._history) == 3
def test_warned_hashes_are_pruned_to_sliding_window(self):
"""A long-lived thread should not keep every historical warned hash."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=100, window_size=4)
runtime = _make_runtime()
for i in range(12):
call = [_bash_call(f"cmd_{i}")]
mw._apply(_make_state(tool_calls=call), runtime)
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._history["test-thread"]) <= 4
assert set(mw._warned["test-thread"]).issubset(set(mw._history["test-thread"]))
assert len(mw._warned["test-thread"]) <= 4
def test_pending_warning_keys_are_capped(self):
"""Abnormal same-thread runs cannot grow pending-warning keys forever."""
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=2)
for i in range(10):
runtime = _make_runtime(thread_id="same-thread", run_id=f"run-{i}")
mw._queue_pending_warning(runtime, f"warning-{i}")
assert len(mw._pending_warnings) == mw._max_pending_warning_keys
assert len(mw._pending_warning_touch_order) == mw._max_pending_warning_keys
assert _pending_key("same-thread", "run-9") in mw._pending_warnings
def test_pending_warning_list_is_capped_and_deduped(self):
"""One run cannot accumulate an unbounded warning list."""
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
for i in range(_MAX_PENDING_WARNINGS_PER_RUN + 4):
mw._queue_pending_warning(runtime, f"warning-{i}")
mw._queue_pending_warning(runtime, f"warning-{_MAX_PENDING_WARNINGS_PER_RUN + 3}")
warnings = mw._pending_warnings[_pending_key()]
assert len(warnings) == _MAX_PENDING_WARNINGS_PER_RUN
assert warnings == [f"warning-{i}" for i in range(4, _MAX_PENDING_WARNINGS_PER_RUN + 4)]
def test_pending_warning_touch_order_cleared_with_pending_key(self):
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
mw._queue_pending_warning(runtime, "warning")
mw.after_agent({"messages": []}, runtime)
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
def test_thread_safe_mutations(self):
"""Verify lock is used for mutations (basic structural test)."""
mw = LoopDetectionMiddleware()
@@ -331,6 +565,99 @@ class TestLoopDetection:
assert "default" in mw._history
class TestLoopDetectionAgentGraphIntegration:
def test_loop_warning_is_transient_in_real_agent_graph(self):
"""after_model queues the warning; wrap_model_call injects it request-only."""
@as_tool
def bash(command: str) -> str:
"""Run a fake shell command."""
return f"ran: {command}"
repeated_calls = [[{"name": "bash", "id": f"call_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(content="", tool_calls=repeated_calls[0]),
AIMessage(content="", tool_calls=repeated_calls[1]),
AIMessage(content="", tool_calls=repeated_calls[2]),
AIMessage(content="final answer"),
],
)
graph = create_agent(model=model, tools=[bash], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "inspect the directory")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
config={"recursion_limit": 20},
)
assert len(model.seen_messages) == 4
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
assert loop_warnings_by_call[0] == []
assert loop_warnings_by_call[1] == []
assert loop_warnings_by_call[2] == []
assert len(loop_warnings_by_call[3]) == 1
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
fourth_request = model.seen_messages[3]
assert isinstance(fourth_request[-2], ToolMessage)
assert fourth_request[-2].tool_call_id == "call_ls_2"
assert fourth_request[-1] is loop_warnings_by_call[3][0]
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert persisted_loop_warnings == []
assert result["messages"][-1].content == "final answer"
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
@pytest.mark.asyncio
async def test_loop_warning_is_transient_in_async_agent_graph(self):
"""awrap_model_call injects loop_warning request-only in async graph runs."""
@as_tool
async def bash(command: str) -> str:
"""Run a fake shell command."""
return f"ran: {command}"
repeated_calls = [[{"name": "bash", "id": f"call_async_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(content="", tool_calls=repeated_calls[0]),
AIMessage(content="", tool_calls=repeated_calls[1]),
AIMessage(content="", tool_calls=repeated_calls[2]),
AIMessage(content="async final answer"),
],
)
graph = create_agent(model=model, tools=[bash], middleware=[mw])
result = await graph.ainvoke(
{"messages": [("user", "inspect the directory asynchronously")]},
context={"thread_id": "async-integration-thread", "run_id": "async-integration-run"},
config={"recursion_limit": 20},
)
assert len(model.seen_messages) == 4
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
assert loop_warnings_by_call[0] == []
assert loop_warnings_by_call[1] == []
assert loop_warnings_by_call[2] == []
assert len(loop_warnings_by_call[3]) == 1
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
fourth_request = model.seen_messages[3]
assert isinstance(fourth_request[-2], ToolMessage)
assert fourth_request[-2].tool_call_id == "call_async_ls_2"
assert fourth_request[-1] is loop_warnings_by_call[3][0]
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert persisted_loop_warnings == []
assert result["messages"][-1].content == "async final answer"
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
class TestAppendText:
"""Unit tests for LoopDetectionMiddleware._append_text."""
@@ -507,33 +834,29 @@ class TestToolFrequencyDetection:
for i in range(4):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 5th call to read_file (different file each time) triggers freq warning
# 5th call queues a per-tool-type frequency warning; state untouched.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
assert result is not None
msg = result["messages"][0]
# Warning is appended to the AIMessage content; tool_calls preserved
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
# validation does not break.
assert isinstance(msg, AIMessage)
assert msg.tool_calls
assert "read_file" in msg.content
assert "LOOP DETECTED" in msg.content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "read_file" in queued[0]
assert "LOOP DETECTED" in queued[0]
def test_freq_warn_only_injected_once(self):
def test_freq_warn_only_queued_once(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd triggers warning
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# 3rd queues a frequency warning.
mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
# 4th should not re-warn (already warned for read_file)
# 4th: same tool name, no additional enqueue.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime)
assert result is None
assert len(mw._pending_warnings[_pending_key()]) == 1
def test_freq_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6)
@@ -565,10 +888,10 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None
# 3rd read_file triggers (read_file count = 3)
# 3rd read_file triggers — warning is queued (state unchanged).
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
assert "read_file" in mw._pending_warnings[_pending_key()][0]
def test_freq_reset_clears_state(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
@@ -600,10 +923,10 @@ class TestToolFrequencyDetection:
assert "thread-A" not in mw._tool_freq
assert "thread-A" not in mw._tool_freq_warned
# thread-B state should still be intact — 3rd call triggers warn
# thread-B state should still be intact — 3rd call queues a warn.
result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
# thread-A restarted from 0 — should not trigger
result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a)
@@ -623,10 +946,11 @@ class TestToolFrequencyDetection:
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b)
# 3rd call on thread A — triggers (count=3 for thread A only)
# 3rd call on thread A — queues a warning (count=3 for thread A only).
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
assert not mw._pending_warnings.get(_pending_key("thread-B"))
def test_multi_tool_single_response_counted(self):
"""When a single response has multiple tool calls, each is counted."""
@@ -643,10 +967,10 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
# Response 3: 1 more → count = 5 → triggers warn
# Response 3: 1 more → count = 5 → queues warn.
result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
assert "read_file" in mw._pending_warnings[_pending_key()][0]
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
@@ -674,10 +998,14 @@ class TestToolFrequencyDetection:
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
# 3rd read_file call hits global warn=3 (read_file has no override).
# Warning delivery is deferred to wrap_model_call so the just-emitted
# AIMessage(tool_calls=...) is not mutated before ToolMessages exist.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "read_file" in queued[0]
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
@@ -736,11 +1064,13 @@ class TestFromConfig:
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
def test_constructed_middleware_queues_loop_warning(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "LOOP DETECTED" in queued[0]
+20
View File
@@ -24,6 +24,26 @@ def test_build_server_params_stdio_success():
}
def test_extensions_config_resolves_env_variables_inside_nested_collections(monkeypatch):
monkeypatch.setenv("MCP_TOKEN", "secret")
monkeypatch.delenv("MISSING_TOKEN", raising=False)
raw_config = {
"args": ["--token", "$MCP_TOKEN", {"nested": ["$MCP_TOKEN", "$MISSING_TOKEN"]}],
"tuple_args": ("$MCP_TOKEN", "$MISSING_TOKEN"),
"env": {"API_KEY": "$MCP_TOKEN"},
"enabled": True,
"timeout": 30,
}
resolved = ExtensionsConfig.resolve_env_variables(raw_config)
assert resolved["args"] == ["--token", "secret", {"nested": ["secret", ""]}]
assert resolved["tuple_args"] == ("secret", "")
assert resolved["env"] == {"API_KEY": "secret"}
assert resolved["enabled"] is True
assert resolved["timeout"] == 30
def test_build_server_params_stdio_requires_command():
config = McpServerConfig(type="stdio", command=None)
+305
View File
@@ -0,0 +1,305 @@
"""Tests for MCP config secret masking and preservation.
Verifies that GET /api/mcp/config masks sensitive fields (env values,
header values, OAuth secrets) and that PUT /api/mcp/config correctly
preserves existing secrets when the frontend round-trips masked values.
"""
from __future__ import annotations
import pytest
from app.gateway.routers.mcp import (
McpOAuthConfigResponse,
McpServerConfigResponse,
_mask_server_config,
_merge_preserving_secrets,
)
# ---------------------------------------------------------------------------
# _mask_server_config
# ---------------------------------------------------------------------------
def test_mask_replaces_env_values_with_asterisks():
"""Env dict values should be replaced with '***'."""
server = McpServerConfigResponse(
env={"GITHUB_TOKEN": "ghp_real_secret_123", "API_KEY": "sk-abc"},
)
masked = _mask_server_config(server)
assert masked.env == {"GITHUB_TOKEN": "***", "API_KEY": "***"}
def test_mask_replaces_header_values_with_asterisks():
"""Header dict values should be replaced with '***'."""
server = McpServerConfigResponse(
headers={"Authorization": "Bearer tok_123", "X-API-Key": "key_456"},
)
masked = _mask_server_config(server)
assert masked.headers == {"Authorization": "***", "X-API-Key": "***"}
def test_mask_removes_oauth_secrets():
"""OAuth client_secret and refresh_token should be set to None."""
server = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_id="my-client",
client_secret="super-secret",
refresh_token="refresh-token-abc",
token_url="https://auth.example.com/token",
),
)
masked = _mask_server_config(server)
assert masked.oauth is not None
assert masked.oauth.client_secret is None
assert masked.oauth.refresh_token is None
# Non-secret fields preserved
assert masked.oauth.client_id == "my-client"
assert masked.oauth.token_url == "https://auth.example.com/token"
def test_mask_preserves_non_secret_fields():
"""Non-sensitive fields should pass through unchanged."""
server = McpServerConfigResponse(
enabled=True,
type="stdio",
command="npx",
args=["-y", "@modelcontextprotocol/server-github"],
env={"KEY": "val"},
description="GitHub MCP server",
)
masked = _mask_server_config(server)
assert masked.enabled is True
assert masked.type == "stdio"
assert masked.command == "npx"
assert masked.args == ["-y", "@modelcontextprotocol/server-github"]
assert masked.description == "GitHub MCP server"
def test_mask_handles_empty_env_and_headers():
"""Empty env/headers dicts should remain empty."""
server = McpServerConfigResponse()
masked = _mask_server_config(server)
assert masked.env == {}
assert masked.headers == {}
def test_mask_handles_no_oauth():
"""Server without OAuth should remain None."""
server = McpServerConfigResponse(oauth=None)
masked = _mask_server_config(server)
assert masked.oauth is None
def test_mask_does_not_mutate_original():
"""Masking should return a new object, not modify the original."""
server = McpServerConfigResponse(env={"KEY": "secret"})
masked = _mask_server_config(server)
assert server.env["KEY"] == "secret"
assert masked.env["KEY"] == "***"
# ---------------------------------------------------------------------------
# _merge_preserving_secrets
# ---------------------------------------------------------------------------
def test_merge_preserves_masked_env_values():
"""Incoming '***' env values should be replaced with existing secrets."""
incoming = McpServerConfigResponse(env={"KEY": "***"})
existing = McpServerConfigResponse(env={"KEY": "real_secret"})
merged = _merge_preserving_secrets(incoming, existing)
assert merged.env["KEY"] == "real_secret"
def test_merge_preserves_masked_header_values():
"""Incoming '***' header values should be replaced with existing secrets."""
incoming = McpServerConfigResponse(headers={"Authorization": "***"})
existing = McpServerConfigResponse(headers={"Authorization": "Bearer real"})
merged = _merge_preserving_secrets(incoming, existing)
assert merged.headers["Authorization"] == "Bearer real"
def test_merge_preserves_oauth_secrets_when_none():
"""Incoming None oauth secrets should preserve existing values."""
incoming = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret=None,
refresh_token=None,
token_url="https://auth.example.com/token",
),
)
existing = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret="existing-secret",
refresh_token="existing-refresh",
token_url="https://auth.example.com/token",
),
)
merged = _merge_preserving_secrets(incoming, existing)
assert merged.oauth is not None
assert merged.oauth.client_secret == "existing-secret"
assert merged.oauth.refresh_token == "existing-refresh"
def test_merge_accepts_new_secret_values():
"""Incoming real secret values should replace existing ones."""
incoming = McpServerConfigResponse(
env={"KEY": "new_secret"},
oauth=McpOAuthConfigResponse(
client_secret="new-client-secret",
refresh_token="new-refresh-token",
token_url="https://auth.example.com/token",
),
)
existing = McpServerConfigResponse(
env={"KEY": "old_secret"},
oauth=McpOAuthConfigResponse(
client_secret="old-secret",
refresh_token="old-refresh",
token_url="https://auth.example.com/token",
),
)
merged = _merge_preserving_secrets(incoming, existing)
assert merged.env["KEY"] == "new_secret"
assert merged.oauth.client_secret == "new-client-secret"
assert merged.oauth.refresh_token == "new-refresh-token"
def test_merge_handles_no_existing_oauth():
"""When existing has no oauth but incoming does, keep incoming."""
incoming = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret="new-secret",
token_url="https://auth.example.com/token",
),
)
existing = McpServerConfigResponse(oauth=None)
merged = _merge_preserving_secrets(incoming, existing)
assert merged.oauth is not None
assert merged.oauth.client_secret == "new-secret"
def test_merge_does_not_mutate_original():
"""Merge should return a new object, not modify the original."""
incoming = McpServerConfigResponse(env={"KEY": "***"})
existing = McpServerConfigResponse(env={"KEY": "secret"})
merged = _merge_preserving_secrets(incoming, existing)
assert incoming.env["KEY"] == "***"
assert existing.env["KEY"] == "secret"
assert merged.env["KEY"] == "secret"
# ---------------------------------------------------------------------------
# Comment 2 fix: masked value for new key is rejected
# ---------------------------------------------------------------------------
def test_merge_rejects_masked_value_for_new_env_key():
"""Sending '***' for a key that doesn't exist in existing should raise 400."""
from fastapi import HTTPException
incoming = McpServerConfigResponse(env={"NEW_KEY": "***"})
existing = McpServerConfigResponse(env={})
with pytest.raises(HTTPException) as exc_info:
_merge_preserving_secrets(incoming, existing)
assert exc_info.value.status_code == 400
assert "NEW_KEY" in exc_info.value.detail
def test_merge_rejects_masked_value_for_new_header_key():
"""Sending '***' for a header key that doesn't exist should raise 400."""
from fastapi import HTTPException
incoming = McpServerConfigResponse(headers={"X-New-Auth": "***"})
existing = McpServerConfigResponse(headers={})
with pytest.raises(HTTPException) as exc_info:
_merge_preserving_secrets(incoming, existing)
assert exc_info.value.status_code == 400
assert "X-New-Auth" in exc_info.value.detail
# ---------------------------------------------------------------------------
# Comment 4 fix: empty string clears OAuth secrets
# ---------------------------------------------------------------------------
def test_merge_empty_string_clears_oauth_client_secret():
"""Sending '' for client_secret should clear the stored value."""
incoming = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret="",
refresh_token=None,
token_url="https://auth.example.com/token",
),
)
existing = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret="existing-secret",
refresh_token="existing-refresh",
token_url="https://auth.example.com/token",
),
)
merged = _merge_preserving_secrets(incoming, existing)
assert merged.oauth.client_secret is None
assert merged.oauth.refresh_token == "existing-refresh"
def test_merge_empty_string_clears_oauth_refresh_token():
"""Sending '' for refresh_token should clear the stored value."""
incoming = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret=None,
refresh_token="",
token_url="https://auth.example.com/token",
),
)
existing = McpServerConfigResponse(
oauth=McpOAuthConfigResponse(
client_secret="existing-secret",
refresh_token="existing-refresh",
token_url="https://auth.example.com/token",
),
)
merged = _merge_preserving_secrets(incoming, existing)
assert merged.oauth.client_secret == "existing-secret"
assert merged.oauth.refresh_token is None
# ---------------------------------------------------------------------------
# Round-trip integration: mask → merge should preserve original secrets
# ---------------------------------------------------------------------------
def test_roundtrip_mask_then_merge_preserves_original_secrets():
"""Simulates the full frontend round-trip: GET (masked) → toggle → PUT."""
original = McpServerConfigResponse(
enabled=True,
env={"GITHUB_TOKEN": "ghp_real_secret"},
headers={"Authorization": "Bearer real_token"},
oauth=McpOAuthConfigResponse(
client_id="client-123",
client_secret="oauth-secret",
refresh_token="refresh-abc",
token_url="https://auth.example.com/token",
),
description="GitHub MCP server",
)
# Step 1: Server returns masked config (simulates GET response)
masked = _mask_server_config(original)
assert masked.env["GITHUB_TOKEN"] == "***"
assert masked.oauth.client_secret is None
# Step 2: Frontend toggles enabled and sends back (simulates PUT request)
from_frontend = masked.model_copy(update={"enabled": False})
# Step 3: Server merges with existing secrets (simulates PUT handler)
restored = _merge_preserving_secrets(from_frontend, original)
assert restored.enabled is False
assert restored.env["GITHUB_TOKEN"] == "ghp_real_secret"
assert restored.headers["Authorization"] == "Bearer real_token"
assert restored.oauth.client_secret == "oauth-secret"
assert restored.oauth.refresh_token == "refresh-abc"
# Non-secret fields from the update are preserved
assert restored.description == "GitHub MCP server"
+409
View File
@@ -0,0 +1,409 @@
"""Tests for the MCP persistent-session pool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
@pytest.fixture(autouse=True)
def _reset_pool():
reset_session_pool()
yield
reset_session_pool()
# ---------------------------------------------------------------------------
# MCPSessionPool unit tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_session_creates_new():
"""First call for a key creates a new session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert session is mock_session
mock_session.initialize.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_session_reuses_existing():
"""Second call for the same key returns the cached session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert s1 is s2
# Only one session should have been created.
assert mock_cm.__aenter__.await_count == 1
@pytest.mark.asyncio
async def test_different_scope_creates_different_session():
"""Different scope keys get different sessions."""
pool = MCPSessionPool()
sessions = [AsyncMock(), AsyncMock()]
idx = 0
class CmFactory:
def __init__(self):
self.enter_count = 0
async def __aenter__(self):
nonlocal idx
s = sessions[idx]
idx += 1
self.enter_count += 1
return s
async def __aexit__(self, *args):
return False
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
assert s1 is not s2
assert s1 is sessions[0]
assert s2 is sessions[1]
@pytest.mark.asyncio
async def test_lru_eviction():
"""Oldest entries are evicted when the pool is full."""
pool = MCPSessionPool()
pool.MAX_SESSIONS = 2
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
# Pool is full (2). Adding t3 should evict t1.
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
assert cms[0].closed is True
assert cms[1].closed is False
assert cms[2].closed is False
@pytest.mark.asyncio
async def test_close_scope():
"""close_scope shuts down sessions for a specific scope key."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_scope("t1")
assert cms[0].closed is True
assert cms[1].closed is False
# t2 session still exists.
assert ("s", "t2") in pool._entries
@pytest.mark.asyncio
async def test_close_all():
"""close_all shuts down every session."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_all()
assert all(cm.closed for cm in cms)
assert len(pool._entries) == 0
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
def test_get_session_pool_singleton():
"""get_session_pool returns the same instance."""
p1 = get_session_pool()
p2 = get_session_pool()
assert p1 is p2
def test_reset_session_pool():
"""reset_session_pool clears the singleton."""
p1 = get_session_pool()
reset_session_pool()
p2 = get_session_pool()
assert p1 is not p2
# ---------------------------------------------------------------------------
# Integration: _make_session_pool_tool uses the pool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_pool_tool_wrapping():
"""The wrapper tool delegates to a pool-managed session."""
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Simulate a tool call with a runtime context containing thread_id.
mock_runtime = MagicMock()
mock_runtime.context = {"thread_id": "thread-42"}
mock_runtime.config = {}
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
@pytest.mark.asyncio
async def test_session_pool_tool_extracts_thread_id():
"""Thread ID is extracted from runtime.config when not in context."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
mock_runtime = MagicMock()
mock_runtime.context = {}
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
await wrapped.coroutine(runtime=mock_runtime, x=1)
# Verify the session was created with the correct scope key.
pool = get_session_pool()
assert ("server", "from-config") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_default_scope():
"""When no thread_id is available, 'default' is used as scope key."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# No thread_id in runtime at all.
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "default") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_get_config_fallback():
"""When runtime is None, get_config() provides thread_id as fallback."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
fake_config = {"configurable": {"thread_id": "from-langgraph-config"}}
with (
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
patch("deerflow.mcp.tools.get_config", return_value=fake_config),
):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# runtime=None — get_config() fallback should provide thread_id
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "from-langgraph-config") in pool._entries
def test_session_pool_tool_sync_wrapper_path_is_safe():
"""Sync wrapper (tool.func) invocation doesn't crash on cross-loop access."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
from deerflow.tools.sync import make_sync_tool_wrapper
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Attach the sync wrapper exactly as get_mcp_tools() does.
wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name)
# Call via the sync path (asyncio.run in a worker thread).
# runtime is not supplied so _extract_thread_id falls back to "default".
wrapped.func(url="https://example.com")
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
+62 -8
View File
@@ -1,11 +1,14 @@
import asyncio
import contextvars
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_sync_tool_wrapper, get_mcp_tools
from deerflow.mcp.tools import get_mcp_tools
from deerflow.tools.sync import make_sync_tool_wrapper
class MockArgs(BaseModel):
@@ -51,14 +54,13 @@ def test_mcp_tool_sync_wrapper_generation():
def test_mcp_tool_sync_wrapper_in_running_loop():
"""Test the actual helper function from production code (Fix for Comment 1 & 3)."""
"""Test the shared sync wrapper from production code."""
async def mock_coro(x: int):
await asyncio.sleep(0.01)
return f"async_result: {x}"
# Test the real helper function exported from deerflow.mcp.tools
sync_func = _make_sync_tool_wrapper(mock_coro, "test_tool")
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
async def run_in_loop():
# This call should succeed due to ThreadPoolExecutor in the real helper
@@ -69,17 +71,69 @@ def test_mcp_tool_sync_wrapper_in_running_loop():
assert result == "async_result: 100"
def test_sync_wrapper_preserves_contextvars_in_running_loop():
"""The executor branch preserves LangGraph-style contextvars."""
current_value: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_value", default=None)
async def mock_coro() -> str | None:
return current_value.get()
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
async def run_in_loop() -> str | None:
token = current_value.set("from-parent-context")
try:
return sync_func()
finally:
current_value.reset(token)
assert asyncio.run(run_in_loop()) == "from-parent-context"
def test_sync_wrapper_preserves_runnable_config_injection():
"""LangChain can still inject RunnableConfig after an async tool is wrapped."""
captured: dict[str, object] = {}
async def mock_coro(x: int, config: RunnableConfig = None):
captured["thread_id"] = ((config or {}).get("configurable") or {}).get("thread_id")
return f"result: {x}"
mock_tool = StructuredTool(
name="test_tool",
description="test description",
args_schema=MockArgs,
func=make_sync_tool_wrapper(mock_coro, "test_tool"),
coroutine=mock_coro,
)
result = mock_tool.invoke({"x": 42}, config={"configurable": {"thread_id": "thread-123"}})
assert result == "result: 42"
assert captured["thread_id"] == "thread-123"
def test_sync_wrapper_preserves_regular_config_argument():
"""Only RunnableConfig-annotated coroutine params get special config injection."""
async def mock_coro(config: str):
return config
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
assert sync_func(config="user-config") == "user-config"
def test_mcp_tool_sync_wrapper_exception_logging():
"""Test the actual helper's error logging (Fix for Comment 3)."""
"""Test the shared sync wrapper's error logging."""
async def error_coro():
raise ValueError("Tool failure")
sync_func = _make_sync_tool_wrapper(error_coro, "error_tool")
sync_func = make_sync_tool_wrapper(error_coro, "error_tool")
with patch("deerflow.mcp.tools.logger.error") as mock_log_error:
with patch("deerflow.tools.sync.logger.error") as mock_log_error:
with pytest.raises(ValueError, match="Tool failure"):
sync_func()
mock_log_error.assert_called_once()
# Verify the tool name is in the log message
assert "error_tool" in mock_log_error.call_args[0][0]
assert mock_log_error.call_args[0][1] == "error_tool"
+83 -1
View File
@@ -1,6 +1,6 @@
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
assert elapsed < 0.1
assert finished.is_set() is False
assert finished.wait(1.0) is True
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
assert queue.pending_count == 2
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(
thread_id="thread-1",
messages=["first"],
agent_name="agent-a",
correction_detected=True,
)
queue.add(
thread_id="thread-1",
messages=["second"],
agent_name="agent-a",
correction_detected=False,
)
assert queue.pending_count == 1
assert queue._queue[0].agent_name == "agent-a"
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with (
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
patch("deerflow.agents.memory.queue.time.sleep"),
):
queue.flush()
assert mock_updater.update_memory.call_count == 2
mock_updater.update_memory.assert_has_calls(
[
call(
messages=["agent-a"],
thread_id="thread-1",
agent_name="agent-a",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
call(
messages=["agent-b"],
thread_id="thread-1",
agent_name="agent-b",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
]
)
@@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id():
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
assert q._queue[0].user_id == "alice"
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock()
@@ -37,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
mock_updater.update_memory.assert_called_once()
call_kwargs = mock_updater.update_memory.call_args.kwargs
assert call_kwargs["user_id"] == "alice"
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
assert q.pending_count == 1
assert q._queue[0].messages == ["second"]
assert q._queue[0].user_id == "alice"
assert q._queue[0].agent_name == "researcher"
def test_add_nowait_keeps_different_users_separate():
q = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
patch.object(q, "_schedule_timer"),
):
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
+35
View File
@@ -78,6 +78,41 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
assert all(fact["id"] != "fact_remove" for fact in result["facts"])
def test_prepare_update_prompt_preserves_non_ascii_memory_text() -> None:
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_cn",
"content": "Deer-flow是一个非常好的框架。",
"category": "context",
"confidence": 0.9,
"createdAt": "2026-05-20T00:00:00Z",
"source": "thread-cn",
},
]
)
with (
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
):
msg = MagicMock()
msg.type = "human"
msg.content = "你好"
prepared = updater._prepare_update_prompt(
[msg],
agent_name=None,
correction_detected=False,
reinforcement_detected=False,
)
assert prepared is not None
_, prompt = prepared
assert "Deer-flow是一个非常好的框架。" in prompt
assert "\\u" not in prompt
def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None:
updater = MemoryUpdater()
current_memory = _make_memory()
-1
View File
@@ -454,7 +454,6 @@ class TestAStream:
@pytest.mark.asyncio
async def test_with_tools_emits_tool_call_chunk(self):
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
+106
View File
@@ -0,0 +1,106 @@
"""Regression tests for #3120: SQLite-backed stores must emit tz-aware ISO timestamps.
SQLAlchemy's ``DateTime(timezone=True)`` is a no-op on SQLite because the
backend has no native timezone type, so values read back are naive
``datetime`` instances. The four SQL ``_row_to_dict`` helpers therefore
have to normalize through :func:`deerflow.utils.time.coerce_iso` instead
of calling ``.isoformat()`` directly; otherwise the API ships
timezone-less strings (e.g. ``"2026-05-20T06:10:22.970977"``) and the
frontend's ``new Date(...)`` parses them as local time, shifting recent
threads by the local UTC offset.
"""
import re
import pytest
_TZ_SUFFIX_RE = re.compile(r"(?:\+\d{2}:\d{2}|Z)$")
def _assert_tz_aware(value: str | None, *, context: str) -> None:
assert value, f"{context}: expected ISO string, got {value!r}"
assert _TZ_SUFFIX_RE.search(value), f"{context}: timestamp lacks tz suffix: {value!r}"
async def _init_sqlite(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'tz.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return get_session_factory()
async def _cleanup():
from deerflow.persistence.engine import close_engine
await close_engine()
@pytest.mark.anyio
async def test_thread_meta_emits_tz_aware_timestamps(tmp_path):
from deerflow.persistence.thread_meta import ThreadMetaRepository
repo = ThreadMetaRepository(await _init_sqlite(tmp_path))
try:
created = await repo.create("t-tz", user_id="u1", display_name="tz")
_assert_tz_aware(created["created_at"], context="thread_meta.create.created_at")
_assert_tz_aware(created["updated_at"], context="thread_meta.create.updated_at")
# Second read from DB exercises the same _row_to_dict path on a
# value that SQLite has round-tripped (where tzinfo is lost).
fetched = await repo.get("t-tz", user_id="u1")
_assert_tz_aware(fetched["created_at"], context="thread_meta.get.created_at")
_assert_tz_aware(fetched["updated_at"], context="thread_meta.get.updated_at")
listed = await repo.search(user_id="u1")
assert listed, "search must return the created row"
_assert_tz_aware(listed[0]["created_at"], context="thread_meta.search.created_at")
_assert_tz_aware(listed[0]["updated_at"], context="thread_meta.search.updated_at")
finally:
await _cleanup()
@pytest.mark.anyio
async def test_run_repository_emits_tz_aware_timestamps(tmp_path):
from deerflow.persistence.run import RunRepository
repo = RunRepository(await _init_sqlite(tmp_path))
try:
await repo.put("r-tz", thread_id="t-tz", user_id="u1")
row = await repo.get("r-tz", user_id="u1")
_assert_tz_aware(row["created_at"], context="run.get.created_at")
_assert_tz_aware(row["updated_at"], context="run.get.updated_at")
finally:
await _cleanup()
@pytest.mark.anyio
async def test_feedback_repository_emits_tz_aware_timestamps(tmp_path):
from deerflow.persistence.feedback import FeedbackRepository
repo = FeedbackRepository(await _init_sqlite(tmp_path))
try:
record = await repo.create(run_id="r-tz", thread_id="t-tz", rating=1, user_id="u1")
_assert_tz_aware(record["created_at"], context="feedback.create.created_at")
finally:
await _cleanup()
@pytest.mark.anyio
async def test_run_event_store_emits_tz_aware_timestamps(tmp_path):
from deerflow.runtime.events.store.db import DbRunEventStore
store = DbRunEventStore(await _init_sqlite(tmp_path))
try:
await store.put(
thread_id="t-tz",
run_id="r-tz",
event_type="log",
category="log",
content="hello",
)
events = await store.list_events("t-tz", "r-tz", user_id=None)
assert events, "expected at least one event"
_assert_tz_aware(events[0]["created_at"], context="run_event.list.created_at")
finally:
await _cleanup()
+14 -8
View File
@@ -92,12 +92,19 @@ class TestBuildVolumeMounts:
userdata_mount = mounts[1]
assert userdata_mount.sub_path is None
def test_pvc_sets_subpath(self, provisioner_module):
"""PVC mode should set sub_path to threads/{thread_id}/user-data."""
def test_pvc_sets_user_scoped_subpath(self, provisioner_module):
"""PVC mode should include user_id in the user-data subPath."""
provisioner_module.USERDATA_PVC_NAME = "my-pvc"
mounts = provisioner_module._build_volume_mounts("thread-42", user_id="user-7")
userdata_mount = mounts[1]
assert userdata_mount.sub_path == "deer-flow/users/user-7/threads/thread-42/user-data"
def test_pvc_defaults_to_default_user_subpath(self, provisioner_module):
"""Older callers should still land under a stable default user namespace."""
provisioner_module.USERDATA_PVC_NAME = "my-pvc"
mounts = provisioner_module._build_volume_mounts("thread-42")
userdata_mount = mounts[1]
assert userdata_mount.sub_path == "threads/thread-42/user-data"
assert userdata_mount.sub_path == "deer-flow/users/default/threads/thread-42/user-data"
def test_skills_mount_read_only(self, provisioner_module):
"""Skills mount should always be read-only."""
@@ -146,13 +153,12 @@ class TestBuildPodVolumes:
pod = provisioner_module._build_pod("sandbox-1", "thread-1")
assert len(pod.spec.containers[0].volume_mounts) == 2
def test_pod_pvc_mode(self, provisioner_module):
"""Pod should use PVC volumes when PVC names are configured."""
def test_pod_pvc_mode_uses_user_scoped_subpath(self, provisioner_module):
"""Pod should use a user-scoped subPath for PVC user-data."""
provisioner_module.SKILLS_PVC_NAME = "skills-pvc"
provisioner_module.USERDATA_PVC_NAME = "userdata-pvc"
pod = provisioner_module._build_pod("sandbox-1", "thread-1")
pod = provisioner_module._build_pod("sandbox-1", "thread-1", user_id="user-7")
assert pod.spec.volumes[0].persistent_volume_claim is not None
assert pod.spec.volumes[1].persistent_volume_claim is not None
# subPath should be set on user-data mount
userdata_mount = pod.spec.containers[0].volume_mounts[1]
assert userdata_mount.sub_path == "threads/thread-1/user-data"
assert userdata_mount.sub_path == "deer-flow/users/user-7/threads/thread-1/user-data"
+25 -1
View File
@@ -144,7 +144,11 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch):
def mock_post(url: str, json: dict, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"}
assert json == {
"sandbox_id": "abc123",
"thread_id": "thread-1",
"user_id": "test-user-autouse",
}
assert timeout == 30
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
@@ -155,6 +159,26 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch):
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_create_accepts_anonymous_thread_id(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert json == {
"sandbox_id": "anon123",
"thread_id": None,
"user_id": "test-user-autouse",
}
assert timeout == 30
return _StubResponse(payload={"sandbox_id": "anon123", "sandbox_url": "http://k3s:31002"})
monkeypatch.setattr(requests, "post", mock_post)
info = backend.create(None, "anon123")
assert info.sandbox_id == "anon123"
assert info.sandbox_url == "http://k3s:31002"
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
+33
View File
@@ -268,6 +268,39 @@ class TestEdgeCases:
class TestDbRunEventStore:
"""Tests for DbRunEventStore with temp SQLite."""
@pytest.mark.anyio
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
from sqlalchemy.dialects import postgresql
from deerflow.runtime.events.store.db import DbRunEventStore
class FakeSession:
def __init__(self):
self.dialect = postgresql.dialect()
self.execute_calls = []
self.scalar_stmt = None
def get_bind(self):
return self
async def execute(self, stmt, params=None):
self.execute_calls.append((stmt, params))
async def scalar(self, stmt):
self.scalar_stmt = stmt
return 41
session = FakeSession()
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
assert max_seq == 41
assert session.execute_calls
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
assert "FOR UPDATE" not in compiled
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
+435
View File
@@ -339,6 +339,99 @@ class TestConvenienceFields:
data = j.get_completion_data()
assert data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_completion_data_counts_human_ai_and_tool_messages(self, journal_setup):
from langchain_core.messages import HumanMessage, ToolMessage
j, _ = journal_setup
j.on_chat_model_start({}, [[HumanMessage(content="Question")]], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_tool_end(ToolMessage(content="Tool result", tool_call_id="call_1", name="search"), run_id=uuid4())
data = j.get_completion_data()
assert data["message_count"] == 3
assert data["first_human_message"] == "Question"
assert data["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_tool_call_only_ai_does_not_clear_last_ai_message(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(_make_llm_response("Useful answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(
_make_llm_response("", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["message_count"] == 2
assert data["last_ai_message"] == "Useful answer"
@pytest.mark.anyio
async def test_last_ai_message_extracts_mixed_content_without_extra_newlines(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(
_make_llm_response(
[
{"type": "text", "text": "First "},
{"type": "text", "content": "second"},
" third",
{"type": "image", "url": "ignored"},
]
),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["message_count"] == 1
assert data["last_ai_message"] == "First second third"
@pytest.mark.anyio
async def test_last_ai_message_extracts_mapping_content(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(_make_llm_response({"content": "Nested answer"}), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
data = j.get_completion_data()
assert data["message_count"] == 1
assert data["last_ai_message"] == "Nested answer"
@pytest.mark.anyio
async def test_duplicate_llm_run_id_does_not_double_count_message_summary(self, journal_setup):
j, _ = journal_setup
run_id = uuid4()
j.on_llm_end(_make_llm_response("Answer", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(
_make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=run_id,
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["message_count"] == 1
assert data["last_ai_message"] == "Answer"
assert data["total_tokens"] == 15
@pytest.mark.anyio
async def test_subagent_ai_does_not_overwrite_lead_last_ai_message(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(_make_llm_response("Lead answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Subagent detail"), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
data = j.get_completion_data()
assert data["message_count"] == 2
assert data["last_ai_message"] == "Lead answer"
@pytest.mark.anyio
async def test_get_completion_data(self, journal_setup):
j, _ = journal_setup
@@ -383,6 +476,348 @@ class TestMiddlewareEvents:
assert "middleware:guardrail" in event_types
class TestCallerBucketing:
"""Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware)."""
def test_lead_agent_bucketing(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 0
assert j._middleware_tokens == 0
def test_subagent_bucketing(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
assert j._subagent_tokens == 30
assert j._lead_agent_tokens == 0
assert j._middleware_tokens == 0
def test_middleware_bucketing(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7}
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"])
assert j._middleware_tokens == 7
assert j._lead_agent_tokens == 0
assert j._subagent_tokens == 0
def test_mixed_callers_sum_independently(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 15
assert j._middleware_tokens == 15
assert j._total_tokens == 45
def test_get_completion_data_includes_buckets(self, journal_setup):
j, _ = journal_setup
j._lead_agent_tokens = 100
j._subagent_tokens = 200
j._middleware_tokens = 50
data = j.get_completion_data()
assert data["lead_agent_tokens"] == 100
assert data["subagent_tokens"] == 200
assert data["middleware_tokens"] == 50
def test_dedup_same_run_id(self, journal_setup):
"""Same langchain run_id in on_llm_end must not double-count."""
j, _ = journal_setup
run_id = uuid4()
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
assert j._total_tokens == 15
assert j._lead_agent_tokens == 15
assert j._llm_call_count == 1
def test_first_no_usage_second_with_usage(self, journal_setup):
"""First callback with no usage must not block second callback with usage for same run_id."""
j, _ = journal_setup
run_id = uuid4()
j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
assert str(run_id) not in j._counted_llm_run_ids
# Second callback for the same run_id with actual usage must still count
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
assert j._total_tokens == 15
assert j._lead_agent_tokens == 15
def test_track_token_usage_false_skips_buckets(self):
"""When token tracking is disabled, caller buckets stay at 0."""
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
assert j._subagent_tokens == 0
assert j._lead_agent_tokens == 0
def test_default_no_tags_buckets_as_lead_agent(self, journal_setup):
"""LLM calls without explicit tags default to lead_agent bucket."""
j, _ = journal_setup
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None)
assert j._lead_agent_tokens == 10
assert j._subagent_tokens == 0
assert j._middleware_tokens == 0
def test_unknown_tag_buckets_as_lead_agent(self, journal_setup):
"""Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent."""
j, _ = journal_setup
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"])
assert j._lead_agent_tokens == 10
class TestExternalUsageRecords:
"""Tests for record_external_llm_usage_records."""
def test_records_added_to_subagent_bucket(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-1",
"caller": "subagent:general-purpose",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
}
]
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 150
assert j._total_tokens == 150
assert j._total_input_tokens == 100
assert j._total_output_tokens == 50
def test_records_added_to_middleware_bucket(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-2",
"caller": "middleware:summarize",
"input_tokens": 30,
"output_tokens": 10,
"total_tokens": 40,
}
]
j.record_external_llm_usage_records(records)
assert j._middleware_tokens == 40
assert j._lead_agent_tokens == 0
assert j._subagent_tokens == 0
def test_records_added_to_lead_agent_bucket(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-3",
"caller": "lead_agent",
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
}
]
j.record_external_llm_usage_records(records)
assert j._lead_agent_tokens == 15
def test_dedup_same_source_run_id(self, journal_setup):
"""Same source_run_id must not be double-counted."""
j, _ = journal_setup
records = [
{
"source_run_id": "dup-1",
"caller": "subagent:research",
"input_tokens": 50,
"output_tokens": 25,
"total_tokens": 75,
}
]
j.record_external_llm_usage_records(records)
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 75
assert j._total_tokens == 75
def test_total_tokens_missing_computed_from_input_output(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-4",
"caller": "subagent:bash",
"input_tokens": 200,
"output_tokens": 100,
"total_tokens": 0,
}
]
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 300
assert j._total_tokens == 300
def test_total_tokens_zero_no_count(self, journal_setup):
"""Records with zero total and zero input+output must not be counted."""
j, _ = journal_setup
records = [
{
"source_run_id": "ext-5",
"caller": "subagent:research",
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}
]
j.record_external_llm_usage_records(records)
assert j._total_tokens == 0
assert j._subagent_tokens == 0
def test_empty_source_run_id_skipped(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "",
"caller": "subagent:research",
"input_tokens": 50,
"output_tokens": 25,
"total_tokens": 75,
}
]
j.record_external_llm_usage_records(records)
assert j._total_tokens == 0
def test_multiple_records_in_single_call(self, journal_setup):
j, _ = journal_setup
records = [
{"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
{"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
]
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 45
assert j._total_tokens == 45
def test_external_records_coexist_with_inline_callbacks(self, journal_setup):
"""External records and inline on_llm_end must not interfere."""
j, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 150
assert j._total_tokens == 165
def test_track_token_usage_false_skips_external_records(self):
"""When token tracking is disabled, external records must not accumulate."""
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
assert j._total_tokens == 0
assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message."""
+736 -6
View File
@@ -1,10 +1,17 @@
"""Tests for RunManager."""
import asyncio
import logging
import re
import sqlite3
from typing import Any
import pytest
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime import DisconnectMode, RunManager, RunStatus
from deerflow.runtime.runs.manager import PersistenceRetryPolicy
from deerflow.runtime.runs.store.memory import MemoryRunStore
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@@ -14,6 +21,92 @@ def manager() -> RunManager:
return RunManager()
class FlakyStatusRunStore(MemoryRunStore):
"""Memory run store that simulates transient SQLite status-write failures."""
def __init__(self, *, status_failures: int) -> None:
super().__init__()
self.status_failures = status_failures
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
if self.status_failures > 0:
self.status_failures -= 1
raise sqlite3.OperationalError("database is locked")
return await super().update_status(run_id, status, error=error)
class MissingRowStatusRunStore(MemoryRunStore):
"""Memory run store that reports a missing row for status updates."""
async def update_status(self, run_id, status, *, error=None):
await super().update_status(run_id, status, error=error)
return False
class PermanentStatusRunStore(MemoryRunStore):
"""Memory run store that simulates a permanent SQLAlchemy write failure."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise SQLAlchemyDatabaseError(
"UPDATE runs SET status = :status WHERE run_id = :run_id",
{"status": status, "run_id": run_id},
sqlite3.DatabaseError("no such table: runs"),
)
class FailingStatusRunStore(MemoryRunStore):
"""Memory run store that always fails status updates."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise sqlite3.OperationalError("database is locked")
class MissingCompletionRunStore(MemoryRunStore):
"""Memory run store that reports one missing row for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
if self.completion_update_attempts == 1:
return False
return await super().update_run_completion(run_id, status=status, **kwargs)
class AlwaysMissingCompletionRunStore(MemoryRunStore):
"""Memory run store that keeps reporting missing rows for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
return False
async def _stored_statuses(store: MemoryRunStore, *run_ids: str) -> dict[str, Any]:
rows = {}
for run_id in run_ids:
row = await store.get(run_id)
rows[run_id] = row["status"] if row else None
return rows
@pytest.mark.anyio
async def test_create_and_get(manager: RunManager):
"""Created run should be retrievable with new fields."""
@@ -33,7 +126,7 @@ async def test_create_and_get(manager: RunManager):
assert ISO_RE.match(record.created_at)
assert ISO_RE.match(record.updated_at)
fetched = manager.get(record.run_id)
fetched = await manager.get(record.run_id)
assert fetched is record
@@ -63,6 +156,171 @@ async def test_cancel(manager: RunManager):
assert record.status == RunStatus.interrupted
@pytest.mark.anyio
async def test_cancel_persists_interrupted_status_to_store():
"""Cancel should persist interrupted status to the backing store."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
cancelled = await manager.cancel(record.run_id)
stored = await store.get(record.run_id)
assert cancelled is True
assert stored is not None
assert stored["status"] == "interrupted"
@pytest.mark.anyio
async def test_status_persistence_retries_transient_sqlite_lock():
"""Transient SQLite lock errors should not leave a final status stale."""
store = FlakyStatusRunStore(status_failures=2)
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert store.status_update_attempts >= 4
@pytest.mark.anyio
async def test_status_persistence_recreates_missing_store_row():
"""A final status update should recreate a run row if initial persistence was lost."""
store = MissingRowStatusRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.delete(record.run_id)
await manager.set_status(record.run_id, RunStatus.error, error="boom")
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "error"
assert stored["error"] == "boom"
@pytest.mark.anyio
async def test_status_persistence_does_not_retry_permanent_sqlalchemy_errors():
"""Permanent SQLAlchemy failures should not be retried as SQLite pressure."""
store = PermanentStatusRunStore()
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=5, initial_delay=0),
)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.error, error="boom")
assert store.status_update_attempts == 1
@pytest.mark.anyio
async def test_completion_persistence_recreates_missing_store_row():
"""Completion updates should recreate a missing row and persist final counters."""
store = MissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
await store.delete(record.run_id)
await manager.update_run_completion(
record.run_id,
status="success",
total_tokens=42,
llm_call_count=2,
last_ai_message="done",
)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert stored["total_tokens"] == 42
assert stored["llm_call_count"] == 2
assert stored["last_ai_message"] == "done"
assert store.completion_update_attempts == 2
@pytest.mark.anyio
async def test_completion_persistence_warns_when_recreated_row_still_missing(caplog):
"""A second zero-row completion update after recreation should not be silent."""
store = AlwaysMissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.success)
caplog.set_level(logging.WARNING, logger="deerflow.runtime.runs.manager")
await manager.update_run_completion(record.run_id, status="success", total_tokens=42)
assert store.completion_update_attempts == 2
assert "affected no rows after row recreation" in caplog.text
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_marks_stale_rows_error():
"""Startup recovery should turn persisted active rows into explicit errors."""
store = MemoryRunStore()
await store.put("pending-run", thread_id="thread-1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:01+00:00")
await store.put("success-run", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:02+00:00")
manager = RunManager(store=store)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:02+00:00",
)
assert {record.run_id for record in recovered} == {"pending-run", "running-run"}
assert await _stored_statuses(store, "pending-run", "running-run", "success-run") == {
"pending-run": "error",
"running-run": "error",
"success-run": "success",
}
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_live_local_run():
"""Startup recovery should not mark an active row orphaned when this worker owns it."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
)
stored = await store.get(record.run_id)
assert recovered == []
assert stored["status"] == "running"
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_rows_when_error_status_is_not_persisted():
"""Startup recovery must not report a row as recovered if the error update failed."""
store = FailingStatusRunStore()
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:00+00:00")
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=2, initial_delay=0),
)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:01+00:00",
)
stored = await store.get("running-run")
assert recovered == []
assert stored["status"] == "running"
assert store.status_update_attempts == 2
@pytest.mark.anyio
async def test_cancel_not_inflight(manager: RunManager):
"""Cancelling a completed run should return False."""
@@ -82,8 +340,9 @@ async def test_list_by_thread(manager: RunManager):
runs = await manager.list_by_thread("thread-1")
assert len(runs) == 2
assert runs[0].run_id == r1.run_id
assert runs[1].run_id == r2.run_id
# Newest first: r2 was created after r1.
assert runs[0].run_id == r2.run_id
assert runs[1].run_id == r1.run_id
@pytest.mark.anyio
@@ -115,7 +374,7 @@ async def test_cleanup(manager: RunManager):
run_id = record.run_id
await manager.cleanup(run_id, delay=0)
assert manager.get(run_id) is None
assert await manager.get(run_id) is None
@pytest.mark.anyio
@@ -130,7 +389,191 @@ async def test_set_status_with_error(manager: RunManager):
@pytest.mark.anyio
async def test_get_nonexistent(manager: RunManager):
"""Getting a nonexistent run should return None."""
assert manager.get("does-not-exist") is None
assert await manager.get("does-not-exist") is None
@pytest.mark.anyio
async def test_get_hydrates_store_only_run():
"""Store-only runs should be readable after process restart."""
store = MemoryRunStore()
await store.put(
"run-store-only",
thread_id="thread-1",
assistant_id="lead_agent",
status="success",
multitask_strategy="reject",
metadata={"source": "store"},
kwargs={"input": "value"},
created_at="2026-01-01T00:00:00+00:00",
model_name="model-a",
)
manager = RunManager(store=store)
record = await manager.get("run-store-only")
assert record is not None
assert record.run_id == "run-store-only"
assert record.thread_id == "thread-1"
assert record.assistant_id == "lead_agent"
assert record.status == RunStatus.success
assert record.on_disconnect == DisconnectMode.cancel
assert record.metadata == {"source": "store"}
assert record.kwargs == {"input": "value"}
assert record.model_name == "model-a"
assert record.task is None
assert record.store_only is True
@pytest.mark.anyio
async def test_get_hydrates_run_with_null_enum_fields():
"""Rows with NULL status/on_disconnect must hydrate with safe defaults, not raise."""
store = MemoryRunStore()
# Simulate a SQL row where the nullable status column is NULL
await store.put(
"run-null-status",
thread_id="thread-1",
status=None,
created_at="2026-01-01T00:00:00+00:00",
)
manager = RunManager(store=store)
record = await manager.get("run-null-status")
assert record is not None
assert record.status == RunStatus.pending
assert record.on_disconnect == DisconnectMode.cancel
assert record.store_only is True
@pytest.mark.anyio
async def test_list_by_thread_hydrates_run_with_null_enum_fields():
"""list_by_thread must not skip rows with NULL status; applies safe defaults."""
store = MemoryRunStore()
await store.put(
"run-null-status-list",
thread_id="thread-null",
status=None,
created_at="2026-01-01T00:00:00+00:00",
)
manager = RunManager(store=store)
runs = await manager.list_by_thread("thread-null")
assert len(runs) == 1
assert runs[0].run_id == "run-null-status-list"
assert runs[0].status == RunStatus.pending
assert runs[0].on_disconnect == DisconnectMode.cancel
@pytest.mark.anyio
async def test_create_record_is_not_store_only(manager: RunManager):
"""In-memory records created via create() must have store_only=False."""
record = await manager.create("thread-1")
assert record.store_only is False
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_failure():
"""create() must fail and hide the run when the initial store write fails."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.put = AsyncMock(side_effect=RuntimeError("db down"))
manager = RunManager(store=store)
with pytest.raises(RuntimeError, match="db down"):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
"""create() must also roll back when cancelled during the initial store write."""
store = MemoryRunStore()
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
manager = RunManager(store=store)
with pytest.raises(asyncio.CancelledError):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_does_not_expose_run_until_store_persist_completes():
"""Concurrent readers must wait until the new run has been persisted."""
store = MemoryRunStore()
manager = RunManager(store=store)
original_put = store.put
put_started = asyncio.Event()
allow_put = asyncio.Event()
async def blocking_put(run_id, **kwargs):
put_started.set()
await allow_put.wait()
return await original_put(run_id, **kwargs)
store.put = blocking_put
create_task = asyncio.create_task(manager.create("thread-1"))
list_task = None
try:
await put_started.wait()
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
await asyncio.sleep(0)
assert not list_task.done()
allow_put.set()
record = await create_task
runs = await list_task
assert [run.run_id for run in runs] == [record.run_id]
finally:
allow_put.set()
cleanup_tasks = []
for task in (list_task, create_task):
if task is None:
continue
if not task.done():
task.cancel()
cleanup_tasks.append(task)
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@pytest.mark.anyio
async def test_get_prefers_in_memory_record_over_store():
"""In-memory records retain task/control state when store has same run."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.update_status(record.run_id, "success")
fetched = await manager.get(record.run_id)
assert fetched is record
assert fetched.status == RunStatus.pending
@pytest.mark.anyio
async def test_list_by_thread_merges_store_runs_newest_first():
"""list_by_thread should merge memory and store rows with memory precedence."""
store = MemoryRunStore()
await store.put("old-store", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:00+00:00")
await store.put("other-thread", thread_id="thread-2", status="success", created_at="2026-01-03T00:00:00+00:00")
manager = RunManager(store=store)
memory_record = await manager.create("thread-1")
runs = await manager.list_by_thread("thread-1")
assert [run.run_id for run in runs] == [memory_record.run_id, "old-store"]
assert runs[0] is memory_record
@pytest.mark.anyio
@@ -141,3 +584,290 @@ async def test_create_defaults(manager: RunManager):
assert record.kwargs == {}
assert record.multitask_strategy == "reject"
assert record.assistant_id is None
@pytest.mark.anyio
async def test_model_name_create_or_reject():
"""create_or_reject should accept and persist model_name."""
from deerflow.runtime.runs.schemas import DisconnectMode
store = MemoryRunStore()
mgr = RunManager(store=store)
record = await mgr.create_or_reject(
"thread-1",
assistant_id="lead_agent",
on_disconnect=DisconnectMode.cancel,
metadata={"key": "val"},
kwargs={"input": {}},
multitask_strategy="reject",
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
)
assert record.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
assert record.status == RunStatus.pending
# Verify model_name was persisted to store
stored = await store.get(record.run_id)
assert stored is not None
assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0"
# Verify retrieval returns the model_name via in-memory record
fetched = await mgr.get(record.run_id)
assert fetched is not None
assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
@pytest.mark.anyio
async def test_create_or_reject_interrupt_persists_interrupted_status_to_store():
"""interrupt strategy should persist interrupted status for old runs."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
new = await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert new.run_id != old.run_id
assert old.status == RunStatus.interrupted
assert stored_old is not None
assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
"""A failed new-run persist must not cancel the existing inflight run."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
store.put = AsyncMock(side_effect=RuntimeError("db down"))
with pytest.raises(RuntimeError, match="db down"):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
"""Cancellation during new-run persist must not cancel the existing run."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
with pytest.raises(asyncio.CancelledError):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
"""rollback strategy should persist interrupted status for old runs."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
new = await manager.create_or_reject("thread-1", multitask_strategy="rollback")
stored_old = await store.get(old.run_id)
assert new.run_id != old.run_id
assert old.status == RunStatus.interrupted
assert stored_old is not None
assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_model_name_default_is_none():
"""create_or_reject without model_name should default to None."""
from deerflow.runtime.runs.schemas import DisconnectMode
store = MemoryRunStore()
mgr = RunManager(store=store)
record = await mgr.create_or_reject(
"thread-1",
on_disconnect=DisconnectMode.cancel,
model_name=None,
)
assert record.model_name is None
stored = await store.get(record.run_id)
assert stored["model_name"] is None
# ---------------------------------------------------------------------------
# Store fallback tests (simulates gateway restart scenario)
# ---------------------------------------------------------------------------
@pytest.fixture
def manager_with_store() -> RunManager:
"""RunManager backed by a MemoryRunStore."""
return RunManager(store=MemoryRunStore())
@pytest.mark.anyio
async def test_list_by_thread_returns_store_records_after_restart(manager_with_store: RunManager):
"""After in-memory state is cleared (simulating restart), list_by_thread
should still return runs from the persistent store."""
mgr = manager_with_store
r1 = await mgr.create("thread-1", "agent-1")
await mgr.set_status(r1.run_id, RunStatus.success)
r2 = await mgr.create("thread-1", "agent-2")
await mgr.set_status(r2.run_id, RunStatus.error, error="boom")
# Clear in-memory dict to simulate a restart
mgr._runs.clear()
runs = await mgr.list_by_thread("thread-1")
assert len(runs) == 2
statuses = {r.run_id: r.status for r in runs}
assert statuses[r1.run_id] == RunStatus.success
assert statuses[r2.run_id] == RunStatus.error
# Verify other fields survive the round-trip
for r in runs:
assert r.thread_id == "thread-1"
assert ISO_RE.match(r.created_at)
@pytest.mark.anyio
async def test_list_by_thread_merges_in_memory_and_store(manager_with_store: RunManager):
"""In-memory runs should be included alongside store-only records."""
mgr = manager_with_store
# Create a run and let it complete (will be in both memory and store)
r1 = await mgr.create("thread-1")
await mgr.set_status(r1.run_id, RunStatus.success)
# Simulate restart: clear memory, then create a new in-memory run
mgr._runs.clear()
r2 = await mgr.create("thread-1")
runs = await mgr.list_by_thread("thread-1")
assert len(runs) == 2
run_ids = {r.run_id for r in runs}
assert r1.run_id in run_ids
assert r2.run_id in run_ids
# r2 should be the in-memory record (has live state)
r2_record = next(r for r in runs if r.run_id == r2.run_id)
assert r2_record is r2 # same object reference
@pytest.mark.anyio
async def test_list_by_thread_no_store():
"""Without a store, list_by_thread should only return in-memory runs."""
mgr = RunManager()
await mgr.create("thread-1")
mgr._runs.clear()
runs = await mgr.list_by_thread("thread-1")
assert runs == []
@pytest.mark.anyio
async def test_aget_returns_in_memory_record(manager_with_store: RunManager):
"""aget should return the in-memory record when available."""
mgr = manager_with_store
r1 = await mgr.create("thread-1", "agent-1")
result = await mgr.aget(r1.run_id)
assert result is r1 # same object
@pytest.mark.anyio
async def test_aget_falls_back_to_store(manager_with_store: RunManager):
"""aget should return a record from the store when not in memory."""
mgr = manager_with_store
r1 = await mgr.create("thread-1", "agent-1")
await mgr.set_status(r1.run_id, RunStatus.success)
mgr._runs.clear()
result = await mgr.aget(r1.run_id)
assert result is not None
assert result.run_id == r1.run_id
assert result.status == RunStatus.success
assert result.thread_id == "thread-1"
assert result.assistant_id == "agent-1"
@pytest.mark.anyio
async def test_aget_falls_back_to_store_with_user_filter():
"""aget should honor user_id when reading store-only records."""
store = MemoryRunStore()
await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success")
mgr = RunManager(store=store)
allowed = await mgr.aget("run-1", user_id="user-1")
denied = await mgr.aget("run-1", user_id="user-2")
assert allowed is not None
assert denied is None
@pytest.mark.anyio
async def test_aget_returns_none_for_unknown(manager_with_store: RunManager):
"""aget should return None for a run ID that doesn't exist anywhere."""
result = await manager_with_store.aget("nonexistent-run-id")
assert result is None
@pytest.mark.anyio
async def test_aget_store_failure_is_graceful():
"""If the store raises, aget should return None instead of propagating."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.get = AsyncMock(side_effect=RuntimeError("db down"))
mgr = RunManager(store=store)
result = await mgr.aget("some-id")
assert result is None
@pytest.mark.anyio
async def test_list_by_thread_store_failure_is_graceful():
"""If the store raises, list_by_thread should return only in-memory runs."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.list_by_thread = AsyncMock(side_effect=RuntimeError("db down"))
mgr = RunManager(store=store)
r1 = await mgr.create("thread-1")
runs = await mgr.list_by_thread("thread-1")
assert len(runs) == 1
assert runs[0].run_id == r1.run_id
@pytest.mark.anyio
async def test_list_by_thread_falls_back_to_store_with_user_filter():
"""list_by_thread should return only the requesting user's store records."""
store = MemoryRunStore()
await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success")
await store.put("run-2", thread_id="thread-1", user_id="user-2", status="success")
mgr = RunManager(store=store)
runs = await mgr.list_by_thread("thread-1", user_id="user-1")
assert [r.run_id for r in runs] == ["run-1"]
+34
View File
@@ -0,0 +1,34 @@
from deerflow.runtime.runs.naming import resolve_root_run_name
def test_resolve_root_run_name_from_context_agent_name():
assert resolve_root_run_name({"context": {"agent_name": "finalis"}}, "lead_agent") == "finalis"
def test_resolve_root_run_name_from_configurable_agent_name():
assert resolve_root_run_name({"configurable": {"agent_name": "finalis"}}, "lead_agent") == "finalis"
def test_resolve_root_run_name_falls_back_to_assistant_id():
assert resolve_root_run_name({}, "my-agent") == "my-agent"
def test_resolve_root_run_name_falls_back_to_lead_agent():
assert resolve_root_run_name({}, None) == "lead_agent"
def test_resolve_root_run_name_prefers_context_over_configurable():
config = {
"context": {"agent_name": "ctx-agent"},
"configurable": {"agent_name": "cfg-agent"},
}
assert resolve_root_run_name(config, "lead_agent") == "ctx-agent"
def test_resolve_root_run_name_ignores_blank_agent_name():
assert resolve_root_run_name({"context": {"agent_name": " "}}, "my-agent") == "my-agent"
def test_resolve_root_run_name_ignores_non_string_agent_name():
assert resolve_root_run_name({"context": {"agent_name": None}}, "my-agent") == "my-agent"
+349 -2
View File
@@ -3,9 +3,14 @@
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import re
import pytest
from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path):
@@ -22,6 +27,45 @@ async def _cleanup():
await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def list_inflight(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository:
@pytest.mark.anyio
async def test_put_and_get(self, tmp_path):
@@ -34,6 +78,19 @@ class TestRunRepository:
assert row["status"] == "pending"
await _cleanup()
@pytest.mark.anyio
async def test_put_is_idempotent_for_retried_writes(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", assistant_id="old-agent", status="pending")
await repo.put("r1", thread_id="t1", assistant_id="new-agent", status="running", error="retry")
row = await repo.get("r1")
assert row["assistant_id"] == "new-agent"
assert row["status"] == "running"
assert row["error"] == "retry"
await _cleanup()
@pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -44,11 +101,19 @@ class TestRunRepository:
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "running")
updated = await repo.update_status("r1", "running")
row = await repo.get("r1")
assert updated is True
assert row["status"] == "running"
await _cleanup()
@pytest.mark.anyio
async def test_update_status_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_status("missing", "error", error="lost")
assert updated is False
await _cleanup()
@pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -105,11 +170,24 @@ class TestRunRepository:
assert all(r["status"] == "pending" for r in pending)
await _cleanup()
@pytest.mark.anyio
async def test_list_inflight_returns_pending_and_running_before_cutoff(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("pending-old", thread_id="t1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await repo.put("running-old", thread_id="t1", status="running", created_at="2026-01-01T00:00:01+00:00")
await repo.put("success-old", thread_id="t1", status="success", created_at="2026-01-01T00:00:02+00:00")
await repo.put("pending-new", thread_id="t1", status="pending", created_at="2026-01-01T00:00:03+00:00")
inflight = await repo.list_inflight(before="2026-01-01T00:00:02+00:00")
assert [row["run_id"] for row in inflight] == ["pending-old", "running-old"]
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion(
updated = await repo.update_run_completion(
"r1",
status="success",
total_input_tokens=100,
@@ -124,6 +202,7 @@ class TestRunRepository:
first_human_message="What is the meaning?",
)
row = await repo.get("r1")
assert updated is True
assert row["status"] == "success"
assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2
@@ -133,6 +212,13 @@ class TestRunRepository:
assert row["first_human_message"] == "What is the meaning?"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_run_completion("missing", status="error", total_tokens=1)
assert updated is False
await _cleanup()
@pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -166,6 +252,69 @@ class TestRunRepository:
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -221,6 +370,28 @@ class TestRunRepository:
}
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
@@ -249,3 +420,179 @@ class TestRunRepository:
rows = await repo.list_by_thread("t1", user_id=None)
assert len(rows) == 2
await _cleanup()
@pytest.mark.anyio
async def test_model_name_persistence(self, tmp_path):
"""RunRepository should persist, normalize, and truncate model_name correctly via SQL."""
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
repo = RunRepository(get_session_factory())
await repo.put("run-1", thread_id="thread-1", model_name="gpt-4o")
row = await repo.get("run-1")
assert row is not None
assert row["model_name"] == "gpt-4o"
long_name = "a" * 200
await repo.put("run-2", thread_id="thread-1", model_name=long_name)
row2 = await repo.get("run-2")
assert row2["model_name"] == "a" * 128
await repo.put("run-3", thread_id="thread-1", model_name=123)
row3 = await repo.get("run-3")
assert row3["model_name"] == "123"
await repo.put("run-4", thread_id="thread-1", model_name=None)
row4 = await repo.get("run-4")
assert row4["model_name"] is None
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
captured = []
class FakeResult:
def all(self):
return []
class FakeSession:
async def execute(self, stmt):
captured.append(stmt)
return FakeResult()
class FakeSessionContext:
async def __aenter__(self):
return FakeSession()
async def __aexit__(self, exc_type, exc, tb):
return None
repo = RunRepository(lambda: FakeSessionContext())
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg == {
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_runs": 0,
"by_model": {},
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
}
assert len(captured) == 1
stmt = captured[0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
assert select_match is not None
assert group_by_match is not None
assert select_match.group(1) == group_by_match.group(1)
@pytest.mark.anyio
async def test_run_manager_hydrates_store_only_run_from_sql(self, tmp_path):
"""RunManager should hydrate historical runs from SQL-backed store."""
repo = await _make_repo(tmp_path)
await repo.put(
"sql-store-only",
thread_id="thread-1",
assistant_id="lead_agent",
status="success",
metadata={"source": "sql"},
kwargs={"input": "value"},
model_name="model-a",
)
manager = RunManager(store=repo)
record = await manager.get("sql-store-only")
rows = await manager.list_by_thread("thread-1")
assert record is not None
assert record.run_id == "sql-store-only"
assert record.status == RunStatus.success
assert record.metadata == {"source": "sql"}
assert record.kwargs == {"input": "value"}
assert record.model_name == "model-a"
assert [run.run_id for run in rows] == ["sql-store-only"]
await _cleanup()
@pytest.mark.anyio
async def test_run_manager_cancel_persists_interrupted_status_to_sql(self, tmp_path):
"""RunManager.cancel should write interrupted status to SQL-backed store."""
repo = await _make_repo(tmp_path)
manager = RunManager(store=repo)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
cancelled = await manager.cancel(record.run_id)
row = await repo.get(record.run_id)
assert cancelled is True
assert row is not None
assert row["status"] == "interrupted"
await _cleanup()
@pytest.mark.anyio
async def test_update_model_name(self, tmp_path):
"""RunRepository.update_model_name should update model_name for existing run."""
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", model_name="initial-model")
await repo.update_model_name("r1", "updated-model")
row = await repo.get("r1")
assert row["model_name"] == "updated-model"
await _cleanup()
@pytest.mark.anyio
async def test_update_model_name_normalizes_value(self, tmp_path):
"""RunRepository.update_model_name should normalize and truncate model_name."""
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
long_name = "a" * 200
await repo.update_model_name("r1", long_name)
row = await repo.get("r1")
assert row["model_name"] == "a" * 128
await _cleanup()
@pytest.mark.anyio
async def test_update_model_name_to_none(self, tmp_path):
"""RunRepository.update_model_name should allow setting model_name to None."""
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", model_name="initial-model")
await repo.update_model_name("r1", None)
row = await repo.get("r1")
assert row["model_name"] is None
await _cleanup()
@pytest.mark.anyio
async def test_run_manager_update_model_name_persists_to_sql(self, tmp_path):
"""RunManager.update_model_name should persist to SQL-backed store without integrity error."""
repo = await _make_repo(tmp_path)
manager = RunManager(store=repo)
record = await manager.create("thread-1")
await manager.update_model_name(record.run_id, "gpt-4o")
row = await repo.get(record.run_id)
assert row is not None
assert row["model_name"] == "gpt-4o"
await _cleanup()
@pytest.mark.anyio
async def test_run_manager_update_model_name_twice(self, tmp_path):
"""RunManager.update_model_name should support multiple updates."""
repo = await _make_repo(tmp_path)
manager = RunManager(store=repo)
record = await manager.create("thread-1")
await manager.update_model_name(record.run_id, "model-1")
await manager.update_model_name(record.run_id, "model-2")
row = await repo.get(record.run_id)
assert row["model_name"] == "model-2"
await _cleanup()
+105 -1
View File
@@ -88,11 +88,115 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
assert captured["factory_context"]["app_config"] is app_config
assert captured["astream_context"]["app_config"] is app_config
assert run_manager.get(record.run_id).status == RunStatus.success
fetched = await run_manager.get(record.run_id)
assert fetched is not None
assert fetched.status == RunStatus.success
bridge.publish_end.assert_awaited_once_with(record.run_id)
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
@pytest.mark.anyio
async def test_run_agent_defaults_root_run_name_from_assistant_id():
run_manager = RunManager()
record = await run_manager.create("thread-1", assistant_id="lead_agent")
bridge = SimpleNamespace(
publish=AsyncMock(),
publish_end=AsyncMock(),
cleanup=AsyncMock(),
)
captured: dict[str, object] = {}
class DummyAgent:
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
captured["astream_run_name"] = config["run_name"]
yield {"messages": []}
def factory(*, config):
captured["factory_run_name"] = config["run_name"]
return DummyAgent()
await run_agent(
bridge,
run_manager,
record,
ctx=RunContext(checkpointer=None),
agent_factory=factory,
graph_input={},
config={},
)
assert captured["factory_run_name"] == "lead_agent"
assert captured["astream_run_name"] == "lead_agent"
@pytest.mark.anyio
async def test_run_agent_defaults_root_run_name_from_context_agent_name():
run_manager = RunManager()
record = await run_manager.create("thread-1", assistant_id="lead_agent")
bridge = SimpleNamespace(
publish=AsyncMock(),
publish_end=AsyncMock(),
cleanup=AsyncMock(),
)
captured: dict[str, object] = {}
class DummyAgent:
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
captured["astream_run_name"] = config["run_name"]
yield {"messages": []}
def factory(*, config):
captured["factory_run_name"] = config["run_name"]
return DummyAgent()
await run_agent(
bridge,
run_manager,
record,
ctx=RunContext(checkpointer=None),
agent_factory=factory,
graph_input={},
config={"context": {"agent_name": "finalis"}},
)
assert captured["factory_run_name"] == "finalis"
assert captured["astream_run_name"] == "finalis"
@pytest.mark.anyio
async def test_run_agent_defaults_root_run_name_from_configurable_agent_name():
run_manager = RunManager()
record = await run_manager.create("thread-1", assistant_id="lead_agent")
bridge = SimpleNamespace(
publish=AsyncMock(),
publish_end=AsyncMock(),
cleanup=AsyncMock(),
)
captured: dict[str, object] = {}
class DummyAgent:
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
captured["astream_run_name"] = config["run_name"]
yield {"messages": []}
def factory(*, config):
captured["factory_run_name"] = config["run_name"]
return DummyAgent()
await run_agent(
bridge,
run_manager,
record,
ctx=RunContext(checkpointer=None),
agent_factory=factory,
graph_input={},
config={"configurable": {"agent_name": "finalis"}},
)
assert captured["factory_run_name"] == "finalis"
assert captured["astream_run_name"] == "finalis"
@pytest.mark.anyio
async def test_rollback_restores_snapshot_without_deleting_thread():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
+686
View File
@@ -0,0 +1,686 @@
"""HTTP/runtime lifecycle E2E tests for the Gateway-owned runs API.
These tests keep the external model out of scope while exercising the real
FastAPI app, auth middleware, lifespan-created runtime dependencies,
``start_run()``, ``run_agent()``, StreamBridge, checkpointer, run store, and
thread metadata store.
"""
from __future__ import annotations
import asyncio
import inspect
import json
import queue
import threading
import time
import uuid
from contextlib import suppress
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model
from langchain_core.messages import AIMessage, HumanMessage
pytestmark = pytest.mark.no_auto_user
_MINIMAL_CONFIG_YAML = """\
log_level: info
models:
- name: fake-test-model
display_name: Fake Test Model
use: langchain_openai:ChatOpenAI
model: gpt-4o-mini
api_key: $OPENAI_API_KEY
base_url: $OPENAI_API_BASE
sandbox:
use: deerflow.sandbox.local:LocalSandboxProvider
agents_api:
enabled: true
title:
enabled: false
memory:
enabled: false
database:
backend: sqlite
run_events:
backend: memory
"""
class _RunController:
"""Cross-thread controls for the fake async agent."""
def __init__(self) -> None:
self.started = threading.Event()
self.checkpoint_written = threading.Event()
self.cancelled = threading.Event()
self.release = threading.Event()
self.instances: list[_ScriptedAgent] = []
class _ScriptedAgent:
"""Deterministic runtime double for lifecycle-only tests.
This is intentionally not a full LangGraph graph. Tests that need
controllable blocking, cancellation, and rollback checkpoints use the small
``run_agent`` surface they exercise: ``astream()``, checkpointer/store
attachment, metadata, and interrupt node attributes. The real lead-agent
graph/tool dispatch path is covered separately by
``test_stream_run_executes_real_lead_agent_setup_agent_business_path``.
"""
def __init__(
self,
controller: _RunController,
*,
title: str,
answer: str,
block_after_first_chunk: bool = False,
) -> None:
self.controller = controller
self.title = title
self.answer = answer
self.block_after_first_chunk = block_after_first_chunk
self.checkpointer: Any | None = None
self.store: Any | None = None
self.metadata = {"model_name": "fake-test-model"}
self.interrupt_before_nodes = None
self.interrupt_after_nodes = None
self.model = FakeToolCallingModel(responses=[AIMessage(content=self.answer)])
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
del subgraphs
self.controller.started.set()
thread_id = _thread_id_from_config(config)
human_text = _last_human_text(graph_input)
human = HumanMessage(content=human_text)
ai = await self.model.ainvoke([human], config=config)
state = {"messages": [human.model_dump(), ai.model_dump()], "title": self.title}
if self.checkpointer is not None:
await _write_checkpoint(self.checkpointer, thread_id=thread_id, state=state)
self.controller.checkpoint_written.set()
yield _stream_item_for_mode(stream_mode, state)
if self.block_after_first_chunk:
try:
while not self.controller.release.is_set():
await asyncio.sleep(0.05)
except asyncio.CancelledError:
self.controller.cancelled.set()
raise
def _make_agent_factory(controller: _RunController, **agent_kwargs):
def factory(*, config):
del config
agent = _ScriptedAgent(controller, **agent_kwargs)
controller.instances.append(agent)
return agent
return factory
def _build_fake_setup_agent_model(agent_name: str):
"""Patch target for lead_agent.agent.create_chat_model.
The graph, tool registry, ToolNode dispatch, and setup_agent implementation
remain production code; this fake only replaces the external LLM call.
"""
def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel:
del args, kwargs
return build_single_tool_call_model(
tool_name="setup_agent",
tool_args={
"soul": f"# Runtime Business E2E\n\nAgent name: {agent_name}",
"description": "runtime lifecycle business path",
},
tool_call_id="call_runtime_business_1",
final_text=f"Created {agent_name} through the real setup_agent tool.",
)
return fake_create_chat_model
@pytest.fixture
def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
home = tmp_path / "deer-flow-home"
home.mkdir()
monkeypatch.setenv("DEER_FLOW_HOME", str(home))
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
staged_config = tmp_path / "config.yaml"
staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config))
staged_extensions_config = tmp_path / "extensions_config.json"
staged_extensions_config.write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(staged_extensions_config))
return home
def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None:
"""Clear runtime singletons that depend on this test's temporary config.
The Gateway app/lifespan path reads process-wide caches before wiring
request-scoped dependencies. These E2E tests stage a temporary
``config.yaml``/``extensions_config.json`` and ``DEER_FLOW_HOME``, so the
caches below must be reset before app creation:
- app_config / extensions_config: parsed config file caches.
- paths: ``DEER_FLOW_HOME``-derived filesystem paths.
- persistence.engine: SQLAlchemy engine/session factory for the sqlite dir.
- app.gateway.deps: cached local auth provider/repository.
A shared public reset helper would be cleaner long-term; this test keeps
the reset boundary explicit because the PR is focused on runtime lifecycle
coverage rather than config-cache API cleanup.
"""
from app.gateway import deps as deps_module
from deerflow.config import app_config as app_config_module
from deerflow.config import extensions_config as extensions_config_module
from deerflow.config import paths as paths_module
from deerflow.persistence import engine as engine_module
for module, attr, value in (
(app_config_module, "_app_config", None),
(app_config_module, "_app_config_path", None),
(app_config_module, "_app_config_mtime", None),
(app_config_module, "_app_config_is_custom", False),
(extensions_config_module, "_extensions_config", None),
(paths_module, "_paths_singleton", None),
(paths_module, "_paths", None),
(engine_module, "_engine", None),
(engine_module, "_session_factory", None),
(deps_module, "_cached_local_provider", None),
(deps_module, "_cached_repo", None),
):
monkeypatch.setattr(module, attr, value, raising=False)
def _preserve_process_config_singletons(monkeypatch: pytest.MonkeyPatch) -> None:
"""Restore config singletons mutated as a side effect of AppConfig loading.
``AppConfig.from_file()`` calls ``_apply_singleton_configs()``, which pushes
nested config sections into module-level caches used by middlewares, tool
selection, and runtime providers. Snapshotting those attributes with
``monkeypatch`` lets pytest restore the pre-test values during teardown, so
loading the isolated test config does not leak into later tests.
"""
from deerflow.config import (
acp_config,
agents_api_config,
checkpointer_config,
guardrails_config,
memory_config,
stream_bridge_config,
subagents_config,
summarization_config,
title_config,
tool_search_config,
)
for module, attr in (
(title_config, "_title_config"),
(summarization_config, "_summarization_config"),
(memory_config, "_memory_config"),
(agents_api_config, "_agents_api_config"),
(subagents_config, "_subagents_config"),
(tool_search_config, "_tool_search_config"),
(guardrails_config, "_guardrails_config"),
(checkpointer_config, "_checkpointer_config"),
(stream_bridge_config, "_stream_bridge_config"),
(acp_config, "_acp_agents"),
):
monkeypatch.setattr(module, attr, getattr(module, attr), raising=False)
@pytest.fixture
def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch):
_preserve_process_config_singletons(monkeypatch)
_reset_process_singletons(monkeypatch)
from deerflow.config import app_config as app_config_module
cfg = app_config_module.get_app_config()
cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db")
from app.gateway.app import create_app
return create_app()
def _register_user(client, *, email: str = "runtime-e2e@example.com") -> str:
response = client.post(
"/api/v1/auth/register",
json={"email": email, "password": "very-strong-password-123"},
)
assert response.status_code == 201, response.text
csrf_token = client.cookies.get("csrf_token")
assert csrf_token
return csrf_token
def _create_thread(client, csrf_token: str) -> str:
thread_id = str(uuid.uuid4())
response = client.post(
"/api/threads",
json={"thread_id": thread_id, "metadata": {"purpose": "runtime-lifecycle-e2e"}},
headers={"X-CSRF-Token": csrf_token},
)
assert response.status_code == 200, response.text
return thread_id
def _run_body(**overrides) -> dict[str, Any]:
body: dict[str, Any] = {
"assistant_id": "lead_agent",
"input": {"messages": [{"role": "user", "content": "Run lifecycle E2E prompt"}]},
"config": {"recursion_limit": 50},
"stream_mode": ["values"],
}
body.update(overrides)
return body
def _drain_stream(response, *, timeout: float = 10.0, max_bytes: int = 1024 * 1024) -> str:
chunks: queue.Queue[bytes | BaseException | object] = queue.Queue()
sentinel = object()
def read_stream() -> None:
try:
for chunk in response.iter_bytes():
chunks.put(chunk)
if b"event: end" in chunk:
break
except BaseException as exc: # pragma: no cover - reported in the main test thread
chunks.put(exc)
finally:
chunks.put(sentinel)
reader = threading.Thread(target=read_stream, daemon=True)
reader.start()
deadline = time.monotonic() + timeout
body = b""
while True:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise AssertionError(f"SSE stream did not finish within {timeout}s; transcript tail={body[-4000:].decode('utf-8', errors='replace')}")
try:
chunk = chunks.get(timeout=remaining)
except queue.Empty as exc:
raise AssertionError(f"SSE stream did not produce data within {timeout}s; transcript tail={body[-4000:].decode('utf-8', errors='replace')}") from exc
if chunk is sentinel:
break
if isinstance(chunk, BaseException):
raise AssertionError("SSE reader failed") from chunk
body += chunk
if b"event: end" in body:
break
if len(body) >= max_bytes:
raise AssertionError(f"SSE stream exceeded {max_bytes} bytes without event: end")
if b"event: end" not in body:
raise AssertionError(f"SSE stream closed before event: end; transcript tail={body[-4000:].decode('utf-8', errors='replace')}")
return body.decode("utf-8", errors="replace")
def _parse_sse(transcript: str) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
for raw_frame in transcript.split("\n\n"):
frame = raw_frame.strip()
if not frame or frame.startswith(":"):
continue
parsed: dict[str, Any] = {}
for line in frame.splitlines():
if line.startswith("event: "):
parsed["event"] = line.removeprefix("event: ")
elif line.startswith("data: "):
payload = line.removeprefix("data: ")
parsed["data"] = json.loads(payload)
elif line.startswith("id: "):
parsed["id"] = line.removeprefix("id: ")
if parsed:
events.append(parsed)
return events
def _run_id_from_response(response) -> str:
location = response.headers.get("content-location", "")
assert location, "run stream response must include Content-Location"
return location.rstrip("/").split("/")[-1]
def _wait_for_status(client, thread_id: str, run_id: str, status: str, *, timeout: float = 5.0) -> dict:
deadline = time.monotonic() + timeout
last: dict | None = None
while time.monotonic() < deadline:
response = client.get(f"/api/threads/{thread_id}/runs/{run_id}")
assert response.status_code == 200, response.text
last = response.json()
if last["status"] == status:
return last
time.sleep(0.05)
raise AssertionError(f"Run {run_id} did not reach {status!r}; last={last!r}")
def _thread_id_from_config(config: dict | None) -> str:
config = config or {}
context = config.get("context") if isinstance(config.get("context"), dict) else {}
configurable = config.get("configurable") if isinstance(config.get("configurable"), dict) else {}
thread_id = context.get("thread_id") or configurable.get("thread_id")
assert thread_id, f"runtime config did not contain thread_id: {config!r}"
return str(thread_id)
def _last_human_text(graph_input: dict) -> str:
messages = graph_input.get("messages") or []
if not messages:
return ""
last = messages[-1]
content = getattr(last, "content", last)
if isinstance(content, str):
return content
return str(content)
async def _write_checkpoint(checkpointer: Any, *, thread_id: str, state: dict[str, Any]) -> None:
from langgraph.checkpoint.base import empty_checkpoint
checkpoint = empty_checkpoint()
checkpoint["channel_values"] = dict(state)
checkpoint["channel_versions"] = {key: 1 for key in state}
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
metadata = {
"source": "loop",
"step": 1,
"writes": {"scripted_agent": {"title": state.get("title"), "message_count": len(state.get("messages", []))}},
"parents": {},
}
result = checkpointer.aput(config, checkpoint, metadata, {})
if inspect.isawaitable(result):
await result
def _stream_item_for_mode(stream_mode: Any, state: dict[str, Any]) -> Any:
if isinstance(stream_mode, list):
# ``run_agent`` passes a list when multiple modes/subgraphs are active.
return stream_mode[0], state
return state
def test_stream_run_completes_and_persists_runtime_state(isolated_app):
"""A streaming run should traverse the real runtime and leave state behind."""
from starlette.testclient import TestClient
controller = _RunController()
factory = _make_agent_factory(
controller,
title="Lifecycle E2E",
answer="Lifecycle complete.",
)
with (
patch("app.gateway.services.resolve_agent_factory", return_value=factory),
TestClient(isolated_app) as client,
):
csrf_token = _register_user(client)
thread_id = _create_thread(client, csrf_token)
with client.stream(
"POST",
f"/api/threads/{thread_id}/runs/stream",
json=_run_body(),
headers={"X-CSRF-Token": csrf_token},
) as response:
assert response.status_code == 200, response.read().decode()
run_id = _run_id_from_response(response)
transcript = _drain_stream(response)
events = _parse_sse(transcript)
assert [event["event"] for event in events] == ["metadata", "values", "end"]
assert events[0]["data"] == {"run_id": run_id, "thread_id": thread_id}
assert events[1]["data"]["title"] == "Lifecycle E2E"
assert events[1]["data"]["messages"][-1]["content"] == "Lifecycle complete."
run = client.get(f"/api/threads/{thread_id}/runs/{run_id}")
assert run.status_code == 200, run.text
assert run.json()["status"] == "success"
thread = client.get(f"/api/threads/{thread_id}")
assert thread.status_code == 200, thread.text
assert thread.json()["status"] == "idle"
assert thread.json()["values"]["title"] == "Lifecycle E2E"
messages = client.get(f"/api/threads/{thread_id}/runs/{run_id}/messages")
assert messages.status_code == 200, messages.text
message_events = messages.json()["data"]
event_types = [row["event_type"] for row in message_events]
assert "llm.human.input" in event_types
assert "llm.ai.response" in event_types
assert any(row["content"]["content"] == "Run lifecycle E2E prompt" for row in message_events if row["event_type"] == "llm.human.input")
assert any(row["content"]["content"] == "Lifecycle complete." for row in message_events if row["event_type"] == "llm.ai.response")
def test_stream_run_executes_real_lead_agent_setup_agent_business_path(isolated_app, isolated_deer_flow_home: Path):
"""A runtime stream should execute real lead-agent business code and tools."""
from starlette.testclient import TestClient
agent_name = "runtime-business-agent"
with (
patch(
"deerflow.agents.lead_agent.agent.create_chat_model",
new=_build_fake_setup_agent_model(agent_name),
),
TestClient(isolated_app) as client,
):
csrf_token = _register_user(client, email="business-e2e@example.com")
auth_user_id = client.get("/api/v1/auth/me").json()["id"]
thread_id = _create_thread(client, csrf_token)
body = _run_body(
input={
"messages": [
{
"role": "user",
"content": f"Create a custom agent named {agent_name}.",
}
]
},
context={
"agent_name": agent_name,
"is_bootstrap": True,
"thinking_enabled": False,
"is_plan_mode": False,
"subagent_enabled": False,
},
)
with client.stream(
"POST",
f"/api/threads/{thread_id}/runs/stream",
json=body,
headers={"X-CSRF-Token": csrf_token},
) as response:
assert response.status_code == 200, response.read().decode()
run_id = _run_id_from_response(response)
transcript = _drain_stream(response, timeout=20.0)
events = _parse_sse(transcript)
event_names = [event["event"] for event in events]
assert "metadata" in event_names
assert "error" not in event_names, transcript
assert event_names[-1] == "end"
run = _wait_for_status(client, thread_id, run_id, "success", timeout=10.0)
assert run["assistant_id"] == "lead_agent"
expected_soul = isolated_deer_flow_home / "users" / auth_user_id / "agents" / agent_name / "SOUL.md"
assert expected_soul.exists(), f"setup_agent did not write SOUL.md. tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}"
assert f"Agent name: {agent_name}" in expected_soul.read_text(encoding="utf-8")
assert not (isolated_deer_flow_home / "users" / "default" / "agents" / agent_name).exists()
def test_cancel_interrupt_stops_running_background_run(isolated_app):
"""HTTP cancel?action=interrupt should stop the worker and persist interruption."""
from starlette.testclient import TestClient
controller = _RunController()
factory = _make_agent_factory(
controller,
title="Interrupt candidate",
answer="This run should be interrupted.",
block_after_first_chunk=True,
)
with (
patch("app.gateway.services.resolve_agent_factory", return_value=factory),
TestClient(isolated_app) as client,
):
csrf_token = _register_user(client, email="interrupt-e2e@example.com")
thread_id = _create_thread(client, csrf_token)
created = client.post(
f"/api/threads/{thread_id}/runs",
json=_run_body(),
headers={"X-CSRF-Token": csrf_token},
)
assert created.status_code == 200, created.text
run_id = created.json()["run_id"]
assert controller.started.wait(5), "fake agent never started"
cancelled = client.post(
f"/api/threads/{thread_id}/runs/{run_id}/cancel?wait=true&action=interrupt",
headers={"X-CSRF-Token": csrf_token},
)
assert cancelled.status_code == 204, cancelled.text
assert controller.cancelled.wait(5), "fake agent task was not cancelled"
run = _wait_for_status(client, thread_id, run_id, "interrupted")
assert run["status"] == "interrupted"
thread = client.get(f"/api/threads/{thread_id}")
assert thread.status_code == 200, thread.text
assert thread.json()["status"] == "idle"
@pytest.mark.anyio
async def test_sse_consumer_disconnect_cancels_inflight_run():
"""A disconnected SSE request should cancel an in-flight run when configured."""
from app.gateway.services import sse_consumer
from deerflow.runtime import DisconnectMode, MemoryStreamBridge, RunManager, RunStatus
bridge = MemoryStreamBridge()
run_manager = RunManager()
record = await run_manager.create("thread-disconnect", on_disconnect=DisconnectMode.cancel)
await run_manager.set_status(record.run_id, RunStatus.running)
await bridge.publish(record.run_id, "metadata", {"run_id": record.run_id, "thread_id": record.thread_id})
worker_started = asyncio.Event()
worker_cancelled = asyncio.Event()
async def _pending_worker() -> None:
try:
worker_started.set()
await asyncio.Event().wait()
except asyncio.CancelledError:
worker_cancelled.set()
raise
record.task = asyncio.create_task(_pending_worker())
await asyncio.wait_for(worker_started.wait(), timeout=1.0)
class _DisconnectedRequest:
headers: dict[str, str] = {}
async def is_disconnected(self) -> bool:
return True
try:
frames = []
async for frame in sse_consumer(bridge, record, _DisconnectedRequest(), run_manager):
frames.append(frame)
assert frames == []
assert record.abort_event.is_set()
assert record.status == RunStatus.interrupted
await asyncio.wait_for(worker_cancelled.wait(), timeout=1.0)
assert record.task.cancelled()
finally:
if record.task is not None and not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
def test_cancel_rollback_restores_pre_run_checkpoint(isolated_app):
"""HTTP cancel?action=rollback should restore the checkpoint captured before run start."""
from starlette.testclient import TestClient
controller = _RunController()
factory = _make_agent_factory(
controller,
title="During rollback run",
answer="This answer should be rolled back.",
block_after_first_chunk=True,
)
with (
patch("app.gateway.services.resolve_agent_factory", return_value=factory),
TestClient(isolated_app) as client,
):
csrf_token = _register_user(client, email="rollback-e2e@example.com")
thread_id = _create_thread(client, csrf_token)
before = client.post(
f"/api/threads/{thread_id}/state",
json={
"values": {
"title": "Before rollback",
"messages": [{"type": "human", "content": "before"}],
},
"as_node": "test_seed",
},
headers={"X-CSRF-Token": csrf_token},
)
assert before.status_code == 200, before.text
assert before.json()["values"]["title"] == "Before rollback"
created = client.post(
f"/api/threads/{thread_id}/runs",
json=_run_body(),
headers={"X-CSRF-Token": csrf_token},
)
assert created.status_code == 200, created.text
run_id = created.json()["run_id"]
assert controller.checkpoint_written.wait(5), "fake agent did not write in-run checkpoint"
during = client.get(f"/api/threads/{thread_id}/state")
assert during.status_code == 200, during.text
assert during.json()["values"]["title"] == "During rollback run"
rolled_back = client.post(
f"/api/threads/{thread_id}/runs/{run_id}/cancel?wait=true&action=rollback",
headers={"X-CSRF-Token": csrf_token},
)
assert rolled_back.status_code == 204, rolled_back.text
assert controller.cancelled.wait(5), "rollback did not cancel the worker task"
run = _wait_for_status(client, thread_id, run_id, "error")
assert run["status"] == "error"
after = client.get(f"/api/threads/{thread_id}/state")
assert after.status_code == 200, after.text
assert after.json()["values"]["title"] == "Before rollback"
assert after.json()["values"]["messages"] == [{"type": "human", "content": "before"}]
@@ -0,0 +1,225 @@
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
Unit tests prove ``_apply`` does the right thing on a synthetic state.
This test does one level up: builds a real ``langchain.agents.create_agent``
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
1. The tool node is **not** invoked (the dangerous truncated tool call
is suppressed).
2. The final AIMessage in graph state has ``tool_calls == []``.
3. The observability ``safety_termination`` record is attached.
4. The user-facing explanation is appended to the message content.
This is the closest we can get to the issue's failure mode without a live
Moonshot key, and it proves the middleware actually gates LangChain's
tool router not just rewrites state in isolation.
"""
from __future__ import annotations
from typing import Any
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import tool
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
@tool
def write_file(path: str, content: str) -> str:
"""Pretend to write *content* to *path*. Records the call for assertion."""
_TOOL_INVOCATIONS.append({"path": path, "content": content})
return f"wrote {len(content)} bytes to {path}"
class _ContentFilteredModel(BaseChatModel):
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
First call returns finish_reason='content_filter' + a tool_call whose
arguments are visibly truncated. Second call (if reached) returns a
normal text completion so the agent can terminate cleanly.
"""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
# create_agent binds tools onto the model; we don't actually need
# to bind anything since responses are hard-coded, but the method
# must not raise.
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
message = AIMessage(
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
tool_calls=[
{
"id": "call_truncated_1",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/report.md",
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
},
}
],
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
)
else:
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
class _InspectMiddleware(AgentMiddleware):
"""Captures the messages list at every model entry so we can assert
no synthetic tool result was injected back into the conversation."""
def __init__(self) -> None:
super().__init__()
self.observed: list[list[Any]] = []
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
self.observed.append(list(request.messages))
return handler(request)
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
_TOOL_INVOCATIONS.clear()
inspector = _InspectMiddleware()
agent = create_agent(
model=_ContentFilteredModel(),
tools=[write_file],
# Inspector first so its after_model is registered; Safety last in
# the list so it executes first under LIFO (matches production wiring).
middleware=[inspector, SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
# Critical assertion: the dangerous truncated tool call must NOT have
# been executed. This is the entire point of the middleware.
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
# Final AIMessage has no tool calls left.
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
assert final_ai.tool_calls == []
# Observability stamp is present.
record = final_ai.additional_kwargs.get("safety_termination")
assert record is not None
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 1
assert record["suppressed_tool_call_names"] == ["write_file"]
# User-facing explanation is appended.
assert "safety-related signal" in final_ai.content
# Original partial text preserved (we don't throw away what the user
# already saw in the stream — see middleware docstring).
assert "Weekly Politics" in final_ai.content
# finish_reason on response_metadata is preserved (so SSE / converters
# downstream still see the real provider reason).
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
def test_content_filter_without_tool_calls_passes_through_unchanged():
"""No tool calls => issue scope says don't intervene; the partial
response should be delivered as-is so the user sees what they got."""
_TOOL_INVOCATIONS.clear()
class _NoToolModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-no-tool"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
msg = AIMessage(
content="Partial answer truncated by safety filter",
response_metadata={"finish_reason": "content_filter"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_NoToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
# Content untouched.
assert final_ai.content == "Partial answer truncated by safety filter"
# No safety_termination stamp because we didn't intervene.
assert "safety_termination" not in final_ai.additional_kwargs
# tool node never ran (there were no tool calls in the first place).
assert _TOOL_INVOCATIONS == []
def test_normal_tool_call_round_trip_is_not_affected():
"""Regression: a healthy finish_reason='tool_calls' response must still
execute the tool. The middleware must not over-fire."""
_TOOL_INVOCATIONS.clear()
class _HealthyToolModel(BaseChatModel):
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-healthy"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="",
tool_calls=[
{
"id": "call_ok",
"name": "write_file",
"args": {"path": "/tmp/ok", "content": "complete content"},
}
],
response_metadata={"finish_reason": "tool_calls"},
)
else:
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_HealthyToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
agent.invoke({"messages": [HumanMessage(content="write")]})
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
@@ -0,0 +1,651 @@
"""Unit tests for SafetyFinishReasonMiddleware."""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
)
from deerflow.config.safety_finish_reason_config import (
SafetyDetectorConfig,
SafetyFinishReasonConfig,
)
def _runtime(thread_id="t-1"):
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
return runtime
def _ai(
*,
content="",
tool_calls=None,
response_metadata=None,
additional_kwargs=None,
):
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
def _write_call(idx=1, content_text="半截"):
return {
"id": f"call_write_{idx}",
"name": "write_file",
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
}
class AlwaysHitDetector:
"""Test fixture: always reports the given termination."""
name = "always_hit"
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
self.reason_field = reason_field
self.reason_value = reason_value
self.extras = extras or {}
def detect(self, message):
return SafetyTermination(
detector=self.name,
reason_field=self.reason_field,
reason_value=self.reason_value,
extras=self.extras,
)
class NeverHitDetector:
name = "never_hit"
def detect(self, message):
return None
class RaisingDetector:
name = "raising"
def detect(self, message):
raise RuntimeError("boom")
# ---------------------------------------------------------------------------
# Core trigger behaviour
# ---------------------------------------------------------------------------
class TestTriggerCriteria:
def test_content_filter_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
patched = result["messages"][0]
assert patched.tool_calls == []
def test_content_filter_without_tool_calls_passes_through(self):
"""issue scope: when there are no tool calls the partial text is a
legitimate final response and should not be rewritten."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial response",
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_tool_calls_pass_through(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_stop_with_tool_calls_pass_through(self):
# Some providers report finish_reason='stop' for tool-call messages.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "stop"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_empty_message_list_passes_through(self):
mw = SafetyFinishReasonMiddleware()
assert mw._apply({"messages": []}, _runtime()) is None
def test_non_ai_last_message_passes_through(self):
mw = SafetyFinishReasonMiddleware()
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
assert mw._apply(state, _runtime()) is None
def test_anthropic_refusal_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"stop_reason": "refusal"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
def test_gemini_safety_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "SAFETY"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
# ---------------------------------------------------------------------------
# Message rewriting
# ---------------------------------------------------------------------------
class TestMessageRewrite:
def test_clears_structured_tool_calls(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert patched.tool_calls == []
def test_clears_raw_additional_kwargs_tool_calls(self):
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
tool calls from additional_kwargs.tool_calls if we forget them, which
would re-emit a synthetic ToolMessage downstream and confuse the
model. We must wipe both."""
mw = SafetyFinishReasonMiddleware()
raw_tool_calls = [
{
"id": "call_write_1",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
}
]
state = {
"messages": [
_ai(
tool_calls=[_write_call(1)],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"tool_calls": raw_tool_calls,
"function_call": {"name": "write_file", "arguments": "{}"},
},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert "tool_calls" not in patched.additional_kwargs
assert "function_call" not in patched.additional_kwargs
def test_preserves_other_additional_kwargs(self):
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
# may carry other provider-specific keys. They must not be wiped.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"reasoning": "thinking text",
"custom_provider_field": {"x": 1},
},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["reasoning"] == "thinking text"
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
def test_writes_observability_field(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
record = patched.additional_kwargs["safety_termination"]
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 2
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
def test_preserves_response_metadata_finish_reason(self):
"""Downstream SSE converters read response_metadata.finish_reason —
we want them to see the *real* provider reason, not 'stop'."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.response_metadata["finish_reason"] == "content_filter"
assert patched.response_metadata["model_name"] == "kimi-k2"
def test_appends_user_facing_explanation_to_str_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="some partial text",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert patched.content.startswith("some partial text")
assert "safety-related signal" in patched.content
def test_handles_empty_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert "safety-related signal" in patched.content
def test_handles_list_content_thinking_blocks(self):
"""Anthropic thinking / vLLM reasoning models emit content blocks.
Naively concatenating a string would raise TypeError."""
mw = SafetyFinishReasonMiddleware()
thinking_blocks = [
{"type": "thinking", "text": "let me consider..."},
{"type": "text", "text": "partial answer"},
]
state = {
"messages": [
_ai(
content=thinking_blocks,
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, list)
assert patched.content[:2] == thinking_blocks
assert patched.content[-1]["type"] == "text"
assert "safety-related signal" in patched.content[-1]["text"]
def test_idempotent_on_already_cleared_message(self):
# Re-running the middleware on a message we already cleared must not
# re-trigger (tool_calls is now empty → fast passthrough).
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
first = mw._apply(state, _runtime())
state2 = {"messages": [first["messages"][0]]}
second = mw._apply(state2, _runtime())
assert second is None
def test_preserves_message_id_for_add_messages_replacement(self):
"""LangGraph's add_messages reducer treats same-id messages as
replacements. model_copy keeps id by default."""
mw = SafetyFinishReasonMiddleware()
original = _ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
# AIMessage auto-generates id; capture it
original_id = original.id
state = {"messages": [original]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.id == original_id
# ---------------------------------------------------------------------------
# Detector wiring
# ---------------------------------------------------------------------------
class TestDetectorWiring:
def test_iterates_detectors_in_order(self):
first = AlwaysHitDetector(reason_value="first")
second = AlwaysHitDetector(reason_value="second")
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
def test_returns_none_when_no_detector_matches(self):
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_buggy_detector_does_not_break_run(self):
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
def test_constructor_copies_detectors(self):
"""Caller mutation after construction must not leak into us."""
detectors = [AlwaysHitDetector()]
mw = SafetyFinishReasonMiddleware(detectors=detectors)
detectors.clear()
state = {"messages": [_ai(tool_calls=[_write_call()])]}
assert mw._apply(state, _runtime()) is not None
# ---------------------------------------------------------------------------
# from_config
# ---------------------------------------------------------------------------
class TestFromConfig:
def test_default_config_uses_builtin_detectors(self):
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
assert len(mw._detectors) == 3
names = {d.name for d in mw._detectors}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_custom_detectors_loaded_via_reflection(self):
cfg = SafetyFinishReasonConfig(
detectors=[
SafetyDetectorConfig(
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
config={"finish_reasons": ["custom_filter"]},
),
]
)
mw = SafetyFinishReasonMiddleware.from_config(cfg)
assert len(mw._detectors) == 1
# Confirm the kwargs propagated.
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "custom_filter"},
)
]
}
assert mw._apply(state, _runtime()) is not None
# Default token no longer matches.
state2 = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state2, _runtime()) is None
def test_empty_detector_list_rejected(self):
cfg = SafetyFinishReasonConfig(detectors=[])
with pytest.raises(ValueError, match="enabled=false"):
SafetyFinishReasonMiddleware.from_config(cfg)
def test_non_detector_class_rejected(self):
cfg = SafetyFinishReasonConfig(
detectors=[SafetyDetectorConfig(use="builtins:dict")],
)
with pytest.raises(TypeError):
SafetyFinishReasonMiddleware.from_config(cfg)
# ---------------------------------------------------------------------------
# Stream event
# ---------------------------------------------------------------------------
class TestAuditEvent:
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
audit event via RunJournal.record_middleware when the run-scoped journal is
exposed under runtime.context["__run_journal"].
Background: review on PR #3035 — SSE custom event handles live consumers,
but post-run audit needs a row in run_events that can be queried with one
SQL statement (no JOIN against message body).
"""
def _runtime_with_journal(self, journal):
runtime = MagicMock()
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
return runtime
def test_records_audit_event_when_journal_present(self):
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
tc = _write_call(1)
state = {
"messages": [
_ai(
content="partial",
tool_calls=[tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
journal.record_middleware.assert_called_once()
call = journal.record_middleware.call_args
# tag is positional or kwarg depending on call style; we use kwargs.
assert call.kwargs["tag"] == "safety_termination"
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
assert call.kwargs["hook"] == "after_model"
assert call.kwargs["action"] == "suppress_tool_calls"
changes = call.kwargs["changes"]
assert changes["detector"] == "openai_compatible_content_filter"
assert changes["reason_field"] == "finish_reason"
assert changes["reason_value"] == "content_filter"
assert changes["suppressed_tool_call_count"] == 1
assert changes["suppressed_tool_call_names"] == ["write_file"]
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
assert "message_id" in changes
assert isinstance(changes["extras"], dict)
def test_audit_event_never_carries_tool_arguments(self):
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
and must NOT be persisted to run_events under any circumstance."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
sensitive_tc = {
"id": "call_x",
"name": "write_file",
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
}
state = {
"messages": [
_ai(
tool_calls=[sensitive_tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, self._runtime_with_journal(journal))
flat = repr(journal.record_middleware.call_args)
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
def test_no_journal_in_runtime_context_is_silently_skipped(self):
"""Subagent runtime / unit tests / no-event-store paths have no journal.
Middleware must still intervene and clear tool_calls only the audit
event is skipped."""
mw = SafetyFinishReasonMiddleware()
runtime = MagicMock()
runtime.context = {"thread_id": "t-noj"} # no __run_journal
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise; should still clear tool_calls.
result = mw._apply(state, runtime)
assert result is not None
assert result["messages"][0].tool_calls == []
def test_journal_record_exception_does_not_break_run(self):
"""Buggy journal must never propagate an exception into the agent loop."""
journal = MagicMock()
journal.record_middleware.side_effect = RuntimeError("db down")
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Must not raise.
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
assert result["messages"][0].tool_calls == []
def test_no_record_when_passthrough(self):
"""When the middleware does NOT intervene, no audit event is written."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"}, # healthy
)
]
}
assert mw._apply(state, self._runtime_with_journal(journal)) is None
journal.record_middleware.assert_not_called()
class TestStreamEvent:
def test_emits_event_when_writer_available(self, monkeypatch):
captured: list = []
def fake_writer(payload):
captured.append(payload)
# Patch get_stream_writer at the symbol-resolution site.
import langgraph.config
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, _runtime("t-stream"))
assert len(captured) == 1
payload = captured[0]
assert payload["type"] == "safety_termination"
assert payload["detector"] == "openai_compatible_content_filter"
assert payload["reason_field"] == "finish_reason"
assert payload["reason_value"] == "content_filter"
assert payload["suppressed_tool_call_count"] == 1
assert payload["suppressed_tool_call_names"] == ["write_file"]
assert payload["thread_id"] == "t-stream"
def test_writer_unavailable_does_not_break(self, monkeypatch):
import langgraph.config
def boom():
raise LookupError("not in a stream context")
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise.
result = mw._apply(state, _runtime())
assert result is not None
@@ -0,0 +1,176 @@
"""Unit tests for SafetyTerminationDetector built-ins."""
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.safety_termination_detectors import (
AnthropicRefusalDetector,
GeminiSafetyDetector,
OpenAICompatibleContentFilterDetector,
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
class TestOpenAICompatibleContentFilterDetector:
def test_default_matches_content_filter(self):
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
assert hit is not None
assert hit.detector == "openai_compatible_content_filter"
assert hit.reason_field == "finish_reason"
assert hit.reason_value == "content_filter"
def test_case_insensitive_match(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
def test_other_finish_reasons_pass_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
def test_missing_metadata_passes_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai()) is None
def test_non_string_finish_reason_passes_through(self):
# Some adapters may stash an enum or dict — must not raise.
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
def test_falls_back_to_additional_kwargs(self):
# Legacy adapters surface finish_reason via additional_kwargs.
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
assert hit is not None
def test_configurable_extra_values(self):
# Chinese providers sometimes use bespoke tokens.
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
# Original token still matches.
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
def test_carries_azure_content_filter_results(self):
d = OpenAICompatibleContentFilterDetector()
filter_results = {"hate": {"filtered": True, "severity": "high"}}
hit = d.detect(
_ai(
response_metadata={
"finish_reason": "content_filter",
"content_filter_results": filter_results,
},
)
)
assert hit is not None
assert hit.extras["content_filter_results"] == filter_results
class TestAnthropicRefusalDetector:
def test_default_matches_refusal(self):
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
assert hit is not None
assert hit.reason_field == "stop_reason"
assert hit.reason_value == "refusal"
def test_other_stop_reasons_pass_through(self):
d = AnthropicRefusalDetector()
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
def test_anthropic_does_not_steal_finish_reason(self):
# An OpenAI message must not accidentally trip the Anthropic detector.
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
class TestGeminiSafetyDetector:
def test_default_set_covers_documented_reasons(self):
d = GeminiSafetyDetector()
for reason in (
# text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# image safety
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
def test_normal_termination_passes_through(self):
d = GeminiSafetyDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
# excluded from the default set — they are either normal termination,
# capability mismatches, too broad (OTHER), or tool-call protocol
# errors. See GeminiSafetyDetector docstring.
for reason in (
"MAX_TOKENS",
"LANGUAGE",
"NO_IMAGE",
"OTHER",
"IMAGE_OTHER",
"MALFORMED_FUNCTION_CALL",
"UNEXPECTED_TOOL_CALL",
"FINISH_REASON_UNSPECIFIED",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
def test_carries_safety_ratings(self):
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
hit = GeminiSafetyDetector().detect(
_ai(
response_metadata={
"finish_reason": "SAFETY",
"safety_ratings": ratings,
},
)
)
assert hit is not None
assert hit.extras["safety_ratings"] == ratings
class TestDefaultDetectorSet:
def test_default_set_returns_three_detectors(self):
dets = default_detectors()
names = {d.name for d in dets}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_default_set_returns_fresh_list(self):
# Caller mutation must not affect later calls.
first = default_detectors()
first.clear()
second = default_detectors()
assert len(second) == 3
class TestProtocolConformance:
def test_builtins_satisfy_protocol(self):
for d in default_detectors():
assert isinstance(d, SafetyTerminationDetector)
def test_safety_termination_is_frozen(self):
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
try:
t.detector = "y" # type: ignore[misc]
except Exception:
return
raise AssertionError("SafetyTermination should be frozen")
+225
View File
@@ -0,0 +1,225 @@
from __future__ import annotations
import asyncio
import pytest
from langchain.agents.middleware import AgentMiddleware
from langchain.tools import ToolRuntime
from langgraph.runtime import Runtime
from deerflow.sandbox.middleware import SandboxMiddleware
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider, reset_sandbox_provider, set_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.tools import ls_tool
class _SyncProvider(SandboxProvider):
def __init__(self) -> None:
self.thread_ids: list[str | None] = []
def acquire(self, thread_id: str | None = None) -> str:
self.thread_ids.append(thread_id)
return "sync-sandbox"
def get(self, sandbox_id: str) -> Sandbox | None:
return None
def release(self, sandbox_id: str) -> None:
return None
class _SandboxStub(Sandbox):
def execute_command(self, command: str) -> str:
return "OK"
def read_file(self, path: str) -> str:
return "content"
def download_file(self, path: str) -> bytes:
return b"content"
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
return ["/mnt/user-data/workspace/file.txt"]
def write_file(self, path: str, content: str, append: bool = False) -> None:
return None
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
return [], False
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
return [], False
def update_file(self, path: str, content: bytes) -> None:
return None
class _AsyncOnlyProvider(SandboxProvider):
def __init__(self) -> None:
self.thread_ids: list[str | None] = []
self.released_ids: list[str] = []
self.sandbox = _SandboxStub("async-sandbox")
def acquire(self, thread_id: str | None = None) -> str:
raise AssertionError("async middleware should not call sync acquire")
async def acquire_async(self, thread_id: str | None = None) -> str:
self.thread_ids.append(thread_id)
return "async-sandbox"
def get(self, sandbox_id: str) -> Sandbox | None:
if sandbox_id == "async-sandbox":
return self.sandbox
return None
def release(self, sandbox_id: str) -> None:
self.released_ids.append(sandbox_id)
return None
@pytest.mark.anyio
async def test_provider_default_acquire_async_offloads_sync_acquire(monkeypatch: pytest.MonkeyPatch) -> None:
provider = _SyncProvider()
calls: list[tuple[object, tuple[object, ...]]] = []
async def fake_to_thread(func, /, *args):
calls.append((func, args))
return func(*args)
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
sandbox_id = await provider.acquire_async("thread-1")
assert sandbox_id == "sync-sandbox"
assert provider.thread_ids == ["thread-1"]
assert calls == [(provider.acquire, ("thread-1",))]
@pytest.mark.anyio
async def test_abefore_agent_uses_async_provider_acquire() -> None:
provider = _AsyncOnlyProvider()
set_sandbox_provider(provider)
try:
middleware = SandboxMiddleware(lazy_init=False)
result = await middleware.abefore_agent({}, Runtime(context={"thread_id": "thread-2"}))
finally:
reset_sandbox_provider()
assert result == {"sandbox": {"sandbox_id": "async-sandbox"}}
assert provider.thread_ids == ["thread-2"]
@pytest.mark.anyio
@pytest.mark.parametrize(
("middleware", "state", "runtime"),
[
(SandboxMiddleware(lazy_init=True), {}, Runtime(context={"thread_id": "thread-lazy"})),
(SandboxMiddleware(lazy_init=False), {}, Runtime(context={})),
(SandboxMiddleware(lazy_init=False), {"sandbox": {"sandbox_id": "existing"}}, Runtime(context={"thread_id": "thread-existing"})),
],
)
async def test_abefore_agent_delegates_to_super_when_not_acquiring(
monkeypatch: pytest.MonkeyPatch,
middleware: SandboxMiddleware,
state: dict,
runtime: Runtime,
) -> None:
calls: list[tuple[dict, Runtime]] = []
async def fake_super_abefore_agent(self, state_arg, runtime_arg):
calls.append((state_arg, runtime_arg))
return {"delegated": True}
monkeypatch.setattr(AgentMiddleware, "abefore_agent", fake_super_abefore_agent)
result = await middleware.abefore_agent(state, runtime)
assert result == {"delegated": True}
assert calls == [(state, runtime)]
@pytest.mark.anyio
async def test_default_lazy_tool_acquisition_uses_async_provider() -> None:
provider = _AsyncOnlyProvider()
set_sandbox_provider(provider)
try:
runtime = ToolRuntime(
state={},
context={"thread_id": "thread-lazy"},
config={"configurable": {}},
stream_writer=lambda _: None,
tools=[],
tool_call_id="call-1",
store=None,
)
result = await ls_tool.ainvoke({"runtime": runtime, "description": "list workspace", "path": "/mnt/user-data/workspace"})
finally:
reset_sandbox_provider()
assert result == "/mnt/user-data/workspace/file.txt"
assert provider.thread_ids == ["thread-lazy"]
assert runtime.state["sandbox"] == {"sandbox_id": "async-sandbox"}
assert runtime.context["sandbox_id"] == "async-sandbox"
@pytest.mark.anyio
@pytest.mark.parametrize(
("state", "runtime", "expected_sandbox_id"),
[
({"sandbox": {"sandbox_id": "state-sandbox"}}, Runtime(context={}), "state-sandbox"),
({}, Runtime(context={"sandbox_id": "context-sandbox"}), "context-sandbox"),
],
)
async def test_aafter_agent_releases_sandbox_off_thread(
monkeypatch: pytest.MonkeyPatch,
state: dict,
runtime: Runtime,
expected_sandbox_id: str,
) -> None:
provider = _AsyncOnlyProvider()
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
async def fake_to_thread(func, /, *args):
to_thread_calls.append((func, args))
return func(*args)
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
set_sandbox_provider(provider)
try:
result = await SandboxMiddleware().aafter_agent(state, runtime)
finally:
reset_sandbox_provider()
assert result is None
assert provider.released_ids == [expected_sandbox_id]
assert to_thread_calls == [(provider.release, (expected_sandbox_id,))]
@pytest.mark.anyio
async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[dict, Runtime]] = []
async def fake_super_aafter_agent(self, state_arg, runtime_arg):
calls.append((state_arg, runtime_arg))
return {"delegated": True}
monkeypatch.setattr(AgentMiddleware, "aafter_agent", fake_super_aafter_agent)
state = {}
runtime = Runtime(context={})
result = await SandboxMiddleware().aafter_agent(state, runtime)
assert result == {"delegated": True}
assert calls == [(state, runtime)]
@@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from deerflow.sandbox.exceptions import SandboxError
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
@@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
assert sandbox.content == "ALPHA\ntail\n"
def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-large-oserror"
def write_file(self, path: str, content: str, append: bool = False) -> None:
host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt"
raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True)
monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA)
monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None)
monkeypatch.setattr(
"deerflow.sandbox.tools._resolve_and_validate_user_data_path",
lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt",
)
result = write_file_tool.func(
runtime=runtime,
description="写入大文件失败",
path="/mnt/user-data/workspace/output.txt",
content="report body",
)
assert len(result) <= 2000
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result
assert "/mnt/user-data/workspace/nested/output.txt" in result
assert "remote tail marker" in result
assert "[write_file error truncated:" in result
def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-short-oserror"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise OSError("disk quota exceeded")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="写入失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded"
assert "[write_file error truncated:" not in result
def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-large-sandbox-error"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise SandboxError(f"remote write rejected {'B' * 12000} final detail")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="远端写入失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert len(result) <= 2000
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "SandboxError: remote write rejected" in result
assert "final detail" in result
assert "[write_file error truncated:" in result
@pytest.mark.parametrize(
("raised_error", "expected_fragment"),
[
pytest.param(
PermissionError("permission denied"),
"Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt",
id="permission",
),
pytest.param(
IsADirectoryError("target is a directory"),
"Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt",
id="directory",
),
pytest.param(
Exception("remote sandbox timeout"),
"Exception: remote sandbox timeout",
id="generic",
),
],
)
def test_write_file_tool_formats_all_other_failure_branches(
monkeypatch,
raised_error: Exception,
expected_fragment: str,
) -> None:
class FailingSandbox:
id = "sandbox-write-other-failure"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise raised_error
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="验证错误分支格式化",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert "/mnt/user-data/workspace/output.txt" in result
assert expected_fragment in result
assert "[write_file error truncated:" not in result
def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None:
"""Regression for #3133 review: SandboxError raised during sandbox
initialization (before the local `requested_path` assignment) must still
surface as a bounded tool error rather than an UnboundLocalError.
"""
def raise_sandbox_error(runtime):
raise SandboxError("sandbox missing")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="sandbox 初始化失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "SandboxError: sandbox missing" in result
assert "[write_file error truncated:" not in result
def test_file_operation_lock_memory_cleanup() -> None:
"""Verify that released locks are eventually cleaned up by WeakValueDictionary.
+110 -7
View File
@@ -2,13 +2,12 @@ from types import SimpleNamespace
import pytest
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.security_scanner import _extract_json_object, scan_skill_content
@pytest.mark.anyio
async def test_scan_skill_content_passes_run_name_to_model(monkeypatch):
def _make_env(monkeypatch, response_content):
config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None))
fake_response = SimpleNamespace(content='{"decision":"allow","reason":"ok"}')
fake_response = SimpleNamespace(content=response_content)
class FakeModel:
async def ainvoke(self, *args, **kwargs):
@@ -19,9 +18,59 @@ async def test_scan_skill_content_passes_run_name_to_model(monkeypatch):
model = FakeModel()
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: model)
return model
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
SKILL_CONTENT = "---\nname: demo-skill\ndescription: demo\n---\n"
# --- _extract_json_object unit tests ---
def test_extract_json_plain():
assert _extract_json_object('{"decision":"allow","reason":"ok"}') == {"decision": "allow", "reason": "ok"}
def test_extract_json_markdown_fence():
raw = '```json\n{"decision": "allow", "reason": "ok"}\n```'
assert _extract_json_object(raw) == {"decision": "allow", "reason": "ok"}
def test_extract_json_fence_no_language():
raw = '```\n{"decision": "allow", "reason": "ok"}\n```'
assert _extract_json_object(raw) == {"decision": "allow", "reason": "ok"}
def test_extract_json_prose_wrapped():
raw = 'Looking at this content I conclude: {"decision": "allow", "reason": "clean"} and that is final.'
assert _extract_json_object(raw) == {"decision": "allow", "reason": "clean"}
def test_extract_json_nested_braces_in_reason():
raw = '{"decision": "allow", "reason": "no issues with {placeholder} found"}'
assert _extract_json_object(raw) == {"decision": "allow", "reason": "no issues with {placeholder} found"}
def test_extract_json_nested_braces_code_snippet():
raw = 'Here is my review: {"decision": "block", "reason": "contains {\\"x\\": 1} code injection"}'
assert _extract_json_object(raw) == {"decision": "block", "reason": 'contains {"x": 1} code injection'}
def test_extract_json_returns_none_for_garbage():
assert _extract_json_object("no json here") is None
def test_extract_json_returns_none_for_unclosed_brace():
assert _extract_json_object('{"decision": "allow"') is None
# --- scan_skill_content integration tests ---
@pytest.mark.anyio
async def test_scan_skill_content_passes_run_name_to_model(monkeypatch):
model = _make_env(monkeypatch, '{"decision":"allow","reason":"ok"}')
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "allow"
assert model.kwargs["config"] == {"run_name": "security_agent"}
@@ -32,7 +81,61 @@ async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "block"
assert "manual review required" in result.reason
assert "unavailable" in result.reason
@pytest.mark.anyio
async def test_scan_allows_markdown_fenced_response(monkeypatch):
_make_env(monkeypatch, '```json\n{"decision": "allow", "reason": "clean"}\n```')
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "allow"
assert result.reason == "clean"
@pytest.mark.anyio
async def test_scan_normalizes_decision_case(monkeypatch):
_make_env(monkeypatch, '{"decision": "Allow", "reason": "looks fine"}')
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "allow"
@pytest.mark.anyio
async def test_scan_normalizes_uppercase_decision(monkeypatch):
_make_env(monkeypatch, '{"decision": "BLOCK", "reason": "dangerous"}')
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "block"
@pytest.mark.anyio
async def test_scan_handles_nested_braces_in_reason(monkeypatch):
_make_env(monkeypatch, '{"decision": "allow", "reason": "no issues with {placeholder}"}')
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "allow"
assert "{placeholder}" in result.reason
@pytest.mark.anyio
async def test_scan_handles_prose_wrapped_json(monkeypatch):
_make_env(monkeypatch, 'I reviewed the content: {"decision": "allow", "reason": "safe"}\nDone.')
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "allow"
@pytest.mark.anyio
async def test_scan_distinguishes_unparseable_from_unavailable(monkeypatch):
_make_env(monkeypatch, "I can't decide, this is just prose without any JSON at all.")
result = await scan_skill_content(SKILL_CONTENT, executable=False)
assert result.decision == "block"
assert "unparseable" in result.reason
@pytest.mark.anyio
async def test_scan_distinguishes_unparseable_executable(monkeypatch):
_make_env(monkeypatch, "no json here")
result = await scan_skill_content(SKILL_CONTENT, executable=True)
# Even for executable content, unparseable uses the unparseable message
assert result.decision == "block"
assert "unparseable" in result.reason
@@ -0,0 +1,429 @@
"""End-to-end verification for issue #2862 (and the regression of #2782).
Goal: prove without trusting any single layer's claim — that an authenticated
user creating a custom agent through the real ``setup_agent`` tool, driven by a
real LangGraph ``create_agent`` graph, ends up with files under
``users/<auth_uid>/agents/<name>`` and **not** under ``users/default/agents/...``.
We intentionally exercise the full pipeline:
HTTP body shape (mimics LangGraph SDK wire format)
-> app.gateway.services.start_run config-assembly chain
-> deerflow.runtime.runs.worker._build_runtime_context
-> langchain.agents.create_agent graph
-> ToolNode dispatch
-> setup_agent tool
The only thing we mock is the LLM (FakeMessagesListChatModel) every layer
that handles ``user_id`` is the real production code path. If the
``user_id`` propagation is broken anywhere in this chain, these tests will
fail.
These tests intentionally ``no_auto_user`` so that the ``contextvar``
fallback would put files into ``default/`` if propagation breaks.
"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
from uuid import UUID
import pytest
from _agent_e2e_helpers import FakeToolCallingModel
from langchain_core.messages import AIMessage, HumanMessage
from app.gateway.services import (
build_run_config,
inject_authenticated_user_context,
merge_run_context_overrides,
)
from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context
# ---------------------------------------------------------------------------
# Helpers — real production code paths
# ---------------------------------------------------------------------------
def _make_request(user_id_str: str | None) -> SimpleNamespace:
"""Build a fake FastAPI Request that carries an authenticated user."""
if user_id_str is None:
user = None
else:
# User.id is UUID in production; honour that
user = SimpleNamespace(id=UUID(user_id_str), email="alice@local")
return SimpleNamespace(state=SimpleNamespace(user=user))
def _assemble_config(
*,
body_config: dict | None,
body_context: dict | None,
request_user_id: str | None,
thread_id: str = "thread-e2e",
assistant_id: str = "lead_agent",
) -> dict:
"""Replay the **exact** start_run config-assembly sequence."""
config = build_run_config(thread_id, body_config, None, assistant_id=assistant_id)
merge_run_context_overrides(config, body_context)
inject_authenticated_user_context(config, _make_request(request_user_id))
return config
def _make_paths_mock(tmp_path: Path):
"""Mirror the production paths.user_agent_dir signature."""
from unittest.mock import MagicMock
paths = MagicMock()
paths.base_dir = tmp_path
paths.agent_dir = lambda name: tmp_path / "agents" / name
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
return paths
# ---------------------------------------------------------------------------
# L1-L3: HTTP wire format → start_run → worker._build_runtime_context
# ---------------------------------------------------------------------------
class TestConfigAssembly:
"""Covers L1-L3: validate that user_id reaches runtime_ctx for every wire shape."""
def test_typical_wire_format_user_id_in_runtime_ctx(self):
"""Real frontend: body.config={recursion_limit}, body.context={agent_name,...}."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={"agent_name": "myagent", "is_bootstrap": True, "mode": "flash"},
request_user_id="11111111-2222-3333-4444-555555555555",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555"
assert runtime_ctx["agent_name"] == "myagent"
def test_body_context_none_still_injects_user_id(self):
"""If frontend omits body.context entirely, inject must still create it."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context=None,
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
def test_body_context_empty_dict_still_injects_user_id(self):
"""body.context={} (falsy) path: inject must still produce user_id."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={},
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
def test_body_config_already_contains_context_field(self):
"""body.config={'context': {...}} (LG 0.6 alt wire): inject still wins."""
config = _assemble_config(
body_config={"context": {"agent_name": "myagent"}, "recursion_limit": 1000},
body_context=None,
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
def test_client_supplied_user_id_is_overridden(self):
"""Spoofed client user_id must be overwritten by inject (auth-trusted source)."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={"agent_name": "myagent", "user_id": "spoofed"},
request_user_id="11111111-2222-3333-4444-555555555555",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555"
def test_unauthenticated_request_does_not_inject(self):
"""If request.state.user is missing (impossible under fail-closed auth, but
verify defensively), inject must not write user_id and runtime_ctx must
therefore lack it forcing the tool fallback path to reveal itself."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={"agent_name": "myagent"},
request_user_id=None,
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert "user_id" not in runtime_ctx
# ---------------------------------------------------------------------------
# L4-L7: Real LangGraph create_agent driving the real setup_agent tool
# ---------------------------------------------------------------------------
def _build_real_bootstrap_graph(authenticated_user_id: str):
"""Construct a real LangGraph using create_agent + the real setup_agent tool.
The LLM is faked (FakeMessagesListChatModel) so we don't need an API key.
Everything else ToolNode dispatch, runtime injection, middleware is
the real production code path.
"""
from langchain.agents import create_agent
from deerflow.tools.builtins.setup_agent_tool import setup_agent
# First model turn: emit a tool_call for setup_agent
# Second model turn (after tool result): final answer (terminates the loop)
fake_model = FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "setup_agent",
"args": {
"soul": "# My E2E Agent\n\nA SOUL written by the model.",
"description": "End-to-end test agent",
},
"id": "call_setup_1",
"type": "tool_call",
}
],
),
AIMessage(content=f"Done. Agent created for user {authenticated_user_id}."),
]
)
graph = create_agent(
model=fake_model,
tools=[setup_agent],
system_prompt="You are a bootstrap agent. Call setup_agent immediately.",
)
return graph
@pytest.mark.no_auto_user
@pytest.mark.asyncio
async def test_real_graph_real_setup_agent_writes_to_authenticated_user_dir(tmp_path: Path):
"""The smoking-gun test for issue #2862.
Under no_auto_user (contextvar = empty), if user_id propagation through
runtime.context is broken, setup_agent will fall back to DEFAULT_USER_ID
and write to users/default/agents/... The assertion that this directory
DOES NOT exist is what makes this test load-bearing.
"""
from langgraph.runtime import Runtime
auth_uid = "abcdef01-2345-6789-abcd-ef0123456789"
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "e2e-agent", "is_bootstrap": True},
request_user_id=auth_uid,
thread_id="thread-e2e-1",
)
# Replay worker.run_agent's runtime construction. This is the key step:
# it is what makes ToolRuntime.context contain user_id when the tool
# actually fires.
runtime_ctx = _build_runtime_context("thread-e2e-1", "run-1", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_real_bootstrap_graph(auth_uid)
# Patch get_paths only (the file-system rooting); everything else is real
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
# Drive the real graph. This goes through real ToolNode + real Runtime merge.
final_state = await graph.ainvoke(
{"messages": [HumanMessage(content="Create an agent named e2e-agent")]},
config=config,
)
expected_dir = tmp_path / "users" / auth_uid / "agents" / "e2e-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "e2e-agent"
# Load-bearing assertions:
assert expected_dir.exists(), f"Agent directory not found at the authenticated user's path. Expected: {expected_dir}. tmp_path tree: {[str(p) for p in tmp_path.rglob('*')]}"
assert (expected_dir / "SOUL.md").read_text() == "# My E2E Agent\n\nA SOUL written by the model."
assert (expected_dir / "config.yaml").exists()
assert not default_dir.exists(), "REGRESSION: agent landed under users/default/. user_id propagation broke somewhere between HTTP layer and ToolRuntime.context."
# And final state should reflect tool success
last = final_state["messages"][-1]
assert "Done" in (last.content if isinstance(last.content, str) else str(last.content))
@pytest.mark.no_auto_user
@pytest.mark.asyncio
async def test_inject_failure_falls_back_to_default_proving_test_is_load_bearing(tmp_path: Path):
"""Negative control: if inject does NOT happen (no user in request), and
contextvar is empty (no_auto_user), setup_agent must land in default/.
This proves the positive test is actually load-bearing i.e. it would
have failed before PR #2784, not passed accidentally.
"""
from langgraph.runtime import Runtime
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "fallback-agent", "is_bootstrap": True},
request_user_id=None, # no auth — inject is a no-op
thread_id="thread-e2e-2",
)
runtime_ctx = _build_runtime_context("thread-e2e-2", "run-2", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_real_bootstrap_graph("does-not-matter")
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
await graph.ainvoke(
{"messages": [HumanMessage(content="Create fallback-agent")]},
config=config,
)
default_dir = tmp_path / "users" / "default" / "agents" / "fallback-agent"
assert default_dir.exists(), "Negative control failed: even without inject + contextvar, agent did not land in default/. The test infrastructure may not be reproducing the bug condition."
# ---------------------------------------------------------------------------
# L5: Sub-graph runtime propagation (the task tool case)
# ---------------------------------------------------------------------------
@pytest.mark.no_auto_user
@pytest.mark.asyncio
async def test_subgraph_invocation_preserves_user_id_in_runtime(tmp_path: Path):
"""When a parent graph invokes a child graph (the pattern used by
subagents), parent_runtime.merge() must keep user_id intact.
We construct a child graph that contains setup_agent and call it from
a parent graph's tool. If LangGraph re-creates the Runtime and drops
user_id at the sub-graph boundary, this fails.
"""
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from deerflow.tools.builtins.setup_agent_tool import setup_agent
auth_uid = "deadbeef-0000-1111-2222-333344445555"
# Inner graph: same as the bootstrap flow
inner_model = FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "setup_agent",
"args": {"soul": "# Inner", "description": "subgraph"},
"id": "call_inner_1",
"type": "tool_call",
}
],
),
AIMessage(content="inner done"),
]
)
inner_graph = create_agent(
model=inner_model,
tools=[setup_agent],
system_prompt="inner",
)
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "subgraph-agent", "is_bootstrap": True},
request_user_id=auth_uid,
thread_id="thread-e2e-3",
)
runtime_ctx = _build_runtime_context("thread-e2e-3", "run-3", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
# Direct sub-graph invoke (mimics what a subagent invocation looks like
# — distinct ainvoke call, but parent config carries the same runtime).
await inner_graph.ainvoke(
{"messages": [HumanMessage(content="Create subgraph-agent")]},
config=config,
)
expected_dir = tmp_path / "users" / auth_uid / "agents" / "subgraph-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "subgraph-agent"
assert expected_dir.exists()
assert not default_dir.exists()
# ---------------------------------------------------------------------------
# L6: Sync tool path through ContextThreadPoolExecutor
# ---------------------------------------------------------------------------
def test_sync_tool_dispatch_through_thread_pool_uses_runtime_context(tmp_path: Path):
"""setup_agent is a sync function. When dispatched through ToolNode's
ContextThreadPoolExecutor, runtime.context must still carry user_id
not via thread-local copy_context (which only carries contextvars), but
because it was passed in as the ToolRuntime constructor argument.
"""
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from deerflow.tools.builtins.setup_agent_tool import setup_agent
auth_uid = "11112222-3333-4444-5555-666677778888"
fake_model = FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "setup_agent",
"args": {"soul": "# Sync", "description": "sync path"},
"id": "call_sync_1",
"type": "tool_call",
}
],
),
AIMessage(content="sync done"),
]
)
graph = create_agent(model=fake_model, tools=[setup_agent], system_prompt="sync")
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "sync-agent", "is_bootstrap": True},
request_user_id=auth_uid,
thread_id="thread-e2e-4",
)
runtime_ctx = _build_runtime_context("thread-e2e-4", "run-4", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
# Use SYNC invoke to hit the ContextThreadPoolExecutor path
graph.invoke(
{"messages": [HumanMessage(content="Create sync-agent")]},
config=config,
)
expected_dir = tmp_path / "users" / auth_uid / "agents" / "sync-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "sync-agent"
assert expected_dir.exists()
assert not default_dir.exists()
@@ -0,0 +1,326 @@
"""Real HTTP end-to-end verification for issue #2862's setup_agent path.
This test drives the **entire** FastAPI gateway through ``starlette.testclient.TestClient``:
starlette.testclient.TestClient (real ASGI stack)
-> AuthMiddleware (real cookie parsing, real JWT decode)
-> /api/v1/auth/register endpoint (real password hash + sqlite write)
-> /api/threads/{id}/runs/stream endpoint (real start_run config-assembly)
-> background asyncio.create_task(run_agent) (real worker, real Runtime)
-> langchain.agents.create_agent graph (real, with fake LLM)
-> ToolNode dispatch (real)
-> setup_agent tool (real file I/O)
The only mock is the LLM (no API key needed). Every layer that participates
in ``user_id`` propagation auth, ContextVar, ``inject_authenticated_user_context``,
``worker._build_runtime_context``, ``Runtime.merge`` is the real production
code path. If the chain is broken at any layer, this test fails.
This is what "真实验证" looks like for a server that lives behind authentication:
register a user, log in (cookie), POST to /runs/stream, wait for the run to
finish, then read the filesystem.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model
def _build_fake_create_chat_model(agent_name: str):
"""Return a callable matching the real ``create_chat_model`` signature.
Whenever the lead agent constructs a chat model during the bootstrap flow,
we hand it a fake that emits a single setup_agent tool_call on its first
turn, then a benign final answer on its second turn.
"""
def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel:
return build_single_tool_call_model(
tool_name="setup_agent",
tool_args={
"soul": f"# Real HTTP E2E SOUL for {agent_name}",
"description": "real-http-e2e agent",
},
tool_call_id="call_real_http_1",
final_text=f"Agent {agent_name} created via real HTTP e2e.",
)
return fake_create_chat_model
@pytest.fixture
def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Stand up an isolated DeerFlow data root + config under tmp_path.
- Sets ``DEER_FLOW_HOME`` so paths land under tmp_path, not the real
``.deer-flow`` directory.
- Stages a copy of the project's ``config.yaml`` (or ``config.example.yaml``
on a fresh CI checkout where ``config.yaml`` is gitignored) and pins
``DEER_FLOW_CONFIG_PATH`` to it, so lifespan boot doesn't depend on the
developer's local config layout.
- Sets a placeholder OPENAI_API_KEY because the config has
``$OPENAI_API_KEY`` that gets resolved at parse time; the LLM itself is
mocked, so any non-empty value works.
"""
home = tmp_path / "deer-flow-home"
home.mkdir()
monkeypatch.setenv("DEER_FLOW_HOME", str(home))
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used-because-llm-is-mocked")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
# Hermetic config: do not depend on whether the dev machine has a real
# ``config.yaml`` at the repo root. CI's ``actions/checkout`` only ships
# ``config.example.yaml`` (and its ``models:`` list is commented out, so
# AppConfig validation would reject it). Write a minimal, self-sufficient
# config to tmp_path and pin ``DEER_FLOW_CONFIG_PATH`` to it.
staged_config = tmp_path / "config.yaml"
staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config))
return home
# Minimal config that satisfies AppConfig + LeadAgent's _resolve_model_name.
# The model `use` path must resolve to a real class for config parsing to
# succeed; the test patches ``create_chat_model`` on the lead agent module,
# so the model is never actually instantiated. SandboxConfig.use is required
# at schema level; LocalSandboxProvider is the only sandbox that runs without
# Docker.
_MINIMAL_CONFIG_YAML = """\
log_level: info
models:
- name: fake-test-model
display_name: Fake Test Model
use: langchain_openai:ChatOpenAI
model: gpt-4o-mini
api_key: $OPENAI_API_KEY
base_url: $OPENAI_API_BASE
sandbox:
use: deerflow.sandbox.local:LocalSandboxProvider
agents_api:
enabled: true
database:
backend: sqlite
"""
def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None:
"""Reset every process-wide cache that would survive across tests.
This fixture stands up a full FastAPI app + sqlite DB + LangGraph runtime
inside ``tmp_path``. To get true per-test isolation we have to invalidate
a handful of module-level caches that production normally never resets,
so they pick up our test-only ``DEER_FLOW_HOME`` and sqlite path:
- ``deerflow.config.app_config`` caches the parsed ``config.yaml``.
- ``deerflow.config.paths`` caches the ``Paths`` singleton derived from
``DEER_FLOW_HOME`` at first access.
- ``deerflow.persistence.engine`` caches the SQLAlchemy engine and
session factory after the first call to ``init_engine_from_config``.
``raising=False`` keeps the fixture resilient if upstream renames or
drops one of these attributes the test will simply skip that reset
instead of failing with a confusing AttributeError, and the next test
to call ``get_app_config()``/``get_paths()`` will surface the real
incompatibility loudly.
"""
from deerflow.config import app_config as app_config_module
from deerflow.config import paths as paths_module
from deerflow.persistence import engine as engine_module
for module, attr in (
(app_config_module, "_app_config"),
(app_config_module, "_app_config_path"),
(app_config_module, "_app_config_mtime"),
(paths_module, "_paths_singleton"),
(engine_module, "_engine"),
(engine_module, "_session_factory"),
):
monkeypatch.setattr(module, attr, None, raising=False)
@pytest.fixture
def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch):
"""Build a fresh FastAPI app inside a clean DEER_FLOW_HOME.
Each test gets its own sqlite DB and checkpoint store under ``tmp_path``,
with no cross-test contamination.
"""
_reset_process_singletons(monkeypatch)
# Re-resolve the config from the test-only DEER_FLOW_HOME and pin its
# sqlite path into tmp_path so the lifespan-time engine init lands there.
from deerflow.config import app_config as app_config_module
cfg = app_config_module.get_app_config()
cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db")
from app.gateway.app import create_app
return create_app()
def _drain_stream(response, *, timeout: float = 30.0, max_bytes: int = 4 * 1024 * 1024) -> str:
"""Consume an SSE response body until the run terminates and return the text.
Bounded to keep the test fail-fast:
- Stops as soon as an ``event: end`` SSE frame is observed (the gateway
sends this when the background run finishes see ``services.format_sse``
and ``StreamBridge.publish_end``).
- Stops at ``timeout`` seconds wall-clock so a stuck run / runaway heartbeat
loop surfaces a real failure instead of hanging pytest.
- Stops at ``max_bytes`` so a runaway producer can't OOM the test process.
"""
import time as _time
deadline = _time.monotonic() + timeout
body = b""
for chunk in response.iter_bytes():
body += chunk
if b"event: end" in body:
break
if len(body) >= max_bytes:
break
if _time.monotonic() >= deadline:
break
return body.decode("utf-8", errors="replace")
def _wait_for_file(path: Path, *, timeout: float = 10.0) -> bool:
"""Block until *path* exists or *timeout* elapses.
The run completes inside ``asyncio.create_task`` after start_run returns,
so the test must wait for the background task to flush its writes.
"""
import time as _time
deadline = _time.monotonic() + timeout
while _time.monotonic() < deadline:
if path.exists():
return True
_time.sleep(0.05)
return False
@pytest.mark.no_auto_user
def test_real_http_create_agent_lands_in_authenticated_user_dir(
isolated_app: Any,
isolated_deer_flow_home: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""The full real-server contract test.
1. Register a real user via POST /api/v1/auth/register (also auto-logs in)
2. POST to /api/threads/{tid}/runs/stream with the **exact** body shape the
frontend (LangGraph SDK) sends during the bootstrap flow.
3. Wait for the background run to finish.
4. Assert SOUL.md exists under users/<authenticated_uid>/agents/<name>/.
5. Assert NOTHING exists under users/default/agents/<name>/.
"""
# ``deerflow.agents.lead_agent.agent`` imports ``create_chat_model`` with
# ``from deerflow.models import create_chat_model`` at module load time,
# rebinding the symbol into its own namespace. So the only patch that
# intercepts the call is the bound name on ``lead_agent.agent`` — patching
# ``deerflow.models.create_chat_model`` would be too late.
agent_name = "real-http-agent"
from starlette.testclient import TestClient
with (
patch(
"deerflow.agents.lead_agent.agent.create_chat_model",
new=_build_fake_create_chat_model(agent_name),
),
TestClient(isolated_app) as client,
):
# --- 1. Register & auto-login ---
register = client.post(
"/api/v1/auth/register",
json={"email": "e2e-user@example.com", "password": "very-strong-password-123"},
)
assert register.status_code == 201, register.text
registered = register.json()
auth_uid = registered["id"]
# The endpoint sets both access_token (auth) and csrf_token (CSRF Double
# Submit Cookie) cookies; the TestClient cookie jar propagates them.
assert client.cookies.get("access_token"), "register endpoint must set session cookie"
csrf_token = client.cookies.get("csrf_token")
assert csrf_token, "register endpoint must set csrf_token cookie"
# --- 2. Create a thread (require_existing=True on /runs/stream means
# we must call POST /api/threads first; the React frontend does the
# same via the LangGraph SDK's threads.create) ---
import uuid as _uuid
thread_id = str(_uuid.uuid4())
created = client.post(
"/api/threads",
json={"thread_id": thread_id, "metadata": {}},
headers={"X-CSRF-Token": csrf_token},
)
assert created.status_code == 200, created.text
# --- 3. POST /runs/stream with the bootstrap wire format ---
# This is the EXACT shape the React frontend sends after PR #2784:
# thread.submit(input, {config, context}) ->
# POST /api/threads/{id}/runs/stream body =
# {assistant_id, input, config, context}
body = {
"assistant_id": "lead_agent",
"input": {
"messages": [
{
"role": "user",
"content": (f"The new custom agent name is {agent_name}. Help me design its SOUL.md before saving it."),
}
]
},
"config": {"recursion_limit": 50},
"context": {
"agent_name": agent_name,
"is_bootstrap": True,
"mode": "flash",
"thinking_enabled": False,
"is_plan_mode": False,
"subagent_enabled": False,
},
"stream_mode": ["values"],
}
# The /stream endpoint returns SSE; we drain it so the server-side
# background task (run_agent) gets to completion before we look at disk.
with client.stream(
"POST",
f"/api/threads/{thread_id}/runs/stream",
json=body,
headers={"X-CSRF-Token": csrf_token},
) as resp:
assert resp.status_code == 200, resp.read().decode()
transcript = _drain_stream(resp)
# Sanity: the stream should have produced at least one event
assert "event:" in transcript, f"no SSE events in response: {transcript[:500]!r}"
# --- 4. Verify filesystem outcome ---
expected_dir = isolated_deer_flow_home / "users" / auth_uid / "agents" / agent_name
default_dir = isolated_deer_flow_home / "users" / "default" / "agents" / agent_name
# The setup_agent tool runs inside the background asyncio task spawned
# by start_run; SSE-drain typically waits for it, but we add a bounded
# poll to be robust against scheduler jitter.
assert _wait_for_file(expected_dir / "SOUL.md", timeout=15.0), (
"SOUL.md did not appear under users/<auth_uid>/agents/. "
f"Expected: {expected_dir / 'SOUL.md'}. "
f"tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}. "
f"SSE transcript tail: {transcript[-1000:]!r}"
)
soul_text = (expected_dir / "SOUL.md").read_text()
assert agent_name in soul_text, f"unexpected SOUL content: {soul_text!r}"
# The smoking-gun assertion: the agent must NOT have landed in default/
assert not default_dir.exists(), f"REGRESSION: agent landed under users/default/{agent_name} instead of the authenticated user. Default-dir contents: {list(default_dir.rglob('*')) if default_dir.exists() else 'n/a'}"
+3 -1
View File
@@ -7,6 +7,7 @@ from types import SimpleNamespace
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import skills as skills_router
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill
@@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
def _make_test_app(config) -> FastAPI:
app = FastAPI()
app.state.config = config
app.state.config = config # kept for any startup-style reads
app.dependency_overrides[get_config] = lambda: config
app.include_router(skills_router.router)
return app
+317 -1
View File
@@ -291,7 +291,7 @@ class TestAgentConstruction:
assert captured["agent"]["model"] is model
assert captured["agent"]["middleware"] is middlewares
assert captured["agent"]["tools"] == []
assert captured["agent"]["system_prompt"] == base_config.system_prompt
assert captured["agent"]["system_prompt"] is None # system_prompt is merged into initial state messages
@pytest.mark.anyio
async def test_load_skill_messages_uses_explicit_app_config_for_skill_storage(
@@ -331,6 +331,124 @@ class TestAgentConstruction:
assert len(messages) == 1
assert "Use demo skill" in messages[0].content
@pytest.mark.anyio
async def test_build_initial_state_consolidates_system_prompt_and_skills(
self,
classes,
base_config,
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
"""_build_initial_state merges system_prompt and skills into one SystemMessage."""
SubagentExecutor = classes["SubagentExecutor"]
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
skill_file = skill_dir / "SKILL.md"
skill_file.write_text("Skill instructions here", encoding="utf-8")
monkeypatch.setattr(
sys.modules["deerflow.skills.storage"],
"get_or_new_skill_storage",
lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="my-skill", skill_file=skill_file, allowed_tools=None)]),
)
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
)
state, _filtered_tools = await executor._build_initial_state("Do the task")
messages = state["messages"]
# Should have exactly 2 messages: one combined SystemMessage + one HumanMessage
assert len(messages) == 2
from langchain_core.messages import HumanMessage, SystemMessage
assert isinstance(messages[0], SystemMessage)
assert isinstance(messages[1], HumanMessage)
# SystemMessage should contain both the system_prompt and skill content
assert base_config.system_prompt in messages[0].content
assert "Skill instructions here" in messages[0].content
# HumanMessage should be the task
assert messages[1].content == "Do the task"
@pytest.mark.anyio
async def test_build_initial_state_no_skills_only_system_prompt(
self,
classes,
base_config,
monkeypatch: pytest.MonkeyPatch,
):
"""_build_initial_state works when there are no skills."""
SubagentExecutor = classes["SubagentExecutor"]
monkeypatch.setattr(
sys.modules["deerflow.skills.storage"],
"get_or_new_skill_storage",
lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: []),
)
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
)
state, _filtered_tools = await executor._build_initial_state("Do the task")
messages = state["messages"]
from langchain_core.messages import HumanMessage, SystemMessage
assert len(messages) == 2
assert isinstance(messages[0], SystemMessage)
assert base_config.system_prompt in messages[0].content
assert isinstance(messages[1], HumanMessage)
@pytest.mark.anyio
async def test_build_initial_state_no_system_prompt_with_skills(
self,
classes,
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
"""_build_initial_state works when there is no system_prompt but there are skills."""
SubagentConfig = classes["SubagentConfig"]
config = SubagentConfig(
name="test-agent",
description="Test agent",
system_prompt=None,
max_turns=10,
timeout_seconds=60,
)
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
skill_file = skill_dir / "SKILL.md"
skill_file.write_text("Skill content", encoding="utf-8")
monkeypatch.setattr(
sys.modules["deerflow.skills.storage"],
"get_or_new_skill_storage",
lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="my-skill", skill_file=skill_file, allowed_tools=None)]),
)
SubagentExecutor = classes["SubagentExecutor"]
executor = SubagentExecutor(config=config, tools=[], thread_id="test-thread")
state, _filtered_tools = await executor._build_initial_state("Do the task")
messages = state["messages"]
from langchain_core.messages import HumanMessage, SystemMessage
assert len(messages) == 2
assert isinstance(messages[0], SystemMessage)
assert "Skill content" in messages[0].content
assert isinstance(messages[1], HumanMessage)
# -----------------------------------------------------------------------------
# Async Execution Path Tests
@@ -514,6 +632,70 @@ class TestAsyncExecutionPath:
assert result.status == SubagentStatus.COMPLETED
assert "Task" in result.result
@pytest.mark.anyio
async def test_aexecute_passes_at_most_one_system_message_to_agent(
self,
classes,
base_config,
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
"""Regression: messages sent to agent.astream must contain at most one
SystemMessage and it must be the first message.
This catches any regression where system_prompt would be re-injected
via create_agent() (e.g. system_prompt not passed as None) and appear
as a second SystemMessage, which providers like vLLM and Xinference
reject with "System message must be at the beginning."
"""
from langchain_core.messages import AIMessage, SystemMessage
SubagentExecutor = classes["SubagentExecutor"]
SubagentStatus = classes["SubagentStatus"]
# Set up a skill so both system_prompt AND skill content are present,
# maximising the chance of catching a double-SystemMessage regression.
skill_dir = tmp_path / "regression-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("Skill instruction text", encoding="utf-8")
monkeypatch.setattr(
sys.modules["deerflow.skills.storage"],
"get_or_new_skill_storage",
lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="regression-skill", skill_file=skill_dir / "SKILL.md", allowed_tools=None)]),
)
captured_states: list[dict] = []
async def capturing_astream(state, **kwargs):
captured_states.append(state)
yield {"messages": [AIMessage(content="Done", id="msg-1")]}
mock_agent = MagicMock()
mock_agent.astream = capturing_astream
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
result = await executor._aexecute("Do something")
assert result.status == SubagentStatus.COMPLETED
assert len(captured_states) == 1, "astream should be called exactly once"
initial_messages = captured_states[0]["messages"]
system_messages = [m for m in initial_messages if isinstance(m, SystemMessage)]
assert len(system_messages) <= 1, f"Expected at most 1 SystemMessage but got {len(system_messages)}: {system_messages}"
if system_messages:
assert initial_messages[0] is system_messages[0], "SystemMessage must be the first message in the conversation"
# The consolidated SystemMessage must carry both the system_prompt
# and all skill content — nothing should be split across two messages.
assert base_config.system_prompt in system_messages[0].content
assert "Skill instruction text" in system_messages[0].content
class TestSkillAllowedTools:
@pytest.mark.anyio
@@ -943,6 +1125,15 @@ class TestAsyncToolSupport:
class TestThreadSafety:
"""Test thread safety of executor operations."""
@pytest.fixture
def executor_module(self, _setup_executor_classes):
"""Import the executor module with real classes."""
import importlib
from deerflow.subagents import executor
return importlib.reload(executor)
def test_multiple_executors_in_parallel(self, classes, base_config, msg):
"""Test multiple executors running in parallel via thread pool."""
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -988,6 +1179,68 @@ class TestThreadSafety:
assert result.status == SubagentStatus.COMPLETED
assert "Result" in result.result
def test_terminal_status_is_published_after_payload_fields(self, executor_module, monkeypatch):
"""Readers must not observe terminal status before terminal payload is complete."""
SubagentResult = executor_module.SubagentResult
SubagentStatus = executor_module.SubagentStatus
now_entered = threading.Event()
release_now = threading.Event()
completed_at = datetime(2026, 5, 1, 12, 0, 0)
writer_errors: list[BaseException] = []
class BlockingDateTime:
@staticmethod
def now():
now_entered.set()
release_now.wait(timeout=5)
return completed_at
monkeypatch.setattr(executor_module, "datetime", BlockingDateTime)
result = SubagentResult(
task_id="test-terminal-publication-order",
trace_id="test-trace",
status=SubagentStatus.RUNNING,
)
token_usage_records = [
{
"source_run_id": "run-1",
"caller": "subagent:test-agent",
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
}
]
def set_terminal():
try:
assert result.try_set_terminal(
SubagentStatus.COMPLETED,
result="done",
token_usage_records=token_usage_records,
)
except BaseException as exc:
writer_errors.append(exc)
writer = threading.Thread(target=set_terminal)
writer.start()
assert now_entered.wait(timeout=3), "try_set_terminal did not reach completed_at assignment"
assert result.completed_at is None
assert result.status == SubagentStatus.RUNNING
assert result.token_usage_records == token_usage_records
release_now.set()
writer.join(timeout=3)
assert not writer.is_alive(), "try_set_terminal did not finish"
assert writer_errors == []
assert result.completed_at == completed_at
assert result.status == SubagentStatus.COMPLETED
assert result.result == "done"
assert result.token_usage_records == token_usage_records
# -----------------------------------------------------------------------------
# Cleanup Background Task Tests
@@ -1422,6 +1675,69 @@ class TestCooperativeCancellation:
assert result.error == "Cancelled by user"
assert result.completed_at is not None
def test_late_completion_after_timeout_does_not_overwrite_timed_out(self, executor_module, classes, msg):
"""Late completion from the execution worker must not overwrite TIMED_OUT."""
SubagentExecutor = classes["SubagentExecutor"]
SubagentStatus = classes["SubagentStatus"]
short_config = classes["SubagentConfig"](
name="test-agent",
description="Test agent",
system_prompt="You are a test agent.",
max_turns=10,
timeout_seconds=0.05,
)
first_chunk_seen = threading.Event()
finish_stream = threading.Event()
execution_done = threading.Event()
async def mock_astream(*args, **kwargs):
yield {"messages": [msg.human("Task"), msg.ai("late completion", "msg-late")]}
first_chunk_seen.set()
deadline = asyncio.get_running_loop().time() + 5
while not finish_stream.is_set():
if asyncio.get_running_loop().time() >= deadline:
break
await asyncio.sleep(0.001)
mock_agent = MagicMock()
mock_agent.astream = mock_astream
executor = SubagentExecutor(
config=short_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
)
original_aexecute = executor._aexecute
async def tracked_aexecute(task, result_holder=None):
try:
return await original_aexecute(task, result_holder)
finally:
execution_done.set()
with patch.object(executor, "_create_agent", return_value=mock_agent), patch.object(executor, "_aexecute", tracked_aexecute):
task_id = executor.execute_async("Task")
assert first_chunk_seen.wait(timeout=3), "stream did not yield initial chunk"
result = executor_module._background_tasks[task_id]
assert result.cancel_event.wait(timeout=3), "timeout handler did not request cancellation"
assert result.status.value == SubagentStatus.TIMED_OUT.value
timed_out_error = result.error
timed_out_completed_at = result.completed_at
finish_stream.set()
assert execution_done.wait(timeout=3), "execution worker did not finish"
result = executor_module._background_tasks.get(task_id)
assert result is not None
assert result.status.value == SubagentStatus.TIMED_OUT.value
assert result.result is None
assert result.error == timed_out_error
assert result.completed_at == timed_out_completed_at
def test_cleanup_removes_cancelled_task(self, executor_module, classes):
"""Test that cleanup removes a CANCELLED task (terminal state)."""
SubagentResult = classes["SubagentResult"]
@@ -0,0 +1,161 @@
"""Tests for SubagentTokenCollector callback handler."""
from unittest.mock import MagicMock
from uuid import uuid4
from deerflow.subagents.token_collector import SubagentTokenCollector
def _make_llm_response(content="Hello", usage=None):
"""Create a mock LLM response with a message."""
msg = MagicMock()
msg.content = content
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
def _make_llm_response_from_usages(usages):
"""Create a mock LLM response with one generation per usage entry."""
generations = []
for usage in usages:
msg = MagicMock()
msg.content = "chunk"
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
generations.append([gen])
response = MagicMock()
response.generations = generations
return response
class TestSubagentTokenCollector:
def test_collects_usage_from_response(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["caller"] == "subagent:test"
assert records[0]["input_tokens"] == 100
assert records[0]["output_tokens"] == 50
assert records[0]["total_tokens"] == 150
assert "source_run_id" in records[0]
def test_total_tokens_zero_uses_input_plus_output(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 200, "output_tokens": 100, "total_tokens": 0}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 300
def test_total_tokens_missing_uses_input_plus_output(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 30, "output_tokens": 20}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 50
def test_dedup_same_run_id(self):
collector = SubagentTokenCollector(caller="subagent:test")
run_id = uuid4()
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
records = collector.snapshot_records()
assert len(records) == 1
def test_no_usage_no_record(self):
collector = SubagentTokenCollector(caller="subagent:test")
collector.on_llm_end(_make_llm_response("Hi", usage=None), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_zero_usage_no_record(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_skips_empty_generation_and_records_later_usage(self):
collector = SubagentTokenCollector(caller="subagent:test")
response = _make_llm_response_from_usages(
[
None,
{"input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
]
)
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 30
def test_snapshot_returns_copy(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
snap1 = collector.snapshot_records()
snap2 = collector.snapshot_records()
assert snap1 == snap2
assert snap1 is not snap2
# Mutating snapshot does not affect internal records
snap1.append({"source_run_id": "fake"})
assert len(collector.snapshot_records()) == 1
def test_multiple_calls_accumulate(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4())
collector.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 2
def test_different_run_ids_accumulate_separately(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
collector.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4())
collector.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 2
assert records[0]["total_tokens"] == 15
assert records[1]["total_tokens"] == 30
def test_message_without_usage_metadata_skipped(self):
"""A response where message has no usage_metadata attribute must be skipped."""
collector = SubagentTokenCollector(caller="subagent:test")
msg = MagicMock(spec=[]) # object without usage_metadata
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_generation_without_message_skipped(self):
"""A generation without a message attribute must be skipped."""
collector = SubagentTokenCollector(caller="subagent:test")
gen = MagicMock(spec=[]) # object without message
response = MagicMock()
response.generations = [[gen]]
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
+26 -1
View File
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
)
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
def _runtime(
thread_id: str | None = "thread-1",
agent_name: str | None = None,
user_id: str | None = None,
) -> SimpleNamespace:
context = {}
if thread_id is not None:
context["thread_id"] = thread_id
if agent_name is not None:
context["agent_name"] = agent_name
if user_id is not None:
context["user_id"] = user_id
return SimpleNamespace(context=context)
@@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="main",
agent_name="researcher",
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
+365 -47
View File
@@ -59,12 +59,15 @@ def _make_result(
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
token_usage_records: list[dict] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
token_usage_records=token_usage_records or [],
usage_reported=False,
)
@@ -729,17 +732,27 @@ def test_cleanup_called_on_timed_out(monkeypatch):
def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
"""Verify cleanup_background_task is NOT called on polling safety timeout.
"""Verify cleanup_background_task is NOT called directly on polling safety timeout.
This prevents race conditions where the background task is still running
but the polling loop gives up. The cleanup should happen later when the
executor completes and sets a terminal status.
The task is still RUNNING so it cannot be safely removed yet. Instead,
cooperative cancellation is requested and a deferred cleanup is scheduled.
"""
config = _make_subagent_config()
# Keep max_poll_count small for test speed: (1 + 60) // 5 = 12
config.timeout_seconds = 1
events = []
cleanup_calls = []
cancel_requests = []
scheduled_cleanups = []
class DummyCleanupTask:
def add_done_callback(self, _callback):
return None
def fake_create_task(coro):
scheduled_cleanups.append(coro)
coro.close()
return DummyCleanupTask()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
@@ -756,12 +769,18 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
monkeypatch.setattr(
task_tool_module,
"request_cancel_background_task",
lambda task_id: cancel_requests.append(task_id),
)
output = _run_task_tool(
runtime=_make_runtime(),
@@ -772,27 +791,36 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
)
assert output.startswith("Task polling timed out after 0 minutes")
# cleanup should NOT be called because the task is still RUNNING
# cleanup_background_task must NOT be called directly (task is still RUNNING)
assert cleanup_calls == []
# cooperative cancellation must be requested
assert cancel_requests == ["tc-no-cleanup-safety-timeout"]
# a deferred cleanup coroutine must be scheduled
assert len(scheduled_cleanups) == 1
def test_cleanup_scheduled_on_cancellation(monkeypatch):
"""Verify cancellation schedules deferred cleanup for the background task."""
"""Verify cancellation handler synchronously cleans up after shielded wait."""
config = _make_subagent_config()
events = []
cleanup_calls = []
scheduled_cleanup_coros = []
poll_count = 0
def get_result(_: str):
nonlocal poll_count
poll_count += 1
if poll_count == 1:
# Main loop polls RUNNING twice, then shielded wait gets COMPLETED
if poll_count <= 2:
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
return _make_result(FakeSubagentStatus.COMPLETED, result="done")
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
sleep_count = 0
async def cancel_on_second_sleep(_: float) -> None:
nonlocal sleep_count
sleep_count += 1
if sleep_count == 2:
raise asyncio.CancelledError
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
@@ -804,12 +832,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_second_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
@@ -826,25 +849,48 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
tool_call_id="tc-cancelled-cleanup",
)
assert cleanup_calls == []
assert len(scheduled_cleanup_coros) == 1
asyncio.run(scheduled_cleanup_coros.pop())
# Cleanup happens synchronously within the cancellation handler
assert cleanup_calls == ["tc-cancelled-cleanup"]
def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
"""Verify deferred cleanup gives up after a bounded number of polls."""
"""Verify cancellation handler survives a shielded-wait timeout gracefully.
When the subagent never reaches a terminal state, the shielded wait times
out (or is interrupted), the handler reports whatever usage it can, calls
cleanup (which is a no-op for non-terminal tasks), and re-raises.
"""
config = _make_subagent_config()
config.timeout_seconds = 1
events = []
report_calls = []
cleanup_calls = []
scheduled_cleanup_coros = []
scheduled_cleanups = []
# Always return RUNNING — subagent never finishes
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
def fake_report_subagent_usage(runtime, result):
report_calls.append((runtime, result))
class DummyCleanupTask:
def __init__(self, coro):
self.coro = coro
def add_done_callback(self, callback):
self.callback = callback
def fake_create_task(coro):
scheduled_cleanups.append(coro)
coro.close()
return DummyCleanupTask(coro)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
@@ -852,19 +898,10 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
)
monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
@@ -881,13 +918,73 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
tool_call_id="tc-cancelled-timeout",
)
async def bounded_sleep(_seconds: float) -> None:
return None
monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep)
asyncio.run(scheduled_cleanup_coros.pop())
# Non-terminal tasks cannot be cleaned immediately; a deferred cleanup
# keeps polling after the parent cancellation path exits.
assert cleanup_calls == []
assert len(scheduled_cleanups) == 1
# _report_subagent_usage is called (but skips because result has no records)
assert len(report_calls) == 1
def test_cancellation_wait_uses_subagent_polling_budget(monkeypatch):
"""Cancelled parent waits on the existing subagent polling budget, not a fixed timeout."""
config = _make_subagent_config()
events = []
report_calls = []
cleanup_calls = []
sleep_count = 0
result_polls = 0
terminal_result = _make_result(FakeSubagentStatus.COMPLETED, result="done")
def get_result(_: str):
nonlocal result_polls
result_polls += 1
if result_polls < 5:
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
return terminal_result
async def cancel_then_continue(_: float) -> None:
nonlocal sleep_count
sleep_count += 1
if sleep_count == 1:
raise asyncio.CancelledError
def fake_report_subagent_usage(runtime, result):
report_calls.append((runtime, result))
async def fail_on_fixed_timeout(awaitable, *, timeout=None):
raise AssertionError(f"cancellation wait should not use fixed timeout={timeout}")
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_then_continue)
monkeypatch.setattr(task_tool_module.asyncio, "wait_for", fail_on_fixed_timeout)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
with pytest.raises(asyncio.CancelledError):
_run_task_tool(
runtime=_make_runtime(),
description="执行任务",
prompt="cancel task",
subagent_type="general-purpose",
tool_call_id="tc-cancel-budget",
)
assert report_calls == [(_make_runtime(), terminal_result)]
assert cleanup_calls == ["tc-cancel-budget"]
def test_cancellation_calls_request_cancel(monkeypatch):
@@ -895,7 +992,6 @@ def test_cancellation_calls_request_cancel(monkeypatch):
config = _make_subagent_config()
events = []
cancel_requests = []
scheduled_cleanup_coros = []
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
@@ -915,11 +1011,6 @@ def test_cancellation_calls_request_cancel(monkeypatch):
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: (coro.close(), scheduled_cleanup_coros.append(None))[-1] or _DummyScheduledTask(),
)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
@@ -987,3 +1078,230 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
assert output == "Task cancelled by user."
assert any(e.get("type") == "task_cancelled" for e in events)
assert cleanup_calls == ["tc-poll-cancelled"]
def test_cancellation_reports_subagent_usage(monkeypatch):
"""Verify cancellation handler waits (shielded) for subagent terminal state,
then reports the final token usage before re-raising CancelledError.
The report must happen synchronously within the cancellation handler so
the parent worker's finally block sees the updated journal totals.
"""
config = _make_subagent_config()
events = []
report_calls = []
cleanup_calls = []
# Terminal result with token usage collected after cancellation processing
cancel_result = _make_result(FakeSubagentStatus.CANCELLED, error="Cancelled by user")
cancel_result.token_usage_records = [{"source_run_id": "sub-run-1", "caller": "subagent:gp", "input_tokens": 50, "output_tokens": 25, "total_tokens": 75}]
cancel_result.usage_reported = False
poll_count = 0
def get_result(_: str):
nonlocal poll_count
poll_count += 1
# Main loop polls 3 times (RUNNING each time to keep looping)
if poll_count <= 3:
running = _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
running.token_usage_records = []
running.usage_reported = False
return running
# Shielded wait poll gets the terminal result
return cancel_result
sleep_count = 0
async def cancel_on_third_sleep(_: float) -> None:
nonlocal sleep_count
sleep_count += 1
if sleep_count == 3:
raise asyncio.CancelledError
def fake_report_subagent_usage(runtime, result):
report_calls.append((runtime, result))
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_third_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(task_tool_module, "request_cancel_background_task", lambda _: None)
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
with pytest.raises(asyncio.CancelledError):
_run_task_tool(
runtime=_make_runtime(),
description="执行任务",
prompt="cancel me",
subagent_type="general-purpose",
tool_call_id="tc-cancel-report",
)
# _report_subagent_usage is called synchronously within the cancellation
# handler (after the shielded wait), before CancelledError is re-raised.
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]
@pytest.mark.parametrize(
"status, expected_type",
[
(FakeSubagentStatus.COMPLETED, "task_completed"),
(FakeSubagentStatus.FAILED, "task_failed"),
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
],
)
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
"""Terminal task events include a usage summary from token_usage_records."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
records = [
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
]
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-usage",
)
terminal_events = [e for e in events if e["type"] == expected_type]
assert len(terminal_events) == 1
assert terminal_events[0]["usage"] == {
"input_tokens": 300,
"output_tokens": 130,
"total_tokens": 430,
}
def test_terminal_event_usage_none_when_no_records(monkeypatch):
"""Terminal event has usage=None when token_usage_records is empty."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-no-records",
)
completed = [e for e in events if e["type"] == "task_completed"]
assert len(completed) == 1
assert completed[0]["usage"] is None
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=FileNotFoundError("missing config")),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
runtime = _make_runtime(app_config=app_config)
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
task_tool_module._subagent_usage_cache.clear()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-disabled-cache",
)
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
runtime = _make_runtime(app_config=app_config)
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
with pytest.raises(RuntimeError, match="poll failed"):
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-error",
)
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
@@ -0,0 +1,91 @@
"""Regression tests for _find_usage_recorder callback shape handling.
Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as
an ``AsyncCallbackManager`` (instead of a plain list), the previous
``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object
is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task``
tool call into an error ToolMessage, losing the subagent result.
"""
from types import SimpleNamespace
from langchain_core.callbacks import AsyncCallbackManager, CallbackManager
from deerflow.tools.builtins.task_tool import _find_usage_recorder
class _RecorderHandler:
def record_external_llm_usage_records(self, records):
self.records = records
class _OtherHandler:
pass
def _make_runtime(callbacks):
return SimpleNamespace(config={"callbacks": callbacks})
def test_find_usage_recorder_with_plain_list():
recorder = _RecorderHandler()
runtime = _make_runtime([_OtherHandler(), recorder])
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_with_async_callback_manager():
"""LangChain wraps callbacks in AsyncCallbackManager for async tool runs.
The old implementation raised TypeError here. The recorder lives on
``manager.handlers``; we must look there too.
"""
recorder = _RecorderHandler()
manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_with_sync_callback_manager():
"""Sync flavor of the same wrapper used by some langchain code paths."""
recorder = _RecorderHandler()
manager = CallbackManager(handlers=[recorder])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_returns_none_when_no_recorder():
manager = AsyncCallbackManager(handlers=[_OtherHandler()])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_handles_empty_manager():
manager = AsyncCallbackManager(handlers=[])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_for_none_runtime():
assert _find_usage_recorder(None) is None
def test_find_usage_recorder_returns_none_when_callbacks_is_none():
runtime = _make_runtime(None)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_for_single_handler_object():
"""A single handler instance (not wrapped in a list or manager) should not crash.
LangChain's contract is that ``config["callbacks"]`` is a list-or-manager,
but we treat any other shape defensively rather than letting a ``for`` loop
blow up at runtime.
"""
runtime = _make_runtime(_RecorderHandler())
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_when_config_not_dict():
"""Defensive: a runtime without a dict-shaped config should not raise."""
runtime = SimpleNamespace(config="not-a-dict")
assert _find_usage_recorder(runtime) is None
+438 -66
View File
@@ -1,28 +1,25 @@
"""Tests for ThreadMetaRepository (SQLAlchemy-backed)."""
import logging
import pytest
from deerflow.persistence.thread_meta import ThreadMetaRepository
from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository
async def _make_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
@pytest.fixture
async def repo(tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return ThreadMetaRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
yield ThreadMetaRepository(get_session_factory())
await close_engine()
class TestThreadMetaRepository:
@pytest.mark.anyio
async def test_create_and_get(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_and_get(self, repo):
record = await repo.create("t1")
assert record["thread_id"] == "t1"
assert record["status"] == "idle"
@@ -31,148 +28,523 @@ class TestThreadMetaRepository:
fetched = await repo.get("t1")
assert fetched is not None
assert fetched["thread_id"] == "t1"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_assistant_id(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_with_assistant_id(self, repo):
record = await repo.create("t1", assistant_id="agent1")
assert record["assistant_id"] == "agent1"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_owner_and_display_name(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_with_owner_and_display_name(self, repo):
record = await repo.create("t1", user_id="user1", display_name="My Thread")
assert record["user_id"] == "user1"
assert record["display_name"] == "My Thread"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_metadata(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_with_metadata(self, repo):
record = await repo.create("t1", metadata={"key": "value"})
assert record["metadata"] == {"key": "value"}
await _cleanup()
@pytest.mark.anyio
async def test_get_nonexistent(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_get_nonexistent(self, repo):
assert await repo.get("nonexistent") is None
await _cleanup()
@pytest.mark.anyio
async def test_check_access_no_record_allows(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_no_record_allows(self, repo):
assert await repo.check_access("unknown", "user1") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_owner_matches(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_owner_matches(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_owner_mismatch(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_owner_mismatch(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2") is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_no_owner_allows_all(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_no_owner_allows_all(self, repo):
# Explicit user_id=None to bypass the new AUTO default that
# would otherwise pick up the test user from the autouse fixture.
await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_missing_row_denied(self, tmp_path):
async def test_check_access_strict_missing_row_denied(self, repo):
"""require_existing=True flips the missing-row case to *denied*.
Closes the delete-idempotence cross-user gap: after a thread is
deleted, the row is gone, and the permissive default would let any
caller "claim" it as untracked. The strict mode demands a row.
"""
repo = await _make_repo(tmp_path)
assert await repo.check_access("never-existed", "user1", require_existing=True) is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_strict_owner_match_allowed(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1", require_existing=True) is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_strict_owner_mismatch_denied(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2", require_existing=True) is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
async def test_check_access_strict_null_owner_still_allowed(self, repo):
"""Even in strict mode, a row with NULL user_id stays shared.
The strict flag tightens the *missing row* case, not the *shared
row* case legacy pre-auth rows that survived a clean migration
without an owner are still everyone's.
"""
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone", require_existing=True) is True
await _cleanup()
@pytest.mark.anyio
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_status(self, repo):
await repo.create("t1")
await repo.update_status("t1", "busy")
record = await repo.get("t1")
assert record["status"] == "busy"
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_delete(self, repo):
await repo.create("t1")
await repo.delete("t1")
assert await repo.get("t1") is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_delete_nonexistent_is_noop(self, repo):
await repo.delete("nonexistent") # should not raise
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_merges(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_metadata_merges(self, repo):
await repo.create("t1", metadata={"a": 1, "b": 2})
await repo.update_metadata("t1", {"b": 99, "c": 3})
record = await repo.get("t1")
# Existing key preserved, overlapping key overwritten, new key added
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_on_empty(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_metadata_on_empty(self, repo):
await repo.create("t1")
await repo.update_metadata("t1", {"k": "v"})
record = await repo.get("t1")
assert record["metadata"] == {"k": "v"}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_metadata_nonexistent_is_noop(self, repo):
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
await _cleanup()
# --- search with metadata filter (SQL push-down) ---
@pytest.mark.anyio
async def test_search_metadata_filter_string(self, repo):
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
await repo.create("t3", metadata={"env": "prod", "region": "us"})
results = await repo.search(metadata={"env": "prod"})
ids = {r["thread_id"] for r in results}
assert ids == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_filter_numeric(self, repo):
await repo.create("t1", metadata={"priority": 1})
await repo.create("t2", metadata={"priority": 2})
await repo.create("t3", metadata={"priority": 1, "extra": "x"})
results = await repo.search(metadata={"priority": 1})
ids = {r["thread_id"] for r in results}
assert ids == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_filter_multiple_keys(self, repo):
await repo.create("t1", metadata={"env": "prod", "region": "us"})
await repo.create("t2", metadata={"env": "prod", "region": "eu"})
await repo.create("t3", metadata={"env": "staging", "region": "us"})
results = await repo.search(metadata={"env": "prod", "region": "us"})
assert len(results) == 1
assert results[0]["thread_id"] == "t1"
@pytest.mark.anyio
async def test_search_metadata_no_match(self, repo):
await repo.create("t1", metadata={"env": "prod"})
results = await repo.search(metadata={"env": "dev"})
assert results == []
@pytest.mark.anyio
async def test_search_metadata_pagination_correct(self, repo):
"""Regression: SQL push-down makes limit/offset exact even when most rows don't match."""
for i in range(30):
meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"}
await repo.create(f"t{i:03d}", metadata=meta)
# Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows
all_matches = await repo.search(metadata={"target": "yes"}, limit=100)
assert len(all_matches) == 10
# Paginate: first page
page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0)
assert len(page1) == 3
# Paginate: second page
page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3)
assert len(page2) == 3
# No overlap between pages
page1_ids = {r["thread_id"] for r in page1}
page2_ids = {r["thread_id"] for r in page2}
assert page1_ids.isdisjoint(page2_ids)
# Last page
page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9)
assert len(page_last) == 1
@pytest.mark.anyio
async def test_search_metadata_with_status_filter(self, repo):
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "prod"})
await repo.update_status("t1", "busy")
results = await repo.search(metadata={"env": "prod"}, status="busy")
assert len(results) == 1
assert results[0]["thread_id"] == "t1"
@pytest.mark.anyio
async def test_search_without_metadata_still_works(self, repo):
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2")
results = await repo.search(limit=10)
assert len(results) == 2
@pytest.mark.anyio
async def test_search_metadata_missing_key_no_match(self, repo):
"""Rows without the requested metadata key should not match."""
await repo.create("t1", metadata={"other": "val"})
await repo.create("t2", metadata={"env": "prod"})
results = await repo.search(metadata={"env": "prod"})
assert len(results) == 1
assert results[0]["thread_id"] == "t2"
@pytest.mark.anyio
async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog):
"""When ALL metadata keys are unsafe, raises InvalidMetadataFilterError."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info:
await repo.search(metadata={"bad;key": "x"})
assert any("bad;key" in r.message for r in caplog.records)
# Subclass of ValueError for backward compatibility
assert isinstance(exc_info.value, ValueError)
@pytest.mark.anyio
async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog):
"""Valid keys filter rows; only the invalid key is warned and skipped."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
results = await repo.search(metadata={"env": "prod", "bad;key": "x"})
ids = {r["thread_id"] for r in results}
assert ids == {"t1"}
assert any("bad;key" in r.message for r in caplog.records)
@pytest.mark.anyio
async def test_search_metadata_filter_boolean(self, repo):
"""True matches only boolean true, not integer 1."""
await repo.create("t1", metadata={"active": True})
await repo.create("t2", metadata={"active": False})
await repo.create("t3", metadata={"active": True, "extra": "x"})
await repo.create("t4", metadata={"active": 1})
results = await repo.search(metadata={"active": True})
ids = {r["thread_id"] for r in results}
assert ids == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_filter_none(self, repo):
"""Only rows with explicit JSON null match; missing key does not."""
await repo.create("t1", metadata={"tag": None})
await repo.create("t2", metadata={"tag": "present"})
await repo.create("t3", metadata={"other": "val"})
results = await repo.search(metadata={"tag": None})
ids = {r["thread_id"] for r in results}
assert ids == {"t1"}
@pytest.mark.anyio
async def test_search_metadata_non_string_key_skipped(self, repo, caplog):
"""Non-string keys raise ValueError from isinstance check; should be warned and skipped."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search(metadata={1: "x"})
assert any("1" in r.message for r in caplog.records)
@pytest.mark.anyio
async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog):
"""Unsupported value types (list, dict) raise TypeError; should be warned and skipped."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search(metadata={"env": ["prod", "staging"]})
@pytest.mark.anyio
async def test_search_metadata_dotted_key_raises(self, repo, caplog):
"""Dotted keys are rejected; when ALL keys are dotted, raises ValueError."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search(metadata={"a.b": "anything"})
assert any("a.b" in r.message for r in caplog.records)
# --- dialect-aware type-safe filtering edge cases ---
@pytest.mark.anyio
async def test_search_metadata_bool_vs_int_distinction(self, repo):
"""True must not match 1; False must not match 0."""
await repo.create("bool_true", metadata={"flag": True})
await repo.create("bool_false", metadata={"flag": False})
await repo.create("int_one", metadata={"flag": 1})
await repo.create("int_zero", metadata={"flag": 0})
true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})}
assert true_hits == {"bool_true"}
false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})}
assert false_hits == {"bool_false"}
@pytest.mark.anyio
async def test_search_metadata_int_does_not_match_bool(self, repo):
"""Integer 1 must not match boolean True."""
await repo.create("bool_true", metadata={"val": True})
await repo.create("int_one", metadata={"val": 1})
hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})}
assert hits == {"int_one"}
@pytest.mark.anyio
async def test_search_metadata_none_excludes_missing_key(self, repo):
"""Filtering by None matches explicit JSON null only, not missing key or empty {}."""
await repo.create("explicit_null", metadata={"k": None})
await repo.create("missing_key", metadata={"other": "x"})
await repo.create("empty_obj", metadata={})
hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})}
assert hits == {"explicit_null"}
@pytest.mark.anyio
async def test_search_metadata_float_value(self, repo):
await repo.create("t1", metadata={"score": 3.14})
await repo.create("t2", metadata={"score": 2.71})
await repo.create("t3", metadata={"score": 3.14})
hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})}
assert hits == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_mixed_types_same_key(self, repo):
"""Each type query only matches its own type, even when the key is shared."""
await repo.create("str_row", metadata={"x": "hello"})
await repo.create("int_row", metadata={"x": 42})
await repo.create("bool_row", metadata={"x": True})
await repo.create("null_row", metadata={"x": None})
assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"}
assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"}
assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"}
assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"}
@pytest.mark.anyio
async def test_search_metadata_large_int_precision(self, repo):
"""Integers beyond float precision (> 2**53) must match exactly."""
large = 2**53 + 1
await repo.create("t1", metadata={"id": large})
await repo.create("t2", metadata={"id": large - 1})
hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})}
assert hits == {"t1"}
class TestJsonMatchCompilation:
"""Verify compiled SQL for both SQLite and PostgreSQL dialects."""
def test_json_match_compiles_sqlite(self):
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
engine = create_engine("sqlite://")
cases = [
(None, "json_type(t.data, '$.\"k\"') = 'null'"),
(True, "json_type(t.data, '$.\"k\"') = 'true'"),
(False, "json_type(t.data, '$.\"k\"') = 'false'"),
]
for value, expected_fragment in cases:
expr = json_match(t.c.data, "k", value)
sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})
assert str(sql) == expected_fragment, f"value={value!r}: {sql}"
# int: uses INTEGER cast for precision, type-check narrows to 'integer' only
int_expr = json_match(t.c.data, "k", 42)
sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
assert "json_type" in sql
assert "= 'integer'" in sql
assert "INTEGER" in sql
assert "CAST" in sql
# float: uses REAL cast, type-check spans 'integer' and 'real'
float_expr = json_match(t.c.data, "k", 3.14)
sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
assert "json_type" in sql
assert "IN ('integer', 'real')" in sql
assert "REAL" in sql
str_expr = json_match(t.c.data, "k", "hello")
sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
assert "json_type" in sql
assert "'text'" in sql
def test_json_match_compiles_pg(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
dialect = postgresql.dialect()
cases = [
(None, "json_typeof(t.data -> 'k') = 'null'"),
(True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"),
(False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"),
]
for value, expected_fragment in cases:
expr = json_match(t.c.data, "k", value)
sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
assert str(sql) == expected_fragment, f"value={value!r}: {sql}"
# int: CASE guard prevents CAST error when 'number' also matches floats
int_expr = json_match(t.c.data, "k", 42)
sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "json_typeof" in sql
assert "'number'" in sql
assert "BIGINT" in sql
assert "CASE WHEN" in sql
assert "'^-?[0-9]+$'" in sql
# float: uses DOUBLE PRECISION cast
float_expr = json_match(t.c.data, "k", 3.14)
sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "json_typeof" in sql
assert "'number'" in sql
assert "DOUBLE PRECISION" in sql
str_expr = json_match(t.c.data, "k", "hello")
sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "json_typeof" in sql
assert "'string'" in sql
def test_json_match_rejects_unsafe_key(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]:
with pytest.raises(ValueError, match="JsonMatch key must match"):
json_match(t.c.data, bad_key, "x")
# Non-string keys must also raise ValueError (not TypeError from re.match)
for non_str_key in [42, None, ("k",)]:
with pytest.raises(ValueError, match="JsonMatch key must match"):
json_match(t.c.data, non_str_key, "x")
def test_json_match_rejects_unsupported_value_type(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
for bad_value in [[], {}, object()]:
with pytest.raises(TypeError, match="JsonMatch value must be"):
json_match(t.c.data, "k", bad_value)
def test_json_match_unsupported_dialect_raises(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.dialects import mysql
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
expr = json_match(t.c.data, "k", "v")
with pytest.raises(NotImplementedError, match="mysql"):
str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True}))
def test_json_match_rejects_out_of_range_int(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
# boundary values must be accepted
json_match(t.c.data, "k", 2**63 - 1)
json_match(t.c.data, "k", -(2**63))
# one beyond each boundary must be rejected
for out_of_range in [2**63, -(2**63) - 1, 10**30]:
with pytest.raises(TypeError, match="out of signed 64-bit range"):
json_match(t.c.data, "k", out_of_range)
def test_compiler_raises_on_escaped_key(self):
"""Compiler raises ValueError even when __init__ validation is bypassed."""
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
engine = create_engine("sqlite://")
elem = json_match(t.c.data, "k", "v")
elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert
with pytest.raises(ValueError, match="Key escaped validation"):
str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
with pytest.raises(ValueError, match="Key escaped validation"):
str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
@@ -2,25 +2,30 @@
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.gateway.routers import thread_runs
from deerflow.runtime import RunManager
from deerflow.runtime.runs.store.memory import MemoryRunStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app(event_store=None):
def _make_app(event_store=None, run_manager=None):
"""Build a test FastAPI app with stub auth and mocked state."""
app = make_authed_test_app()
app.include_router(thread_runs.router)
if event_store is not None:
app.state.run_event_store = event_store
if run_manager is not None:
app.state.run_manager = run_manager
return app
@@ -36,6 +41,23 @@ def _make_message(seq: int) -> dict:
return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"}
def _make_store_only_run_manager() -> RunManager:
store = MemoryRunStore()
asyncio.run(
store.put(
"store-only-run",
thread_id="thread-store",
assistant_id="lead_agent",
status="running",
multitask_strategy="reject",
metadata={},
kwargs={},
created_at="2026-01-01T00:00:00+00:00",
)
)
return RunManager(store=store)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@@ -128,3 +150,46 @@ def test_empty_data_when_no_messages():
body = response.json()
assert body["data"] == []
assert body["has_more"] is False
def test_get_run_hydrates_store_only_run():
"""GET /api/threads/{tid}/runs/{rid} should read historical store rows."""
app = _make_app(run_manager=_make_store_only_run_manager())
with TestClient(app) as client:
response = client.get("/api/threads/thread-store/runs/store-only-run")
assert response.status_code == 200
body = response.json()
assert body["run_id"] == "store-only-run"
assert body["thread_id"] == "thread-store"
assert body["status"] == "running"
def test_cancel_store_only_run_returns_409():
"""Store-only runs are readable but not cancellable by this worker."""
app = _make_app(run_manager=_make_store_only_run_manager())
with TestClient(app) as client:
response = client.post("/api/threads/thread-store/runs/store-only-run/cancel")
assert response.status_code == 409
assert "not active on this worker" in response.json()["detail"]
def test_join_store_only_run_returns_409():
"""join endpoint should return 409 for store-only runs (no local stream state)."""
app = _make_app(run_manager=_make_store_only_run_manager())
with TestClient(app) as client:
response = client.get("/api/threads/thread-store/runs/store-only-run/join")
assert response.status_code == 409
assert "not active on this worker" in response.json()["detail"]
def test_stream_store_only_run_returns_409():
"""stream endpoint (action=None) should return 409 for store-only runs."""
app = _make_app(run_manager=_make_store_only_run_manager())
with TestClient(app) as client:
response = client.get("/api/threads/thread-store/runs/store-only-run/stream")
assert response.status_code == 409
assert "not active on this worker" in response.json()["detail"]
@@ -0,0 +1,97 @@
"""Unit tests for ThreadState reducers.
Regression coverage for issue #3123: todos list disappearing after streaming
completes because a downstream node's partial state update with `todos=None`
overwrites the previously accumulated value.
"""
from typing import get_type_hints
from deerflow.agents.thread_state import (
ThreadState,
merge_artifacts,
merge_todos,
merge_viewed_images,
)
class TestMergeTodos:
"""Reducer for ThreadState.todos - keeps last non-None value."""
def test_new_value_overrides_existing(self):
existing = [{"id": 1, "text": "old", "done": False}]
new = [{"id": 1, "text": "old", "done": True}]
assert merge_todos(existing, new) == new
def test_none_new_preserves_existing(self):
"""THE KEY FIX for #3123: a node that doesn't touch todos must NOT
wipe them out by returning an implicit None."""
existing = [{"id": 1, "text": "task", "done": False}]
assert merge_todos(existing, None) == existing
def test_none_existing_accepts_new(self):
new = [{"id": 1, "text": "first todo"}]
assert merge_todos(None, new) == new
def test_both_none_returns_none(self):
assert merge_todos(None, None) is None
def test_empty_list_is_explicit_clear(self):
"""An explicit empty list means 'user cleared all todos' and must
win over the previous list."""
existing = [{"id": 1, "text": "task"}]
assert merge_todos(existing, []) == []
class TestMergeArtifacts:
"""Sanity check for the existing artifacts reducer."""
def test_dedupes_and_preserves_order(self):
assert merge_artifacts(["a", "b"], ["b", "c"]) == ["a", "b", "c"]
def test_none_new_preserves_existing(self):
assert merge_artifacts(["a"], None) == ["a"]
def test_none_existing_accepts_new(self):
assert merge_artifacts(None, ["a"]) == ["a"]
class TestMergeViewedImages:
"""Sanity check for the existing viewed_images reducer."""
def test_merges_dicts(self):
existing = {"k1": {"base64": "x", "mime_type": "image/png"}}
new = {"k2": {"base64": "y", "mime_type": "image/jpeg"}}
merged = merge_viewed_images(existing, new)
assert set(merged.keys()) == {"k1", "k2"}
def test_empty_dict_clears(self):
existing = {"k1": {"base64": "x", "mime_type": "image/png"}}
assert merge_viewed_images(existing, {}) == {}
class TestThreadStateAnnotations:
"""Regression guards: ensure reducer wiring on ThreadState fields.
These tests protect against silent regressions where a field's
``Annotated[..., reducer]`` is reverted to a plain type, which would
re-introduce bugs even when the reducer functions themselves remain
correct.
"""
def test_todos_field_is_wired_to_merge_todos(self):
"""ThreadState.todos must use merge_todos.
Without this Annotated binding, LangGraph falls back to last-value-wins
behavior, and partial state updates that omit todos will silently clear
previously streamed values.
"""
hints = get_type_hints(ThreadState, include_extras=True)
todos_hint = hints["todos"]
assert hasattr(todos_hint, "__metadata__"), "ThreadState.todos must be Annotated with a reducer"
assert merge_todos in todos_hint.__metadata__, "ThreadState.todos must be wired to merge_todos reducer (see #3123)"
def test_artifacts_field_is_wired_to_merge_artifacts(self):
"""Sanity check that existing reducer wiring is preserved."""
hints = get_type_hints(ThreadState, include_extras=True)
assert merge_artifacts in hints["artifacts"].__metadata__
+27
View File
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
},
}
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
+54
View File
@@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore
from app.gateway.routers import threads
from deerflow.config.paths import Paths
from deerflow.persistence.thread_meta import InvalidMetadataFilterError
from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore
_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None
assert entries, "expected at least one history entry"
for entry in entries:
assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry
# ── Metadata filter validation at API boundary ────────────────────────────────
def test_search_threads_rejects_invalid_key_at_api_boundary() -> None:
"""Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic
validator on ThreadSearchRequest.metadata 422 from both backends.
"""
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}})
assert response.status_code == 422
def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None:
"""Value types outside (None, bool, int, float, str) are rejected."""
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}})
assert response.status_code == 422
def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None:
"""If the backend still raises InvalidMetadataFilterError (defense in
depth), the handler surfaces it as HTTP 400.
"""
app, _store, _checkpointer = _build_thread_app()
thread_store = app.state.thread_store
async def _raise(**kwargs):
raise InvalidMetadataFilterError("rejected")
with TestClient(app) as client:
with patch.object(thread_store, "search", side_effect=_raise):
response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}})
assert response.status_code == 400
assert "rejected" in response.json()["detail"]
def test_search_threads_succeeds_with_valid_metadata() -> None:
"""Sanity check: valid metadata passes through without error."""
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}})
assert response.status_code == 200
@@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic:
assert middleware._should_generate_title(state) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
_set_test_title_config(max_chars=12, model_name=None)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
@@ -109,7 +109,7 @@ class TestTitleMiddlewareCoreLogic:
title = result["title"]
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, attach_tracing=False)
model.ainvoke.assert_awaited_once()
assert model.ainvoke.await_args.kwargs["config"] == {
"run_name": "title_agent",
@@ -141,6 +141,7 @@ class TestTitleMiddlewareCoreLogic:
title_middleware_module.create_chat_model.assert_called_once_with(
name="title-model",
thinking_enabled=False,
attach_tracing=False,
app_config=app_config,
)
+401 -20
View File
@@ -1,17 +1,23 @@
"""Tests for TodoMiddleware context-loss detection."""
import asyncio
from unittest.mock import MagicMock
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import PrivateAttr
from deerflow.agents.middlewares.todo_middleware import (
TodoMiddleware,
_completion_reminder_count,
_format_todos,
_has_tool_call_intent_or_error,
_reminder_in_messages,
_todos_in_messages,
)
from deerflow.agents.thread_state import ThreadState
def _ai_with_write_todos():
@@ -22,9 +28,35 @@ def _reminder_msg():
return HumanMessage(name="todo_reminder", content="reminder")
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_runtime():
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
return runtime
def _make_runtime_for(thread_id: str, run_id: str):
runtime = _make_runtime()
runtime.context = {"thread_id": thread_id, "run_id": run_id}
return runtime
@@ -161,10 +193,62 @@ def _completion_reminder_msg():
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
def _todo_completion_reminders(messages):
reminders = []
for message in messages:
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
reminders.append(message)
return reminders
def _ai_no_tool_calls():
return AIMessage(content="I'm done!")
def _ai_with_invalid_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[
{
"type": "invalid_tool_call",
"id": "write_file:36",
"name": "write_file",
"args": "{invalid",
"error": "Failed to parse tool arguments",
}
],
)
def _ai_with_raw_provider_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[],
additional_kwargs={
"tool_calls": [
{
"id": "raw-tool-call",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
}
]
},
)
def _ai_with_legacy_function_call():
return AIMessage(
content="",
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
)
def _ai_with_tool_finish_reason():
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
def _incomplete_todos():
return [
{"status": "completed", "content": "Step 1"},
@@ -194,6 +278,36 @@ class TestCompletionReminderCount:
assert _completion_reminder_count(msgs) == 1
class TestToolCallIntentOrError:
def test_false_for_plain_final_answer(self):
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
def test_true_for_structured_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
def test_true_for_invalid_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
def test_true_for_raw_provider_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
def test_true_for_legacy_function_call(self):
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
def test_true_for_tool_finish_reason(self):
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
# Sentinel for LangChain compatibility: if future AIMessage versions add
# new top-level tool/function-call fields, this test should fail. When
# it does, update `_has_tool_call_intent_or_error()` so the completion
# reminder guard explicitly decides whether each new field means "not a
# clean final answer"; the helper has a matching comment pointing back
# to this sentinel.
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
class TestAfterModel:
def test_returns_none_when_agent_still_using_tools(self):
mw = TodoMiddleware()
@@ -235,68 +349,335 @@ class TestAfterModel:
}
assert mw.after_model(state, _make_runtime()) is None
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, _make_runtime())
result = mw.after_model(state, runtime)
assert result is not None
assert result["jump_to"] == "model"
assert len(result["messages"]) == 1
reminder = result["messages"][0]
assert "messages" not in result
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_called_once()
reminder = request.override.call_args.kwargs["messages"][-1]
assert isinstance(reminder, HumanMessage)
assert reminder.name == "todo_completion_reminder"
assert reminder.additional_kwargs["hide_from_ui"] is True
assert "Step 2" in reminder.content
assert "Step 3" in reminder.content
handler.assert_called_once_with("patched-request")
def test_reminder_lists_only_incomplete_items(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, _make_runtime())
content = result["messages"][0].content
result = mw.after_model(state, runtime)
assert result is not None
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
mw.wrap_model_call(request, MagicMock(return_value="response"))
content = request.override.call_args.kwargs["messages"][-1].content
assert "Step 1" not in content # completed — should not appear
assert "Step 2" in content
assert "Step 3" in content
def test_allows_exit_after_max_reminders(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_completion_reminder_msg(),
_completion_reminder_msg(),
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is None
def test_still_sends_reminder_before_cap(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
result = mw.after_model(state, runtime)
assert result is not None
assert result["jump_to"] == "model"
def test_does_not_trigger_for_invalid_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_invalid_tool_calls()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_still_sends_reminder_before_cap(self):
def test_does_not_trigger_for_raw_provider_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [
_completion_reminder_msg(), # 1 reminder so far
_ai_no_tool_calls(),
],
"messages": [_ai_with_raw_provider_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, _make_runtime())
assert result is not None
assert result["jump_to"] == "model"
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_legacy_function_call(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_legacy_function_call()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_tool_finish_reason(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_tool_finish_reason()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
class TestAafterModel:
def test_delegates_to_sync(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
result = asyncio.run(mw.aafter_model(state, runtime))
assert result is not None
assert result["jump_to"] == "model"
assert result["messages"][0].name == "todo_completion_reminder"
assert "messages" not in result
class TestWrapModelCall:
def test_no_pending_reminder_passthrough(self):
mw = TodoMiddleware()
request = MagicMock()
request.runtime = _make_runtime()
request.messages = [HumanMessage(content="hi")]
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
def test_pending_reminder_is_injected_once(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
request.override.reset_mock()
handler.reset_mock()
handler.return_value = "second-response"
assert mw.wrap_model_call(request, handler) == "second-response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
class TestTodoMiddlewareAgentGraphIntegration:
def test_reuses_thread_state_todos_schema_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "pending"},
]
},
}
],
),
AIMessage(content="final"),
],
)
graph = create_agent(
model=model,
tools=[],
middleware=[mw],
state_schema=ThreadState,
)
result = graph.invoke(
{"messages": [("user", "create a todo")]},
context={"thread_id": "schema-thread", "run_id": "schema-run"},
)
assert result["todos"] == [{"content": "Step 1", "status": "pending"}]
def test_completion_reminder_is_transient_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
},
}
],
),
AIMessage(content="premature final 1"),
AIMessage(content="premature final 2"),
AIMessage(content="premature final 3"),
],
)
graph = create_agent(model=model, tools=[], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "finish all todos")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
)
assert len(model.seen_messages) == 4
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
assert reminders_by_call[0] == []
assert reminders_by_call[1] == []
assert len(reminders_by_call[2]) == 1
assert len(reminders_by_call[3]) == 1
assert "Step 1" not in reminders_by_call[2][0].content
assert "Step 2" in reminders_by_call[2][0].content
persisted_reminders = _todo_completion_reminders(result["messages"])
assert persisted_reminders == []
assert result["messages"][-1].content == "premature final 3"
assert result["todos"] == [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
assert mw._pending_completion_reminders == {}
assert mw._completion_reminder_counts == {}
class TestRunScopedReminderCleanup:
def test_before_agent_clears_stale_count_without_pending_reminder(self):
mw = TodoMiddleware()
stale_runtime = _make_runtime()
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
current_runtime = _make_runtime()
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
other_thread_runtime = _make_runtime()
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, other_thread_runtime) is not None
# Simulate a model call that drained the pending message, followed by an
# abnormal run end where after_agent did not clear the reminder count.
assert mw._drain_completion_reminders(stale_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
mw.before_agent({}, current_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 2
first_runtime = _make_runtime_for("thread-a", "run-a")
second_runtime = _make_runtime_for("thread-b", "run-b")
third_runtime = _make_runtime_for("thread-c", "run-c")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, first_runtime) is not None
# Simulate the normal model request path: pending reminder is consumed,
# but the run count remains until after_agent() or stale cleanup.
assert mw._drain_completion_reminders(first_runtime)
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
assert mw.after_model(state, second_runtime) is not None
assert mw.after_model(state, third_runtime) is not None
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
def test_size_guard_prunes_pending_and_count_state_together(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 1
stale_runtime = _make_runtime_for("thread-a", "run-a")
current_runtime = _make_runtime_for("thread-b", "run-b")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, current_runtime) is not None
assert mw._drain_completion_reminders(stale_runtime) == []
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
class TestAwrapModelCall:
def test_async_pending_reminder_is_injected(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = AsyncMock(return_value="response")
result = asyncio.run(mw.awrap_model_call(request, handler))
assert result == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
handler.assert_awaited_once_with("patched-request")
+48 -1
View File
@@ -1,9 +1,10 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove",
}
]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}
@@ -89,3 +89,20 @@ def test_tool_args_schema_does_not_emit_pydantic_context_warning(tool_obj, extra
pydantic_warnings = [w for w in caught if "PydanticSerializationUnexpectedValue" in str(w.message)]
assert not pydantic_warnings, f"{tool_obj.name} args_schema.model_dump() emitted Pydantic context serialization warnings: {[str(w.message) for w in pydantic_warnings]}"
def test_write_file_append_is_discoverable_in_tool_schema() -> None:
"""``append`` must be visible and described in the model-facing tool schema."""
assert "append" in write_file_tool.description
append_field = write_file_tool.tool_call_schema.model_fields["append"]
assert append_field.default is False
assert append_field.description
assert "append" in append_field.description
@pytest.mark.parametrize("tool_obj", [case[0] for case in _TOOL_CASES], ids=[case[0].name for case in _TOOL_CASES])
def test_model_facing_tool_parameters_have_descriptions(tool_obj) -> None:
"""Every model-facing tool parameter should explain when and how to use it."""
missing_descriptions = [field_name for field_name, field in tool_obj.tool_call_schema.model_fields.items() if not field.description]
assert missing_descriptions == [], f"{tool_obj.name} has model-facing parameters without descriptions: {missing_descriptions}. Add an Args: section to the tool's docstring and ensure @tool(parse_docstring=True) is set."
+101 -7
View File
@@ -10,7 +10,8 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
from langchain_core.tools import BaseTool, tool
from langchain_core.tools import BaseTool, StructuredTool, tool
from pydantic import BaseModel, Field
from deerflow.tools.tools import get_available_tools
@@ -19,6 +20,10 @@ from deerflow.tools.tools import get_available_tools
# ---------------------------------------------------------------------------
class AsyncToolArgs(BaseModel):
x: int = Field(..., description="test input")
@tool
def _tool_alpha(x: str) -> str:
"""Alpha tool."""
@@ -52,14 +57,105 @@ def _make_minimal_config(tools):
config.tools = tools
config.models = []
config.tool_search.enabled = False
config.skill_evolution.enabled = False
config.sandbox = MagicMock()
config.acp_agents = {}
return config
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Config-loaded async-only tools can still be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
return f"result: {x}"
async_tool = StructuredTool(
name="async_tool",
description="Async-only test tool.",
args_schema=AsyncToolArgs,
func=None,
coroutine=async_tool_impl,
)
tool_cfg = MagicMock()
tool_cfg.name = "async_tool"
tool_cfg.group = "test"
tool_cfg.use = "tests.fake:async_tool"
mock_cfg.return_value = _make_minimal_config([tool_cfg])
with (
patch("deerflow.tools.tools.resolve_variable", return_value=async_tool),
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
):
result = get_available_tools(include_mcp=False, app_config=mock_cfg.return_value)
assert async_tool in result
assert async_tool.func is not None
assert async_tool.invoke({"x": 42}) == "result: 42"
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_subagent_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Async-only tools added through the subagent path can be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
return f"subagent: {x}"
async_tool = StructuredTool(
name="async_subagent_tool",
description="Async-only subagent test tool.",
args_schema=AsyncToolArgs,
func=None,
coroutine=async_tool_impl,
)
mock_cfg.return_value = _make_minimal_config([])
with (
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
patch("deerflow.tools.tools.SUBAGENT_TOOLS", [async_tool]),
):
result = get_available_tools(include_mcp=False, subagent_enabled=True, app_config=mock_cfg.return_value)
assert async_tool in result
assert async_tool.func is not None
assert async_tool.invoke({"x": 7}) == "subagent: 7"
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_acp_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Async-only ACP tools can be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
return f"acp: {x}"
async_tool = StructuredTool(
name="invoke_acp_agent",
description="Async-only ACP test tool.",
args_schema=AsyncToolArgs,
func=None,
coroutine=async_tool_impl,
)
config = _make_minimal_config([])
config.acp_agents = {"codex": object()}
mock_cfg.return_value = config
with (
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
patch("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", return_value=async_tool),
):
result = get_available_tools(include_mcp=False, app_config=config)
assert async_tool in result
assert async_tool.func is not None
assert async_tool.invoke({"x": 9}) == "acp: 9"
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_no_duplicates_returned(mock_bash, mock_cfg):
"""get_available_tools() never returns two tools with the same name."""
mock_cfg.return_value = _make_minimal_config([])
@@ -73,8 +169,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
def test_first_occurrence_wins(mock_bash, mock_cfg):
"""When duplicates exist, the first occurrence is kept."""
mock_cfg.return_value = _make_minimal_config([])
@@ -92,8 +187,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
"""A warning is logged for every skipped duplicate."""
import logging
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
assert len(middlewares) == 6
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
# (enabled by default — see SafetyFinishReasonConfig).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
assert len(middlewares) == 7
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
def test_wrap_tool_call_passthrough_on_success():
+2 -1
View File
@@ -5,10 +5,11 @@ from __future__ import annotations
import pytest
from deerflow.config import tracing_config as tracing_module
from deerflow.config.tracing_config import reset_tracing_config
def _reset_tracing_cache() -> None:
tracing_module._tracing_config = None
reset_tracing_config()
@pytest.fixture(autouse=True)
+5 -5
View File
@@ -12,7 +12,7 @@ from deerflow.tracing import factory as tracing_factory
@pytest.fixture(autouse=True)
def clear_tracing_env(monkeypatch):
from deerflow.config import tracing_config as tracing_module
from deerflow.config.tracing_config import reset_tracing_config
for name in (
"LANGSMITH_TRACING",
@@ -30,9 +30,9 @@ def clear_tracing_env(monkeypatch):
"LANGFUSE_BASE_URL",
):
monkeypatch.delenv(name, raising=False)
tracing_module._tracing_config = None
reset_tracing_config()
yield
tracing_module._tracing_config = None
reset_tracing_config()
def test_build_tracing_callbacks_returns_empty_list_when_disabled(monkeypatch):
@@ -114,12 +114,12 @@ def test_build_tracing_callbacks_raises_when_enabled_provider_fails(monkeypatch)
def test_build_tracing_callbacks_raises_for_explicitly_enabled_misconfigured_provider(monkeypatch):
from deerflow.config import tracing_config as tracing_module
from deerflow.config.tracing_config import reset_tracing_config
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
tracing_module._tracing_config = None
reset_tracing_config()
with pytest.raises(ValueError, match="LANGFUSE_PUBLIC_KEY"):
tracing_factory.build_tracing_callbacks()
+137
View File
@@ -0,0 +1,137 @@
"""Tests for deerflow.tracing.metadata.build_langfuse_trace_metadata."""
from __future__ import annotations
import pytest
from deerflow.tracing import metadata as tracing_metadata
@pytest.fixture(autouse=True)
def _clear_tracing_env(monkeypatch):
from deerflow.config.tracing_config import reset_tracing_config
for name in (
"LANGFUSE_TRACING",
"LANGFUSE_PUBLIC_KEY",
"LANGFUSE_SECRET_KEY",
"LANGFUSE_BASE_URL",
"LANGSMITH_TRACING",
"LANGCHAIN_TRACING_V2",
"LANGCHAIN_TRACING",
"LANGSMITH_API_KEY",
"LANGCHAIN_API_KEY",
):
monkeypatch.delenv(name, raising=False)
reset_tracing_config()
yield
reset_tracing_config()
def _enable_langfuse(monkeypatch):
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
def test_returns_empty_when_langfuse_disabled(monkeypatch):
# No env vars set → langfuse not in enabled providers.
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="t-1",
user_id="u-1",
assistant_id="lead-agent",
model_name="gpt-4o",
)
assert result == {}
def test_session_id_maps_to_thread_id(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="thread-abc",
user_id="user-42",
)
assert result["langfuse_session_id"] == "thread-abc"
def test_user_id_falls_back_to_default(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="thread-abc",
user_id=None,
)
assert result["langfuse_user_id"] == "default"
def test_user_id_explicit_value_wins(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="thread-abc",
user_id="alice@example.com",
)
assert result["langfuse_user_id"] == "alice@example.com"
def test_trace_name_uses_assistant_id_when_provided(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="t",
assistant_id="custom-agent",
)
assert result["langfuse_trace_name"] == "custom-agent"
def test_trace_name_defaults_to_lead_agent(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="t",
assistant_id=None,
)
assert result["langfuse_trace_name"] == "lead-agent"
def test_tags_include_env_and_model(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="t",
environment="production",
model_name="gpt-4o",
)
assert result["langfuse_tags"] == ["env:production", "model:gpt-4o"]
def test_tags_omitted_when_no_tag_inputs(monkeypatch):
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id="t",
user_id="u",
)
assert "langfuse_tags" not in result
def test_thread_id_none_still_produces_metadata(monkeypatch):
# Stateless run paths may not have a thread_id — we still want
# user_id / trace_name to flow through so Users page works.
_enable_langfuse(monkeypatch)
result = tracing_metadata.build_langfuse_trace_metadata(
thread_id=None,
user_id="u-1",
)
assert result["langfuse_session_id"] is None
assert result["langfuse_user_id"] == "u-1"
@@ -0,0 +1,253 @@
"""End-to-end verification for update_agent's user_id resolution.
PR #2784 hardened setup_agent to prefer runtime.context["user_id"] over the
contextvar. update_agent had the same latent gap: it unconditionally called
get_effective_user_id() at module level, so any scenario where the contextvar
was unavailable while runtime.context carried user_id (a background task
scheduled outside the request task, a worker pool that doesn't copy_context,
checkpoint resume on a different task) would silently route writes to
users/default/agents/...
These tests are load-bearing under @no_auto_user (contextvar empty):
- The negative-control test confirms the fixture actually puts the tool in
the regime where the contextvar fallback would land in users/default/.
Without that, the positive test would be vacuously satisfied.
- The positive test verifies update_agent honours runtime.context["user_id"]
injected by inject_authenticated_user_context in the gateway. Before the
fix in this PR, this test failed; now it passes.
"""
from __future__ import annotations
from contextlib import ExitStack
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
import yaml
from _agent_e2e_helpers import build_single_tool_call_model
from langchain_core.messages import HumanMessage
from app.gateway.services import (
build_run_config,
inject_authenticated_user_context,
merge_run_context_overrides,
)
from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context
def _make_request(user_id_str: str | None) -> SimpleNamespace:
user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") if user_id_str else None
return SimpleNamespace(state=SimpleNamespace(user=user))
def _assemble_config(*, body_context: dict | None, request_user_id: str | None, thread_id: str) -> dict:
config = build_run_config(thread_id, {"recursion_limit": 50}, None, assistant_id="lead_agent")
merge_run_context_overrides(config, body_context)
inject_authenticated_user_context(config, _make_request(request_user_id))
return config
def _seed_existing_agent(tmp_path: Path, user_id: str, agent_name: str, soul: str = "# Original"):
"""Pre-create an agent on disk for update_agent to overwrite."""
agent_dir = tmp_path / "users" / user_id / "agents" / agent_name
agent_dir.mkdir(parents=True, exist_ok=True)
(agent_dir / "config.yaml").write_text(
yaml.dump({"name": agent_name, "description": "old"}, allow_unicode=True),
encoding="utf-8",
)
(agent_dir / "SOUL.md").write_text(soul, encoding="utf-8")
return agent_dir
def _make_paths_mock(tmp_path: Path):
paths = MagicMock()
paths.base_dir = tmp_path
paths.agent_dir = lambda name: tmp_path / "agents" / name
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
return paths
def _patch_update_agent_dependencies(tmp_path: Path):
"""update_agent reads load_agent_config + get_app_config — stub them
minimally so the tool can run without a real config file or LLM."""
fake_model_cfg = SimpleNamespace(name="fake-model")
fake_app_cfg = MagicMock()
fake_app_cfg.get_model_config = lambda name: fake_model_cfg if name == "fake-model" else None
return [
patch(
"deerflow.tools.builtins.update_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
),
patch(
"deerflow.tools.builtins.update_agent_tool.get_app_config",
return_value=fake_app_cfg,
),
# load_agent_config (used by update_agent to read existing config) also
# reads paths via its own module-level get_paths reference. Patch it too
# or the tool returns "Agent does not exist" before touching disk.
patch(
"deerflow.config.agents_config.get_paths",
return_value=_make_paths_mock(tmp_path),
),
]
def _build_update_graph(*, soul_payload: str):
from langchain.agents import create_agent
from deerflow.tools.builtins.update_agent_tool import update_agent
fake_model = build_single_tool_call_model(
tool_name="update_agent",
tool_args={"soul": soul_payload, "description": "refined"},
tool_call_id="call_update_1",
final_text="updated",
)
return create_agent(model=fake_model, tools=[update_agent], system_prompt="updater")
# ---------------------------------------------------------------------------
# Negative control — proves the test environment puts update_agent in the
# regime where the contextvar fallback would land in default/.
# ---------------------------------------------------------------------------
@pytest.mark.no_auto_user
def test_update_agent_falls_back_to_default_when_no_inject_and_no_contextvar(tmp_path: Path):
"""No request.state.user, no contextvar — update_agent must look in
users/default/agents/. We seed the file there so the tool succeeds and
we know which directory it actually consulted."""
from langgraph.runtime import Runtime
_seed_existing_agent(tmp_path, "default", "fallback-target")
config = _assemble_config(
body_context={"agent_name": "fallback-target"},
request_user_id=None, # no auth, inject is no-op
thread_id="thread-update-1",
)
runtime_ctx = _build_runtime_context("thread-update-1", "run-1", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_update_graph(soul_payload="# Fallback Updated")
with ExitStack() as stack:
for p in _patch_update_agent_dependencies(tmp_path):
stack.enter_context(p)
graph.invoke(
{"messages": [HumanMessage(content="update fallback-target")]},
config=config,
)
soul = (tmp_path / "users" / "default" / "agents" / "fallback-target" / "SOUL.md").read_text()
assert soul == "# Fallback Updated", "Sanity: tool should have written under default/"
# ---------------------------------------------------------------------------
# Regression guard — passes on this branch, would fail on main before the fix.
# ---------------------------------------------------------------------------
@pytest.mark.no_auto_user
def test_update_agent_should_use_runtime_context_user_id_when_contextvar_missing(tmp_path: Path):
"""update_agent prefers the authenticated user_id carried in
runtime.context (placed there by inject_authenticated_user_context)
over the contextvar same contract as setup_agent (PR #2784).
Before this PR's fix, update_agent unconditionally called
get_effective_user_id() and landed in default/ whenever the contextvar
was unavailable. This test pins the corrected behaviour.
"""
from langgraph.runtime import Runtime
auth_uid = "abcdef01-2345-6789-abcd-ef0123456789"
# Seed the agent in BOTH locations so we can prove which one was opened.
auth_dir = _seed_existing_agent(tmp_path, auth_uid, "shared-name", soul="# Auth Original")
default_dir = _seed_existing_agent(tmp_path, "default", "shared-name", soul="# Default Original")
config = _assemble_config(
body_context={"agent_name": "shared-name"},
request_user_id=auth_uid,
thread_id="thread-update-2",
)
runtime_ctx = _build_runtime_context("thread-update-2", "run-2", config.get("context"), None)
assert runtime_ctx["user_id"] == auth_uid, "Pre-condition: inject must have placed user_id into runtime_ctx"
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_update_graph(soul_payload="# Auth Updated")
with ExitStack() as stack:
for p in _patch_update_agent_dependencies(tmp_path):
stack.enter_context(p)
graph.invoke(
{"messages": [HumanMessage(content="update shared-name")]},
config=config,
)
auth_soul = (auth_dir / "SOUL.md").read_text()
default_soul = (default_dir / "SOUL.md").read_text()
assert auth_soul == "# Auth Updated", f"REGRESSION: update_agent ignored runtime.context['user_id']={auth_uid!r} and routed the write to users/default/ instead. auth_soul={auth_soul!r}, default_soul={default_soul!r}"
assert default_soul == "# Default Original", "REGRESSION: update_agent corrupted the shared default-user agent. It should have written under the authenticated user's path."
# ---------------------------------------------------------------------------
# Positive — when contextvar IS the auth user (the normal HTTP case), things
# already work. Pin it as a regression guard so future refactors don't
# accidentally break the contextvar path in pursuit of the runtime-context fix.
# ---------------------------------------------------------------------------
def test_update_agent_uses_contextvar_when_present(tmp_path: Path, monkeypatch):
"""The normal HTTP case: contextvar is set by auth_middleware. This must
keep working regardless of how runtime.context is populated."""
from types import SimpleNamespace as _SN
from deerflow.runtime.user_context import reset_current_user, set_current_user
auth_uid = "11112222-3333-4444-5555-666677778888"
user = _SN(id=auth_uid, email="ctxvar@local")
_seed_existing_agent(tmp_path, auth_uid, "ctxvar-agent", soul="# Original")
from langgraph.runtime import Runtime
config = _assemble_config(
body_context={"agent_name": "ctxvar-agent"},
request_user_id=auth_uid,
thread_id="thread-update-3",
)
runtime_ctx = _build_runtime_context("thread-update-3", "run-3", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_update_graph(soul_payload="# CtxVar Updated")
with ExitStack() as stack:
for p in _patch_update_agent_dependencies(tmp_path):
stack.enter_context(p)
token = set_current_user(user)
try:
final = graph.invoke(
{"messages": [HumanMessage(content="update ctxvar-agent")]},
config=config,
)
finally:
reset_current_user(token)
# surface the tool's reply for debug if it errored
tool_replies = [m.content for m in final["messages"] if getattr(m, "type", "") == "tool"]
soul = (tmp_path / "users" / auth_uid / "agents" / "ctxvar-agent" / "SOUL.md").read_text()
assert soul == "# CtxVar Updated", f"tool replies: {tool_replies}"
+61
View File
@@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app
from fastapi import HTTPException, UploadFile
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import uploads
@@ -218,6 +219,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
provider = MagicMock()
provider.uses_thread_data_mounts = True
provider.needs_upload_permission_adjustment = False
provider.acquire.return_value = "local"
sandbox = MagicMock()
provider.get.return_value = sandbox
@@ -227,12 +229,17 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
patch.object(uploads, "_make_file_sandbox_readable") as make_readable,
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_writable.assert_not_called()
# Readable adjustment is now always applied regardless of sandbox type
make_readable.assert_called_once()
called_path = make_readable.call_args[0][0]
assert called_path.name == "notes.txt"
def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path):
@@ -430,6 +437,59 @@ def test_make_file_sandbox_writable_skips_symlinks(tmp_path):
chmod.assert_not_called()
def test_make_file_sandbox_readable_adds_read_bits_for_regular_files(tmp_path):
file_path = tmp_path / "data.csv"
file_path.write_bytes(b"csv-data")
# Simulate the 0o600 permissions set by open_upload_file_no_symlink
file_path.chmod(0o600)
uploads._make_file_sandbox_readable(file_path)
updated_mode = stat.S_IMODE(file_path.stat().st_mode)
assert updated_mode & stat.S_IRUSR
assert updated_mode & stat.S_IRGRP
assert updated_mode & stat.S_IROTH
def test_make_file_sandbox_readable_skips_symlinks(tmp_path):
file_path = tmp_path / "target-link.txt"
file_path.write_text("hello", encoding="utf-8")
symlink_stat = MagicMock(st_mode=stat.S_IFLNK)
with (
patch.object(uploads.os, "lstat", return_value=symlink_stat),
patch.object(uploads.os, "chmod") as chmod,
):
uploads._make_file_sandbox_readable(file_path)
chmod.assert_not_called()
def test_upload_files_adjusts_read_permissions_for_mounted_non_local_sandbox(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
# AIO sandbox with LocalContainerBackend: uses_thread_data_mounts=True
# but needs_upload_permission_adjustment=True (default)
provider = MagicMock()
provider.uses_thread_data_mounts = True
provider.needs_upload_permission_adjustment = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
patch.object(uploads, "_make_file_sandbox_readable") as make_readable,
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_readable.assert_called_once()
called_path = make_readable.call_args[0][0]
assert called_path.name == "notes.txt"
def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
@@ -631,6 +691,7 @@ def test_upload_limits_endpoint_requires_thread_access():
cfg.uploads = {}
app = make_authed_test_app(owner_check_passes=False)
app.state.config = cfg
app.dependency_overrides[get_config] = lambda: cfg
app.include_router(uploads.router)
with TestClient(app) as client:
@@ -0,0 +1,248 @@
"""Integration test: worker.run_agent injects Langfuse trace metadata.
Verifies that the agent factory's resulting graph receives a
``RunnableConfig`` whose ``metadata`` carries the Langfuse reserved keys
(``langfuse_session_id`` / ``langfuse_user_id`` / ``langfuse_trace_name``).
"""
from __future__ import annotations
import asyncio
import pytest
from deerflow.runtime.runs.manager import RunRecord
from deerflow.runtime.runs.schemas import DisconnectMode, RunStatus
from deerflow.runtime.runs.worker import RunContext, run_agent
class _FakeAgent:
"""Minimal LangGraph-like graph that captures the runnable config."""
def __init__(self) -> None:
self.captured_config: dict | None = None
self.metadata: dict = {}
# Worker may assign these attributes; need them to exist.
self.checkpointer = None
self.store = None
self.interrupt_before_nodes: list[str] = []
self.interrupt_after_nodes: list[str] = []
async def astream(self, graph_input, *, config, stream_mode, **kwargs):
self.captured_config = config
# Empty async generator — no chunks produced.
return
yield # pragma: no cover (makes this an async generator)
class _FakeRunManager:
async def set_status(self, *_args, **_kwargs) -> None:
return None
async def update_model_name(self, *_args, **_kwargs) -> None:
return None
async def update_run_completion(self, *_args, **_kwargs) -> None:
return None
class _FakeBridge:
def __init__(self) -> None:
self.events: list[tuple[str, object]] = []
async def publish(self, _run_id, event, payload) -> None:
self.events.append((event, payload))
async def publish_end(self, _run_id) -> None:
self.events.append(("end", None))
async def cleanup(self, _run_id, *, delay: int = 0) -> None:
return None
@pytest.fixture(autouse=True)
def _clear_tracing_env(monkeypatch):
from deerflow.config.tracing_config import reset_tracing_config
for name in ("LANGFUSE_TRACING", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY", "LANGFUSE_BASE_URL"):
monkeypatch.delenv(name, raising=False)
reset_tracing_config()
yield
reset_tracing_config()
@pytest.mark.asyncio
async def test_run_agent_injects_langfuse_metadata(monkeypatch):
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
from deerflow.config.tracing_config import reset_tracing_config
reset_tracing_config()
fake_agent = _FakeAgent()
def agent_factory(config):
return fake_agent
record = RunRecord(
run_id="run-1",
thread_id="thread-xyz",
assistant_id="lead-agent",
status=RunStatus.pending,
on_disconnect=DisconnectMode.cancel,
model_name="gpt-4o",
)
record.abort_event = asyncio.Event()
ctx = RunContext(checkpointer=None)
await run_agent(
_FakeBridge(),
_FakeRunManager(),
record,
ctx=ctx,
agent_factory=agent_factory,
graph_input={"messages": []},
config={"configurable": {"thread_id": "thread-xyz"}},
)
assert fake_agent.captured_config is not None, "astream was not invoked"
metadata = fake_agent.captured_config.get("metadata") or {}
assert metadata.get("langfuse_session_id") == "thread-xyz"
# conftest.py autouse fixture injects ``test-user-autouse`` into the
# contextvar — the worker should read it via ``get_effective_user_id``.
user_id = metadata.get("langfuse_user_id")
assert user_id == "test-user-autouse", f"expected test-user-autouse, got {user_id}"
assert metadata.get("langfuse_trace_name") == "lead-agent"
tags = metadata.get("langfuse_tags") or []
assert "model:gpt-4o" in tags
@pytest.mark.asyncio
async def test_run_agent_falls_back_to_default_user_when_unset(monkeypatch):
"""When no user is in the contextvar, langfuse_user_id falls back to 'default'.
Uses ``monkeypatch.setattr`` to redirect ``get_effective_user_id`` to return
``"default"`` rather than directly mutating the contextvar direct contextvar
operations across pytest test boundaries have produced spooky cross-file
pollution when combined with the langfuse OTel global tracer provider.
"""
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
from deerflow.config.tracing_config import reset_tracing_config
from deerflow.runtime.runs import worker as worker_module
from deerflow.runtime.user_context import DEFAULT_USER_ID
reset_tracing_config()
monkeypatch.setattr(worker_module, "get_effective_user_id", lambda: DEFAULT_USER_ID)
fake_agent = _FakeAgent()
def agent_factory(config):
return fake_agent
record = RunRecord(
run_id="run-fallback",
thread_id="thread-fb",
assistant_id="lead-agent",
status=RunStatus.pending,
on_disconnect=DisconnectMode.cancel,
)
record.abort_event = asyncio.Event()
ctx = RunContext(checkpointer=None)
await run_agent(
_FakeBridge(),
_FakeRunManager(),
record,
ctx=ctx,
agent_factory=agent_factory,
graph_input={"messages": []},
config={"configurable": {"thread_id": "thread-fb"}},
)
metadata = fake_agent.captured_config.get("metadata") or {}
assert metadata.get("langfuse_user_id") == "default"
@pytest.mark.asyncio
async def test_run_agent_preserves_caller_metadata_overrides(monkeypatch):
"""Caller-provided langfuse_* keys must NOT be overridden by the default injection."""
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
from deerflow.config.tracing_config import reset_tracing_config
reset_tracing_config()
fake_agent = _FakeAgent()
def agent_factory(config):
return fake_agent
record = RunRecord(
run_id="run-2",
thread_id="thread-default",
assistant_id="lead-agent",
status=RunStatus.pending,
on_disconnect=DisconnectMode.cancel,
)
record.abort_event = asyncio.Event()
ctx = RunContext(checkpointer=None)
await run_agent(
_FakeBridge(),
_FakeRunManager(),
record,
ctx=ctx,
agent_factory=agent_factory,
graph_input={"messages": []},
config={
"configurable": {"thread_id": "thread-default"},
"metadata": {
"langfuse_session_id": "custom-session-id",
"langfuse_user_id": "explicit-user",
},
},
)
metadata = fake_agent.captured_config.get("metadata") or {}
# Caller-supplied keys win.
assert metadata["langfuse_session_id"] == "custom-session-id"
assert metadata["langfuse_user_id"] == "explicit-user"
# Worker still fills in keys that the caller didn't set.
assert metadata["langfuse_trace_name"] == "lead-agent"
@pytest.mark.asyncio
async def test_run_agent_skips_metadata_when_langfuse_disabled(monkeypatch):
fake_agent = _FakeAgent()
def agent_factory(config):
return fake_agent
record = RunRecord(
run_id="run-3",
thread_id="thread-noop",
assistant_id="lead-agent",
status=RunStatus.pending,
on_disconnect=DisconnectMode.cancel,
)
record.abort_event = asyncio.Event()
ctx = RunContext(checkpointer=None)
await run_agent(
_FakeBridge(),
_FakeRunManager(),
record,
ctx=ctx,
agent_factory=agent_factory,
graph_input={"messages": []},
config={"configurable": {"thread_id": "thread-noop"}},
)
metadata = fake_agent.captured_config.get("metadata") or {}
assert "langfuse_session_id" not in metadata
assert "langfuse_user_id" not in metadata
assert "langfuse_trace_name" not in metadata