Merge branch 'main' of https://github.com/bytedance/deer-flow into rayhpeng/storage-package-base

This commit is contained in:
rayhpeng
2026-05-15 16:41:34 +08:00
74 changed files with 4445 additions and 533 deletions
+68
View File
@@ -0,0 +1,68 @@
"""Shared helpers for user-isolation e2e tests on the custom-agent tooling.
Centralises the small fake-LLM shim and a few test-data builders that the
three e2e files in this PR (``test_setup_agent_e2e_user_isolation``,
``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``)
all need. The shim is what lets a real ``langchain.agents.create_agent``
graph run without an API key — every other layer in those tests is real
production code, which is the entire point of the test design.
"""
from __future__ import annotations
from typing import Any
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import Runnable
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent.
``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to
expose the tool schemas to the model; the upstream fake raises
``NotImplementedError`` there. We just return ``self`` because we
drive deterministic tool_call output via ``responses=...``, no schema
handling needed.
"""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
def build_single_tool_call_model(
*,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str = "call_e2e_1",
final_text: str = "done",
) -> FakeToolCallingModel:
"""Build a fake model that emits exactly one tool_call then finishes.
Two-turn behaviour, identical across our e2e tests:
turn 1 → AIMessage with a single tool_call for *tool_name*
turn 2 → AIMessage with *final_text* (terminates the agent loop)
"""
return FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": tool_name,
"args": tool_args,
"id": tool_call_id,
"type": "tool_call",
}
],
),
AIMessage(content=final_text),
]
)
+93
View File
@@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
issues when unit-testing lightweight config/registry code in isolation.
"""
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
@@ -11,11 +13,16 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
# Make 'app' and 'deerflow' importable from any working directory
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
# Break the circular import chain that exists in production code:
# deerflow.subagents.__init__
# -> .executor (SubagentExecutor, SubagentResult)
@@ -56,6 +63,92 @@ def provisioner_module():
return module
@pytest.fixture()
def blocking_io_detector():
"""Fail a focused test if blocking calls run on the event loop thread."""
with detect_blocking_io(fail_on_exit=True) as detector:
yield detector
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("blocking-io")
group.addoption(
"--detect-blocking-io",
action="store_true",
default=False,
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
)
group.addoption(
"--detect-blocking-io-fail",
action="store_true",
default=False,
help="Set a failing exit status when --detect-blocking-io records violations.",
)
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
def pytest_sessionstart(session: pytest.Session) -> None:
if _blocking_io_probe_enabled(session.config):
_blocking_io_probe.clear()
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item: pytest.Item):
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
yield
return
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
detector.__enter__()
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item: pytest.Item):
yield
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
if detector is None:
return
try:
detector.__exit__(None, None, None)
_blocking_io_probe.record(item.nodeid, detector.violations)
finally:
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
def pytest_sessionfinish(session: pytest.Session) -> None:
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
session.exitstatus = pytest.ExitCode.TESTS_FAILED
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
if not _blocking_io_probe_enabled(terminalreporter.config):
return
header, *details = _blocking_io_probe.format_summary().splitlines()
terminalreporter.write_sep("=", header)
for line in details:
terminalreporter.write_line(line)
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io-fail"))
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
# ---------------------------------------------------------------------------
# Auto-set user context for every test unless marked no_auto_user
# ---------------------------------------------------------------------------
+1
View File
@@ -0,0 +1 @@
"""Shared test support helpers."""
@@ -0,0 +1 @@
"""Runtime and static detectors used by tests."""
@@ -0,0 +1,287 @@
"""Test helper for detecting blocking calls on an asyncio event loop.
The detector is intentionally test-only. It monkeypatches a small set of
well-known blocking entry points and their already-loaded module-level aliases,
then records calls only when they happen on a thread that is currently running
an asyncio event loop. Aliases captured in closures or default arguments remain
out of scope.
"""
from __future__ import annotations
import asyncio
import importlib
import sys
import traceback
from collections import Counter
from collections.abc import Callable, Iterable, Iterator
from contextlib import AbstractContextManager
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from types import TracebackType
from typing import Any
BlockingCallable = Callable[..., Any]
@dataclass(frozen=True)
class BlockingCallSpec:
"""Describes one blocking callable to wrap during a detector run."""
name: str
target: str
record_on_iteration: bool = False
@dataclass(frozen=True)
class BlockingCall:
"""One blocking call observed on an asyncio event loop thread."""
name: str
target: str
stack: tuple[traceback.FrameSummary, ...]
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
BlockingCallSpec("time.sleep", "time:sleep"),
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
)
def _is_event_loop_thread() -> bool:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return False
return loop.is_running()
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
module_name, attr_path = target.split(":", maxsplit=1)
owner: object = importlib.import_module(module_name)
parts = attr_path.split(".")
for part in parts[:-1]:
owner = getattr(owner, part)
attr_name = parts[-1]
original = getattr(owner, attr_name)
return owner, attr_name, original
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
return tuple(frame for frame in stack if frame.filename != __file__)
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
"""Record blocking calls made from async runtime code.
By default the detector reports violations but does not fail on context
exit. Tests can set ``fail_on_exit=True`` or call
``assert_no_blocking_calls()`` explicitly.
"""
def __init__(
self,
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> None:
self._specs = tuple(specs)
self._fail_on_exit = fail_on_exit
self._patch_loaded_aliases_enabled = patch_loaded_aliases
self._stack_limit = stack_limit
self._patches: list[tuple[object, str, BlockingCallable]] = []
self._patch_keys: set[tuple[int, str]] = set()
self.violations: list[BlockingCall] = []
self._active = False
def __enter__(self) -> BlockingIODetector:
try:
self._active = True
alias_replacements: dict[int, BlockingCallable] = {}
for spec in self._specs:
owner, attr_name, original = _resolve_target(spec.target)
wrapper = self._wrap(spec, original)
self._patch_attribute(owner, attr_name, original, wrapper)
alias_replacements[id(original)] = wrapper
if self._patch_loaded_aliases_enabled:
self._patch_loaded_module_aliases(alias_replacements)
except Exception:
self._restore()
self._active = False
raise
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback_value: TracebackType | None,
) -> bool | None:
self._restore()
self._active = False
if exc_type is None and self._fail_on_exit:
self.assert_no_blocking_calls()
return None
def _restore(self) -> None:
for owner, attr_name, original in reversed(self._patches):
setattr(owner, attr_name, original)
self._patches.clear()
self._patch_keys.clear()
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
key = (id(owner), attr_name)
if key in self._patch_keys:
return
setattr(owner, attr_name, replacement)
self._patches.append((owner, attr_name, original))
self._patch_keys.add(key)
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
for module in tuple(sys.modules.values()):
namespace = getattr(module, "__dict__", None)
if not isinstance(namespace, dict):
continue
for attr_name, value in tuple(namespace.items()):
replacement = replacements_by_id.get(id(value))
if replacement is not None:
self._patch_attribute(module, attr_name, value, replacement)
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
@wraps(original)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if spec.record_on_iteration:
result = original(*args, **kwargs)
return self._wrap_iteration(spec, result)
self._record_if_blocking(spec)
return original(*args, **kwargs)
return wrapper
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
iterator = iter(iterable)
reported = False
while True:
if not reported:
reported = self._record_if_blocking(spec)
try:
yield next(iterator)
except StopIteration:
return
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
if self._active and _is_event_loop_thread():
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
self.violations.append(BlockingCall(spec.name, spec.target, stack))
return True
return False
def assert_no_blocking_calls(self) -> None:
if self.violations:
raise AssertionError(format_blocking_calls(self.violations))
class BlockingIOProbe:
"""Collect detector output across tests and format a compact summary."""
def __init__(self, project_root: Path) -> None:
self._project_root = project_root.resolve()
self._observed: list[tuple[str, BlockingCall]] = []
@property
def violation_count(self) -> int:
return len(self._observed)
@property
def test_count(self) -> int:
return len({nodeid for nodeid, _violation in self._observed})
def clear(self) -> None:
self._observed.clear()
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
for violation in violations:
self._observed.append((nodeid, violation))
def format_summary(self, *, limit: int = 30) -> str:
if not self._observed:
return "blocking io probe: no violations"
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
for _nodeid, violation in self._observed:
frame = self._local_call_site(violation.stack)
if frame is None:
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
continue
call_sites[
(
violation.name,
self._relative(frame.filename),
frame.lineno,
frame.name,
(frame.line or "").strip(),
)
] += 1
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
return "\n".join(lines)
def _relative(self, filename: str) -> str:
try:
return str(Path(filename).resolve().relative_to(self._project_root))
except ValueError:
return filename
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
if local_frames:
return local_frames[-1]
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
return test_frames[-1] if test_frames else None
def detect_blocking_io(
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> BlockingIODetector:
"""Create a detector context manager for a focused test scope."""
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
"""Format detector output with enough stack context to locate call sites."""
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
for index, violation in enumerate(violations, start=1):
lines.append(f"{index}. {violation.name} ({violation.target})")
lines.extend(_format_stack(violation.stack))
return "\n".join(lines)
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
for frame in stack:
location = f"{frame.filename}:{frame.lineno}"
lines = [f" at {frame.name} ({location})"]
if frame.line:
lines.append(f" {frame.line.strip()}")
yield from lines
+190
View File
@@ -0,0 +1,190 @@
from __future__ import annotations
import asyncio
import os
import time
from os import walk as imported_walk
from pathlib import Path
from time import sleep as imported_sleep
import httpx
import pytest
import requests
from support.detectors.blocking_io import (
BlockingCallSpec,
BlockingIOProbe,
detect_blocking_io,
)
pytestmark = pytest.mark.asyncio
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
async def test_records_time_sleep_on_event_loop() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
original_alias = imported_sleep
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
imported_sleep(0)
assert imported_sleep is original_alias
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_can_disable_loaded_alias_patching() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
imported_sleep(0)
assert detector.violations == []
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
await asyncio.to_thread(time.sleep, 0)
assert detector.violations == []
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
await asyncio.to_thread(time.sleep, 0)
assert blocking_io_detector.violations == []
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
def call_sleep() -> list[str]:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
return [violation.name for violation in detector.violations]
assert await asyncio.to_thread(call_sleep) == []
async def test_fail_on_exit_includes_call_site() -> None:
with pytest.raises(AssertionError) as exc_info:
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
time.sleep(0)
message = str(exc_info.value)
assert "time.sleep" in message
assert "test_fail_on_exit_includes_call_site" in message
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
return f"{method}:{url}"
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
with detect_blocking_io(REQUESTS_ONLY) as detector:
assert requests.get("https://example.invalid") == "get:https://example.invalid"
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
return httpx.Response(200, request=httpx.Request(method, url))
monkeypatch.setattr(httpx.Client, "request", fake_request)
with detect_blocking_io(HTTPX_ONLY) as detector:
with httpx.Client() as client:
response = client.get("https://example.invalid")
assert response.status_code == 200
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(os.walk(tmp_path))
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
original_alias = imported_walk
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(imported_walk(tmp_path))
assert imported_walk is original_alias
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert list(walker)
assert detector.violations == []
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert await asyncio.to_thread(lambda: list(walker))
assert detector.violations == []
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
assert path.read_text(encoding="utf-8") == "content"
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
summary = probe.format_summary()
assert "blocking io probe: 1 violations across 1 tests" in summary
assert "pathlib.Path.read_text" in summary
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
assert probe.format_summary() == "blocking io probe: no violations"
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
assert probe.violation_count == 1
probe.clear()
assert probe.violation_count == 0
assert probe.format_summary() == "blocking io probe: no violations"
@@ -0,0 +1,22 @@
from __future__ import annotations
import time
import pytest
ORIGINAL_SLEEP = time.sleep
def replacement_sleep(seconds: float) -> None:
return None
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(time, "sleep", replacement_sleep)
assert time.sleep is replacement_sleep
@pytest.mark.no_blocking_io_probe
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
assert time.sleep is ORIGINAL_SLEEP
assert getattr(time.sleep, "__wrapped__", None) is None
@@ -0,0 +1,222 @@
"""Real-LLM end-to-end verification for issue #2884.
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
and the production ``get_available_tools`` pipeline. The only thing we mock is
the MCP tool source — we hand-roll two ``@tool``s and inject them through
``deerflow.mcp.cache.get_cached_mcp_tools``.
The flow exercised:
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
that re-enters ``get_available_tools`` on the same task — this is the
code path issue #2884 reports). It must call ``tool_search`` to
discover the deferred ``fake_calculator`` tool.
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
``fake_subagent_trigger`` re-enters ``get_available_tools``.
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
so it can actually call it. Without this PR's fix, the re-entry wipes
the promotion and the model can no longer invoke the tool.
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
test run. Run with::
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
PYTHONPATH=. uv run pytest \
tests/test_deferred_tool_promotion_real_llm.py -v -s
"""
from __future__ import annotations
import os
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool as as_tool
# ---------------------------------------------------------------------------
# Skip control: only run when explicitly opted in.
# ---------------------------------------------------------------------------
pytestmark = pytest.mark.skipif(
os.getenv("ONEAPI_E2E") != "1",
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
)
# ---------------------------------------------------------------------------
# Fake "MCP" tools the agent should discover via tool_search.
# Keep them obviously synthetic so the model can pattern-match the search.
# ---------------------------------------------------------------------------
_calls: list[str] = []
@as_tool
def fake_calculator(expression: str) -> str:
"""Evaluate a tiny arithmetic expression like '2 + 2'.
Reserved for the user — only call this if the user asks for arithmetic.
"""
_calls.append(f"fake_calculator:{expression}")
try:
# Trivially safe-eval just for the e2e check
allowed = set("0123456789+-*/() .")
if not set(expression) <= allowed:
return "expression contains disallowed characters"
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
except Exception as e:
return f"error: {e}"
@as_tool
def fake_translator(text: str, target_lang: str) -> str:
"""Translate text into the given language code. Decorative — not used."""
_calls.append(f"fake_translator:{text}:{target_lang}")
return f"[{target_lang}] {text}"
# ---------------------------------------------------------------------------
# Pipeline wiring (same shape as the in-process tests).
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_registry_between_tests():
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Build a minimal mock AppConfig and patch the symbol — never call the
real loader, which would trigger ``_apply_singleton_configs`` and
permanently mutate cross-test singletons (memory, title, …)."""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Real-LLM e2e test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
"""End-to-end against a real OpenAI-compatible LLM.
The model must:
Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
``fake_subagent_trigger(...)``.
Turn 2 — call ``fake_calculator`` and finish.
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
layer — recorded in ``_calls`` — which proves the model received the
promoted schema after the re-entrant ``get_available_tools`` call.
"""
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
_force_tool_search_enabled(monkeypatch)
_calls.clear()
@as_tool
async def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
Use this whenever the user asks you to delegate work — pass a short
description as ``prompt``.
"""
# ``task_tool`` does this internally. Whether the registry-reset that
# used to happen here actually leaks back to the parent task depends
# on asyncio's implicit context-copying semantics (gather creates
# child tasks with copied contexts, so reset_deferred_registry is
# task-local) — but the fix in this PR is what GUARANTEES the
# promotion sticks regardless of which integration path triggers a
# re-entrant ``get_available_tools`` call.
get_available_tools(subagent_enabled=False)
_calls.append(f"fake_subagent_trigger:{prompt}")
return "subagent completed"
tools = get_available_tools() + [fake_subagent_trigger]
model = ChatOpenAI(
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
temperature=0,
max_retries=1,
)
system_prompt = (
"You are a meticulous assistant. Available deferred tools include a "
"calculator and a translator — their schemas are hidden until you "
"search for them via tool_search.\n\n"
"Procedure for the user's request:\n"
" 1. Call tool_search with query 'select:fake_calculator' AND "
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
"to delegate the side work. Put both tool_calls in your first response.\n"
" 2. After both tool messages come back, call fake_calculator with "
"the user's expression.\n"
" 3. Reply with just the numeric result."
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt=system_prompt,
)
result = await graph.ainvoke(
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
config={"recursion_limit": 12},
)
print("\n=== tool calls recorded ===")
for c in _calls:
print(f" {c}")
print("\n=== final message ===")
final_text = result["messages"][-1].content if result["messages"] else "(none)"
print(f" {final_text!r}")
# The smoking-gun assertion: fake_calculator was actually invoked at the
# tool layer. This is only possible if the promoted schema reached the
# model in turn 2, despite the subagent-style re-entry in turn 1.
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
# And the math should actually be done correctly (sanity that the LLM
# really used the result, not just hallucinated the answer).
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
@@ -0,0 +1,390 @@
"""Reproduce + regression-guard issue #2884.
Hypothesis from the issue:
``tools.tools.get_available_tools`` unconditionally calls
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
every time it is invoked. If anything calls ``get_available_tools`` again
during the same async context (after the agent has promoted tools via
``tool_search``), the promotion is wiped and the next model call hides the
tool's schema again.
These tests pin two things:
A. **At the unit boundary** — verify the failure mode directly. Promote a
tool in the registry, then call ``get_available_tools`` again and observe
that the ContextVar registry is reset and the promotion is lost.
B. **At the graph-execution boundary** — drive a real ``create_agent`` graph
with the real ``DeferredToolFilterMiddleware`` through two model turns.
The first turn calls ``tool_search`` which promotes a tool. The second
turn must see that tool's schema in ``request.tools``. If
``get_available_tools`` were to run again between the two turns and reset
the registry, the second turn's filter would strip the tool.
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
unmodified; mock only the LLM and the MCP tool source. Patch
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
``get_available_tools`` resolves via lazy import) to return our fixture
tools so we don't need a real MCP server.
"""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
# ---------------------------------------------------------------------------
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
# ---------------------------------------------------------------------------
@as_tool
def fake_mcp_search(query: str) -> str:
"""Pretend to search a knowledge base for the given query."""
return f"results for {query}"
@as_tool
def fake_mcp_fetch(url: str) -> str:
"""Pretend to fetch a page at the given URL."""
return f"content of {url}"
@pytest.fixture(autouse=True)
def _supply_env(monkeypatch: pytest.MonkeyPatch):
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
@pytest.fixture(autouse=True)
def _reset_deferred_registry_between_tests():
"""Each test must start with a clean ContextVar.
The registry lives in a module-level ContextVar with no per-task isolation
in a synchronous test runner, so one test's promotion can leak into the
next and silently break filter assertions.
"""
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
"""Make get_available_tools believe an MCP server is registered.
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
that both ``AppConfig.from_file`` (which calls
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
see a valid instance. Then point the MCP tool cache at our fixture tools.
"""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Force config.tool_search.enabled=True without touching the yaml.
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
which permanently mutates module-level singletons (``_memory_config``,
``_title_config``, …) to match the developer's ``config.yaml`` — even
after pytest restores our patch. That leaks across tests later in the
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
require ``_memory_config.enabled = True``, which is the dataclass default
but FALSE in the actual yaml).
Build a minimal mock AppConfig instead and never call the real loader.
"""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Section A — direct unit-level reproduction
# ---------------------------------------------------------------------------
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
Step 1: call get_available_tools() — registers MCP tools as deferred.
Step 2: simulate the agent calling tool_search by promoting one tool.
Step 3: call get_available_tools() again (the same code path
``task_tool`` exercises mid-run).
Assertion: after step 3, the promoted tool is STILL promoted (not
re-deferred). On ``main`` before the fix, step 3's
``reset_deferred_registry()`` wiped the promotion and re-registered
every MCP tool as deferred — this assertion fired with
``REGRESSION (#2884)``.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# Step 1: first call — both MCP tools start deferred
get_available_tools()
reg1 = get_deferred_registry()
assert reg1 is not None
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
# Step 2: simulate tool_search promoting one of them
reg1.promote({"fake_mcp_search"})
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
# Step 3: second call — registry must NOT silently undo the promotion
get_available_tools()
reg2 = get_deferred_registry()
assert reg2 is not None
deferred_after = {e.name for e in reg2.entries}
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
# ---------------------------------------------------------------------------
# Section B — graph-execution reproduction
# ---------------------------------------------------------------------------
class _ToolSearchPromotingModel(FakeToolCallingModel):
"""Two-turn model that:
Turn 1 → emit a tool_call for ``tool_search`` (the real one)
Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool)
Records the tools it received on each turn so the test can inspect what
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
"""
bound_tools_per_turn: list[list[str]] = []
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
# Record the tool names the model would see in this turn
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
self.bound_tools_per_turn.append(names)
return self
def _build_promoting_model() -> _ToolSearchPromotingModel:
return _ToolSearchPromotingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
}
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
"""End-to-end: drive a real create_agent graph through two turns.
Without the fix, the second-turn bind_tools call should NOT contain
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
registry and strips it). With the fix, the model sees the schema and can
invoke it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
tools = get_available_tools()
# Sanity: the assembled tool list includes the deferred tools (they're in
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
# they reach the model)
tool_names = {getattr(t, "name", "") for t in tools}
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
model = _build_promoting_model()
model.bound_tools_per_turn = [] # reset class-level recorder
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
turn1 = set(model.bound_tools_per_turn[0])
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
# This is the load-bearing assertion for issue #2884.
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
turn2 = set(model.bound_tools_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
# ---------------------------------------------------------------------------
# Section C — the actual issue #2884 trigger: a re-entrant
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
# wipe the parent's promotion.
# ---------------------------------------------------------------------------
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
(the same pattern that happens when ``task_tool`` builds a subagent's
toolset mid-run) must not wipe the parent agent's tool_search promotions.
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
``get_available_tools`` again — exactly what ``task_tool`` does when it
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
promoted tool. Without the fix, the re-entry wipes the registry and
the filter re-hides it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# The trigger tool simulates what task_tool does internally: rebuild the
# toolset by calling get_available_tools while the registry is live.
@as_tool
def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
get_available_tools(subagent_enabled=False)
return f"spawned subagent for: {prompt}"
tools = get_available_tools() + [fake_subagent_trigger]
bound_per_turn: list[list[str]] = []
class _Model(FakeToolCallingModel):
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
return self
model = _Model(
responses=[
# Turn 1: do both in one batch — promote AND trigger the
# subagent-style rebuild. LangGraph executes them in order in the
# same agent step.
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
},
{
"name": "fake_subagent_trigger",
"args": {"prompt": "go"},
"id": "call_trigger_1",
"type": "tool_call",
},
],
),
# Turn 2: try to invoke the promoted tool. The model gets this
# turn only if turn 1's bind_tools recorded what the filter sent.
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-subagent-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1 sanity: deferred tool not visible yet
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
# re-entrant get_available_tools call that happened in turn 1's tool batch.
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
turn2 = set(bound_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
+83 -1
View File
@@ -1,6 +1,6 @@
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
assert elapsed < 0.1
assert finished.is_set() is False
assert finished.wait(1.0) is True
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
assert queue.pending_count == 2
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(
thread_id="thread-1",
messages=["first"],
agent_name="agent-a",
correction_detected=True,
)
queue.add(
thread_id="thread-1",
messages=["second"],
agent_name="agent-a",
correction_detected=False,
)
assert queue.pending_count == 1
assert queue._queue[0].agent_name == "agent-a"
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with (
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
patch("deerflow.agents.memory.queue.time.sleep"),
):
queue.flush()
assert mock_updater.update_memory.call_count == 2
mock_updater.update_memory.assert_has_calls(
[
call(
messages=["agent-a"],
thread_id="thread-1",
agent_name="agent-a",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
call(
messages=["agent-b"],
thread_id="thread-1",
agent_name="agent-b",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
]
)
@@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id():
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
assert q._queue[0].user_id == "alice"
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock()
@@ -37,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
mock_updater.update_memory.assert_called_once()
call_kwargs = mock_updater.update_memory.call_args.kwargs
assert call_kwargs["user_id"] == "alice"
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
assert q.pending_count == 1
assert q._queue[0].messages == ["second"]
assert q._queue[0].user_id == "alice"
assert q._queue[0].agent_name == "researcher"
def test_add_nowait_keeps_different_users_separate():
q = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
patch.object(q, "_schedule_timer"),
):
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
+33
View File
@@ -268,6 +268,39 @@ class TestEdgeCases:
class TestDbRunEventStore:
"""Tests for DbRunEventStore with temp SQLite."""
@pytest.mark.anyio
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
from sqlalchemy.dialects import postgresql
from deerflow.runtime.events.store.db import DbRunEventStore
class FakeSession:
def __init__(self):
self.dialect = postgresql.dialect()
self.execute_calls = []
self.scalar_stmt = None
def get_bind(self):
return self
async def execute(self, stmt, params=None):
self.execute_calls.append((stmt, params))
async def scalar(self, stmt):
self.scalar_stmt = stmt
return 41
session = FakeSession()
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
assert max_seq == 41
assert session.execute_calls
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
assert "FOR UPDATE" not in compiled
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
+48
View File
@@ -3,7 +3,10 @@
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import re
import pytest
from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
@@ -278,3 +281,48 @@ class TestRunRepository:
assert row4["model_name"] is None
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
captured = []
class FakeResult:
def all(self):
return []
class FakeSession:
async def execute(self, stmt):
captured.append(stmt)
return FakeResult()
class FakeSessionContext:
async def __aenter__(self):
return FakeSession()
async def __aexit__(self, exc_type, exc, tb):
return None
repo = RunRepository(lambda: FakeSessionContext())
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg == {
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_runs": 0,
"by_model": {},
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
}
assert len(captured) == 1
stmt = captured[0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
assert select_match is not None
assert group_by_match is not None
assert select_match.group(1) == group_by_match.group(1)
@@ -0,0 +1,429 @@
"""End-to-end verification for issue #2862 (and the regression of #2782).
Goal: prove — without trusting any single layer's claim — that an authenticated
user creating a custom agent through the real ``setup_agent`` tool, driven by a
real LangGraph ``create_agent`` graph, ends up with files under
``users/<auth_uid>/agents/<name>`` and **not** under ``users/default/agents/...``.
We intentionally exercise the full pipeline:
HTTP body shape (mimics LangGraph SDK wire format)
-> app.gateway.services.start_run config-assembly chain
-> deerflow.runtime.runs.worker._build_runtime_context
-> langchain.agents.create_agent graph
-> ToolNode dispatch
-> setup_agent tool
The only thing we mock is the LLM (FakeMessagesListChatModel) — every layer
that handles ``user_id`` is the real production code path. If the
``user_id`` propagation is broken anywhere in this chain, these tests will
fail.
These tests intentionally ``no_auto_user`` so that the ``contextvar``
fallback would put files into ``default/`` if propagation breaks.
"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
from uuid import UUID
import pytest
from _agent_e2e_helpers import FakeToolCallingModel
from langchain_core.messages import AIMessage, HumanMessage
from app.gateway.services import (
build_run_config,
inject_authenticated_user_context,
merge_run_context_overrides,
)
from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context
# ---------------------------------------------------------------------------
# Helpers — real production code paths
# ---------------------------------------------------------------------------
def _make_request(user_id_str: str | None) -> SimpleNamespace:
"""Build a fake FastAPI Request that carries an authenticated user."""
if user_id_str is None:
user = None
else:
# User.id is UUID in production; honour that
user = SimpleNamespace(id=UUID(user_id_str), email="alice@local")
return SimpleNamespace(state=SimpleNamespace(user=user))
def _assemble_config(
*,
body_config: dict | None,
body_context: dict | None,
request_user_id: str | None,
thread_id: str = "thread-e2e",
assistant_id: str = "lead_agent",
) -> dict:
"""Replay the **exact** start_run config-assembly sequence."""
config = build_run_config(thread_id, body_config, None, assistant_id=assistant_id)
merge_run_context_overrides(config, body_context)
inject_authenticated_user_context(config, _make_request(request_user_id))
return config
def _make_paths_mock(tmp_path: Path):
"""Mirror the production paths.user_agent_dir signature."""
from unittest.mock import MagicMock
paths = MagicMock()
paths.base_dir = tmp_path
paths.agent_dir = lambda name: tmp_path / "agents" / name
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
return paths
# ---------------------------------------------------------------------------
# L1-L3: HTTP wire format → start_run → worker._build_runtime_context
# ---------------------------------------------------------------------------
class TestConfigAssembly:
"""Covers L1-L3: validate that user_id reaches runtime_ctx for every wire shape."""
def test_typical_wire_format_user_id_in_runtime_ctx(self):
"""Real frontend: body.config={recursion_limit}, body.context={agent_name,...}."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={"agent_name": "myagent", "is_bootstrap": True, "mode": "flash"},
request_user_id="11111111-2222-3333-4444-555555555555",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555"
assert runtime_ctx["agent_name"] == "myagent"
def test_body_context_none_still_injects_user_id(self):
"""If frontend omits body.context entirely, inject must still create it."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context=None,
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
def test_body_context_empty_dict_still_injects_user_id(self):
"""body.context={} (falsy) path: inject must still produce user_id."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={},
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
def test_body_config_already_contains_context_field(self):
"""body.config={'context': {...}} (LG 0.6 alt wire): inject still wins."""
config = _assemble_config(
body_config={"context": {"agent_name": "myagent"}, "recursion_limit": 1000},
body_context=None,
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
def test_client_supplied_user_id_is_overridden(self):
"""Spoofed client user_id must be overwritten by inject (auth-trusted source)."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={"agent_name": "myagent", "user_id": "spoofed"},
request_user_id="11111111-2222-3333-4444-555555555555",
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555"
def test_unauthenticated_request_does_not_inject(self):
"""If request.state.user is missing (impossible under fail-closed auth, but
verify defensively), inject must not write user_id and runtime_ctx must
therefore lack it — forcing the tool fallback path to reveal itself."""
config = _assemble_config(
body_config={"recursion_limit": 1000},
body_context={"agent_name": "myagent"},
request_user_id=None,
)
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
assert "user_id" not in runtime_ctx
# ---------------------------------------------------------------------------
# L4-L7: Real LangGraph create_agent driving the real setup_agent tool
# ---------------------------------------------------------------------------
def _build_real_bootstrap_graph(authenticated_user_id: str):
"""Construct a real LangGraph using create_agent + the real setup_agent tool.
The LLM is faked (FakeMessagesListChatModel) so we don't need an API key.
Everything else — ToolNode dispatch, runtime injection, middleware — is
the real production code path.
"""
from langchain.agents import create_agent
from deerflow.tools.builtins.setup_agent_tool import setup_agent
# First model turn: emit a tool_call for setup_agent
# Second model turn (after tool result): final answer (terminates the loop)
fake_model = FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "setup_agent",
"args": {
"soul": "# My E2E Agent\n\nA SOUL written by the model.",
"description": "End-to-end test agent",
},
"id": "call_setup_1",
"type": "tool_call",
}
],
),
AIMessage(content=f"Done. Agent created for user {authenticated_user_id}."),
]
)
graph = create_agent(
model=fake_model,
tools=[setup_agent],
system_prompt="You are a bootstrap agent. Call setup_agent immediately.",
)
return graph
@pytest.mark.no_auto_user
@pytest.mark.asyncio
async def test_real_graph_real_setup_agent_writes_to_authenticated_user_dir(tmp_path: Path):
"""The smoking-gun test for issue #2862.
Under no_auto_user (contextvar = empty), if user_id propagation through
runtime.context is broken, setup_agent will fall back to DEFAULT_USER_ID
and write to users/default/agents/... The assertion that this directory
DOES NOT exist is what makes this test load-bearing.
"""
from langgraph.runtime import Runtime
auth_uid = "abcdef01-2345-6789-abcd-ef0123456789"
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "e2e-agent", "is_bootstrap": True},
request_user_id=auth_uid,
thread_id="thread-e2e-1",
)
# Replay worker.run_agent's runtime construction. This is the key step:
# it is what makes ToolRuntime.context contain user_id when the tool
# actually fires.
runtime_ctx = _build_runtime_context("thread-e2e-1", "run-1", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_real_bootstrap_graph(auth_uid)
# Patch get_paths only (the file-system rooting); everything else is real
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
# Drive the real graph. This goes through real ToolNode + real Runtime merge.
final_state = await graph.ainvoke(
{"messages": [HumanMessage(content="Create an agent named e2e-agent")]},
config=config,
)
expected_dir = tmp_path / "users" / auth_uid / "agents" / "e2e-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "e2e-agent"
# Load-bearing assertions:
assert expected_dir.exists(), f"Agent directory not found at the authenticated user's path. Expected: {expected_dir}. tmp_path tree: {[str(p) for p in tmp_path.rglob('*')]}"
assert (expected_dir / "SOUL.md").read_text() == "# My E2E Agent\n\nA SOUL written by the model."
assert (expected_dir / "config.yaml").exists()
assert not default_dir.exists(), "REGRESSION: agent landed under users/default/. user_id propagation broke somewhere between HTTP layer and ToolRuntime.context."
# And final state should reflect tool success
last = final_state["messages"][-1]
assert "Done" in (last.content if isinstance(last.content, str) else str(last.content))
@pytest.mark.no_auto_user
@pytest.mark.asyncio
async def test_inject_failure_falls_back_to_default_proving_test_is_load_bearing(tmp_path: Path):
"""Negative control: if inject does NOT happen (no user in request), and
contextvar is empty (no_auto_user), setup_agent must land in default/.
This proves the positive test is actually load-bearing — i.e. it would
have failed before PR #2784, not passed accidentally.
"""
from langgraph.runtime import Runtime
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "fallback-agent", "is_bootstrap": True},
request_user_id=None, # no auth — inject is a no-op
thread_id="thread-e2e-2",
)
runtime_ctx = _build_runtime_context("thread-e2e-2", "run-2", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_real_bootstrap_graph("does-not-matter")
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
await graph.ainvoke(
{"messages": [HumanMessage(content="Create fallback-agent")]},
config=config,
)
default_dir = tmp_path / "users" / "default" / "agents" / "fallback-agent"
assert default_dir.exists(), "Negative control failed: even without inject + contextvar, agent did not land in default/. The test infrastructure may not be reproducing the bug condition."
# ---------------------------------------------------------------------------
# L5: Sub-graph runtime propagation (the task tool case)
# ---------------------------------------------------------------------------
@pytest.mark.no_auto_user
@pytest.mark.asyncio
async def test_subgraph_invocation_preserves_user_id_in_runtime(tmp_path: Path):
"""When a parent graph invokes a child graph (the pattern used by
subagents), parent_runtime.merge() must keep user_id intact.
We construct a child graph that contains setup_agent and call it from
a parent graph's tool. If LangGraph re-creates the Runtime and drops
user_id at the sub-graph boundary, this fails.
"""
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from deerflow.tools.builtins.setup_agent_tool import setup_agent
auth_uid = "deadbeef-0000-1111-2222-333344445555"
# Inner graph: same as the bootstrap flow
inner_model = FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "setup_agent",
"args": {"soul": "# Inner", "description": "subgraph"},
"id": "call_inner_1",
"type": "tool_call",
}
],
),
AIMessage(content="inner done"),
]
)
inner_graph = create_agent(
model=inner_model,
tools=[setup_agent],
system_prompt="inner",
)
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "subgraph-agent", "is_bootstrap": True},
request_user_id=auth_uid,
thread_id="thread-e2e-3",
)
runtime_ctx = _build_runtime_context("thread-e2e-3", "run-3", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
# Direct sub-graph invoke (mimics what a subagent invocation looks like
# — distinct ainvoke call, but parent config carries the same runtime).
await inner_graph.ainvoke(
{"messages": [HumanMessage(content="Create subgraph-agent")]},
config=config,
)
expected_dir = tmp_path / "users" / auth_uid / "agents" / "subgraph-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "subgraph-agent"
assert expected_dir.exists()
assert not default_dir.exists()
# ---------------------------------------------------------------------------
# L6: Sync tool path through ContextThreadPoolExecutor
# ---------------------------------------------------------------------------
def test_sync_tool_dispatch_through_thread_pool_uses_runtime_context(tmp_path: Path):
"""setup_agent is a sync function. When dispatched through ToolNode's
ContextThreadPoolExecutor, runtime.context must still carry user_id —
not via thread-local copy_context (which only carries contextvars), but
because it was passed in as the ToolRuntime constructor argument.
"""
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from deerflow.tools.builtins.setup_agent_tool import setup_agent
auth_uid = "11112222-3333-4444-5555-666677778888"
fake_model = FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "setup_agent",
"args": {"soul": "# Sync", "description": "sync path"},
"id": "call_sync_1",
"type": "tool_call",
}
],
),
AIMessage(content="sync done"),
]
)
graph = create_agent(model=fake_model, tools=[setup_agent], system_prompt="sync")
config = _assemble_config(
body_config={"recursion_limit": 50},
body_context={"agent_name": "sync-agent", "is_bootstrap": True},
request_user_id=auth_uid,
thread_id="thread-e2e-4",
)
runtime_ctx = _build_runtime_context("thread-e2e-4", "run-4", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
with patch(
"deerflow.tools.builtins.setup_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
):
# Use SYNC invoke to hit the ContextThreadPoolExecutor path
graph.invoke(
{"messages": [HumanMessage(content="Create sync-agent")]},
config=config,
)
expected_dir = tmp_path / "users" / auth_uid / "agents" / "sync-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "sync-agent"
assert expected_dir.exists()
assert not default_dir.exists()
@@ -0,0 +1,326 @@
"""Real HTTP end-to-end verification for issue #2862's setup_agent path.
This test drives the **entire** FastAPI gateway through ``starlette.testclient.TestClient``:
starlette.testclient.TestClient (real ASGI stack)
-> AuthMiddleware (real cookie parsing, real JWT decode)
-> /api/v1/auth/register endpoint (real password hash + sqlite write)
-> /api/threads/{id}/runs/stream endpoint (real start_run config-assembly)
-> background asyncio.create_task(run_agent) (real worker, real Runtime)
-> langchain.agents.create_agent graph (real, with fake LLM)
-> ToolNode dispatch (real)
-> setup_agent tool (real file I/O)
The only mock is the LLM (no API key needed). Every layer that participates
in ``user_id`` propagation — auth, ContextVar, ``inject_authenticated_user_context``,
``worker._build_runtime_context``, ``Runtime.merge`` — is the real production
code path. If the chain is broken at any layer, this test fails.
This is what "真实验证" looks like for a server that lives behind authentication:
register a user, log in (cookie), POST to /runs/stream, wait for the run to
finish, then read the filesystem.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model
def _build_fake_create_chat_model(agent_name: str):
"""Return a callable matching the real ``create_chat_model`` signature.
Whenever the lead agent constructs a chat model during the bootstrap flow,
we hand it a fake that emits a single setup_agent tool_call on its first
turn, then a benign final answer on its second turn.
"""
def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel:
return build_single_tool_call_model(
tool_name="setup_agent",
tool_args={
"soul": f"# Real HTTP E2E SOUL for {agent_name}",
"description": "real-http-e2e agent",
},
tool_call_id="call_real_http_1",
final_text=f"Agent {agent_name} created via real HTTP e2e.",
)
return fake_create_chat_model
@pytest.fixture
def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Stand up an isolated DeerFlow data root + config under tmp_path.
- Sets ``DEER_FLOW_HOME`` so paths land under tmp_path, not the real
``.deer-flow`` directory.
- Stages a copy of the project's ``config.yaml`` (or ``config.example.yaml``
on a fresh CI checkout where ``config.yaml`` is gitignored) and pins
``DEER_FLOW_CONFIG_PATH`` to it, so lifespan boot doesn't depend on the
developer's local config layout.
- Sets a placeholder OPENAI_API_KEY because the config has
``$OPENAI_API_KEY`` that gets resolved at parse time; the LLM itself is
mocked, so any non-empty value works.
"""
home = tmp_path / "deer-flow-home"
home.mkdir()
monkeypatch.setenv("DEER_FLOW_HOME", str(home))
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used-because-llm-is-mocked")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
# Hermetic config: do not depend on whether the dev machine has a real
# ``config.yaml`` at the repo root. CI's ``actions/checkout`` only ships
# ``config.example.yaml`` (and its ``models:`` list is commented out, so
# AppConfig validation would reject it). Write a minimal, self-sufficient
# config to tmp_path and pin ``DEER_FLOW_CONFIG_PATH`` to it.
staged_config = tmp_path / "config.yaml"
staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config))
return home
# Minimal config that satisfies AppConfig + LeadAgent's _resolve_model_name.
# The model `use` path must resolve to a real class for config parsing to
# succeed; the test patches ``create_chat_model`` on the lead agent module,
# so the model is never actually instantiated. SandboxConfig.use is required
# at schema level; LocalSandboxProvider is the only sandbox that runs without
# Docker.
_MINIMAL_CONFIG_YAML = """\
log_level: info
models:
- name: fake-test-model
display_name: Fake Test Model
use: langchain_openai:ChatOpenAI
model: gpt-4o-mini
api_key: $OPENAI_API_KEY
base_url: $OPENAI_API_BASE
sandbox:
use: deerflow.sandbox.local:LocalSandboxProvider
agents_api:
enabled: true
database:
backend: sqlite
"""
def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None:
"""Reset every process-wide cache that would survive across tests.
This fixture stands up a full FastAPI app + sqlite DB + LangGraph runtime
inside ``tmp_path``. To get true per-test isolation we have to invalidate
a handful of module-level caches that production normally never resets,
so they pick up our test-only ``DEER_FLOW_HOME`` and sqlite path:
- ``deerflow.config.app_config`` caches the parsed ``config.yaml``.
- ``deerflow.config.paths`` caches the ``Paths`` singleton derived from
``DEER_FLOW_HOME`` at first access.
- ``deerflow.persistence.engine`` caches the SQLAlchemy engine and
session factory after the first call to ``init_engine_from_config``.
``raising=False`` keeps the fixture resilient if upstream renames or
drops one of these attributes — the test will simply skip that reset
instead of failing with a confusing AttributeError, and the next test
to call ``get_app_config()``/``get_paths()`` will surface the real
incompatibility loudly.
"""
from deerflow.config import app_config as app_config_module
from deerflow.config import paths as paths_module
from deerflow.persistence import engine as engine_module
for module, attr in (
(app_config_module, "_app_config"),
(app_config_module, "_app_config_path"),
(app_config_module, "_app_config_mtime"),
(paths_module, "_paths_singleton"),
(engine_module, "_engine"),
(engine_module, "_session_factory"),
):
monkeypatch.setattr(module, attr, None, raising=False)
@pytest.fixture
def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch):
"""Build a fresh FastAPI app inside a clean DEER_FLOW_HOME.
Each test gets its own sqlite DB and checkpoint store under ``tmp_path``,
with no cross-test contamination.
"""
_reset_process_singletons(monkeypatch)
# Re-resolve the config from the test-only DEER_FLOW_HOME and pin its
# sqlite path into tmp_path so the lifespan-time engine init lands there.
from deerflow.config import app_config as app_config_module
cfg = app_config_module.get_app_config()
cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db")
from app.gateway.app import create_app
return create_app()
def _drain_stream(response, *, timeout: float = 30.0, max_bytes: int = 4 * 1024 * 1024) -> str:
"""Consume an SSE response body until the run terminates and return the text.
Bounded to keep the test fail-fast:
- Stops as soon as an ``event: end`` SSE frame is observed (the gateway
sends this when the background run finishes — see ``services.format_sse``
and ``StreamBridge.publish_end``).
- Stops at ``timeout`` seconds wall-clock so a stuck run / runaway heartbeat
loop surfaces a real failure instead of hanging pytest.
- Stops at ``max_bytes`` so a runaway producer can't OOM the test process.
"""
import time as _time
deadline = _time.monotonic() + timeout
body = b""
for chunk in response.iter_bytes():
body += chunk
if b"event: end" in body:
break
if len(body) >= max_bytes:
break
if _time.monotonic() >= deadline:
break
return body.decode("utf-8", errors="replace")
def _wait_for_file(path: Path, *, timeout: float = 10.0) -> bool:
"""Block until *path* exists or *timeout* elapses.
The run completes inside ``asyncio.create_task`` after start_run returns,
so the test must wait for the background task to flush its writes.
"""
import time as _time
deadline = _time.monotonic() + timeout
while _time.monotonic() < deadline:
if path.exists():
return True
_time.sleep(0.05)
return False
@pytest.mark.no_auto_user
def test_real_http_create_agent_lands_in_authenticated_user_dir(
isolated_app: Any,
isolated_deer_flow_home: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""The full real-server contract test.
1. Register a real user via POST /api/v1/auth/register (also auto-logs in)
2. POST to /api/threads/{tid}/runs/stream with the **exact** body shape the
frontend (LangGraph SDK) sends during the bootstrap flow.
3. Wait for the background run to finish.
4. Assert SOUL.md exists under users/<authenticated_uid>/agents/<name>/.
5. Assert NOTHING exists under users/default/agents/<name>/.
"""
# ``deerflow.agents.lead_agent.agent`` imports ``create_chat_model`` with
# ``from deerflow.models import create_chat_model`` at module load time,
# rebinding the symbol into its own namespace. So the only patch that
# intercepts the call is the bound name on ``lead_agent.agent`` — patching
# ``deerflow.models.create_chat_model`` would be too late.
agent_name = "real-http-agent"
from starlette.testclient import TestClient
with (
patch(
"deerflow.agents.lead_agent.agent.create_chat_model",
new=_build_fake_create_chat_model(agent_name),
),
TestClient(isolated_app) as client,
):
# --- 1. Register & auto-login ---
register = client.post(
"/api/v1/auth/register",
json={"email": "e2e-user@example.com", "password": "very-strong-password-123"},
)
assert register.status_code == 201, register.text
registered = register.json()
auth_uid = registered["id"]
# The endpoint sets both access_token (auth) and csrf_token (CSRF Double
# Submit Cookie) cookies; the TestClient cookie jar propagates them.
assert client.cookies.get("access_token"), "register endpoint must set session cookie"
csrf_token = client.cookies.get("csrf_token")
assert csrf_token, "register endpoint must set csrf_token cookie"
# --- 2. Create a thread (require_existing=True on /runs/stream means
# we must call POST /api/threads first; the React frontend does the
# same via the LangGraph SDK's threads.create) ---
import uuid as _uuid
thread_id = str(_uuid.uuid4())
created = client.post(
"/api/threads",
json={"thread_id": thread_id, "metadata": {}},
headers={"X-CSRF-Token": csrf_token},
)
assert created.status_code == 200, created.text
# --- 3. POST /runs/stream with the bootstrap wire format ---
# This is the EXACT shape the React frontend sends after PR #2784:
# thread.submit(input, {config, context}) ->
# POST /api/threads/{id}/runs/stream body =
# {assistant_id, input, config, context}
body = {
"assistant_id": "lead_agent",
"input": {
"messages": [
{
"role": "user",
"content": (f"The new custom agent name is {agent_name}. Help me design its SOUL.md before saving it."),
}
]
},
"config": {"recursion_limit": 50},
"context": {
"agent_name": agent_name,
"is_bootstrap": True,
"mode": "flash",
"thinking_enabled": False,
"is_plan_mode": False,
"subagent_enabled": False,
},
"stream_mode": ["values"],
}
# The /stream endpoint returns SSE; we drain it so the server-side
# background task (run_agent) gets to completion before we look at disk.
with client.stream(
"POST",
f"/api/threads/{thread_id}/runs/stream",
json=body,
headers={"X-CSRF-Token": csrf_token},
) as resp:
assert resp.status_code == 200, resp.read().decode()
transcript = _drain_stream(resp)
# Sanity: the stream should have produced at least one event
assert "event:" in transcript, f"no SSE events in response: {transcript[:500]!r}"
# --- 4. Verify filesystem outcome ---
expected_dir = isolated_deer_flow_home / "users" / auth_uid / "agents" / agent_name
default_dir = isolated_deer_flow_home / "users" / "default" / "agents" / agent_name
# The setup_agent tool runs inside the background asyncio task spawned
# by start_run; SSE-drain typically waits for it, but we add a bounded
# poll to be robust against scheduler jitter.
assert _wait_for_file(expected_dir / "SOUL.md", timeout=15.0), (
"SOUL.md did not appear under users/<auth_uid>/agents/. "
f"Expected: {expected_dir / 'SOUL.md'}. "
f"tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}. "
f"SSE transcript tail: {transcript[-1000:]!r}"
)
soul_text = (expected_dir / "SOUL.md").read_text()
assert agent_name in soul_text, f"unexpected SOUL content: {soul_text!r}"
# The smoking-gun assertion: the agent must NOT have landed in default/
assert not default_dir.exists(), f"REGRESSION: agent landed under users/default/{agent_name} instead of the authenticated user. Default-dir contents: {list(default_dir.rglob('*')) if default_dir.exists() else 'n/a'}"
+26 -1
View File
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
)
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
def _runtime(
thread_id: str | None = "thread-1",
agent_name: str | None = None,
user_id: str | None = None,
) -> SimpleNamespace:
context = {}
if thread_id is not None:
context["thread_id"] = thread_id
if agent_name is not None:
context["agent_name"] = agent_name
if user_id is not None:
context["user_id"] = user_id
return SimpleNamespace(context=context)
@@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="main",
agent_name="researcher",
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
+153
View File
@@ -59,12 +59,15 @@ def _make_result(
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
token_usage_records: list[dict] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
token_usage_records=token_usage_records or [],
usage_reported=False,
)
@@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]
@pytest.mark.parametrize(
"status, expected_type",
[
(FakeSubagentStatus.COMPLETED, "task_completed"),
(FakeSubagentStatus.FAILED, "task_failed"),
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
],
)
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
"""Terminal task events include a usage summary from token_usage_records."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
records = [
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
]
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-usage",
)
terminal_events = [e for e in events if e["type"] == expected_type]
assert len(terminal_events) == 1
assert terminal_events[0]["usage"] == {
"input_tokens": 300,
"output_tokens": 130,
"total_tokens": 430,
}
def test_terminal_event_usage_none_when_no_records(monkeypatch):
"""Terminal event has usage=None when token_usage_records is empty."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-no-records",
)
completed = [e for e in events if e["type"] == "task_completed"]
assert len(completed) == 1
assert completed[0]["usage"] is None
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=FileNotFoundError("missing config")),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
runtime = _make_runtime(app_config=app_config)
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
task_tool_module._subagent_usage_cache.clear()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-disabled-cache",
)
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
runtime = _make_runtime(app_config=app_config)
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
with pytest.raises(RuntimeError, match="poll failed"):
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-error",
)
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
+438 -66
View File
@@ -1,28 +1,25 @@
"""Tests for ThreadMetaRepository (SQLAlchemy-backed)."""
import logging
import pytest
from deerflow.persistence.thread_meta import ThreadMetaRepository
from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository
async def _make_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
@pytest.fixture
async def repo(tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return ThreadMetaRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
yield ThreadMetaRepository(get_session_factory())
await close_engine()
class TestThreadMetaRepository:
@pytest.mark.anyio
async def test_create_and_get(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_and_get(self, repo):
record = await repo.create("t1")
assert record["thread_id"] == "t1"
assert record["status"] == "idle"
@@ -31,148 +28,523 @@ class TestThreadMetaRepository:
fetched = await repo.get("t1")
assert fetched is not None
assert fetched["thread_id"] == "t1"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_assistant_id(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_with_assistant_id(self, repo):
record = await repo.create("t1", assistant_id="agent1")
assert record["assistant_id"] == "agent1"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_owner_and_display_name(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_with_owner_and_display_name(self, repo):
record = await repo.create("t1", user_id="user1", display_name="My Thread")
assert record["user_id"] == "user1"
assert record["display_name"] == "My Thread"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_metadata(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_create_with_metadata(self, repo):
record = await repo.create("t1", metadata={"key": "value"})
assert record["metadata"] == {"key": "value"}
await _cleanup()
@pytest.mark.anyio
async def test_get_nonexistent(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_get_nonexistent(self, repo):
assert await repo.get("nonexistent") is None
await _cleanup()
@pytest.mark.anyio
async def test_check_access_no_record_allows(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_no_record_allows(self, repo):
assert await repo.check_access("unknown", "user1") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_owner_matches(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_owner_matches(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_owner_mismatch(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_owner_mismatch(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2") is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_no_owner_allows_all(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_no_owner_allows_all(self, repo):
# Explicit user_id=None to bypass the new AUTO default that
# would otherwise pick up the test user from the autouse fixture.
await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_missing_row_denied(self, tmp_path):
async def test_check_access_strict_missing_row_denied(self, repo):
"""require_existing=True flips the missing-row case to *denied*.
Closes the delete-idempotence cross-user gap: after a thread is
deleted, the row is gone, and the permissive default would let any
caller "claim" it as untracked. The strict mode demands a row.
"""
repo = await _make_repo(tmp_path)
assert await repo.check_access("never-existed", "user1", require_existing=True) is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_strict_owner_match_allowed(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1", require_existing=True) is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_check_access_strict_owner_mismatch_denied(self, repo):
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2", require_existing=True) is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
async def test_check_access_strict_null_owner_still_allowed(self, repo):
"""Even in strict mode, a row with NULL user_id stays shared.
The strict flag tightens the *missing row* case, not the *shared
row* case — legacy pre-auth rows that survived a clean migration
without an owner are still everyone's.
"""
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone", require_existing=True) is True
await _cleanup()
@pytest.mark.anyio
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_status(self, repo):
await repo.create("t1")
await repo.update_status("t1", "busy")
record = await repo.get("t1")
assert record["status"] == "busy"
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_delete(self, repo):
await repo.create("t1")
await repo.delete("t1")
assert await repo.get("t1") is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_delete_nonexistent_is_noop(self, repo):
await repo.delete("nonexistent") # should not raise
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_merges(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_metadata_merges(self, repo):
await repo.create("t1", metadata={"a": 1, "b": 2})
await repo.update_metadata("t1", {"b": 99, "c": 3})
record = await repo.get("t1")
# Existing key preserved, overlapping key overwritten, new key added
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_on_empty(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_metadata_on_empty(self, repo):
await repo.create("t1")
await repo.update_metadata("t1", {"k": "v"})
record = await repo.get("t1")
assert record["metadata"] == {"k": "v"}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
async def test_update_metadata_nonexistent_is_noop(self, repo):
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
await _cleanup()
# --- search with metadata filter (SQL push-down) ---
@pytest.mark.anyio
async def test_search_metadata_filter_string(self, repo):
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
await repo.create("t3", metadata={"env": "prod", "region": "us"})
results = await repo.search(metadata={"env": "prod"})
ids = {r["thread_id"] for r in results}
assert ids == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_filter_numeric(self, repo):
await repo.create("t1", metadata={"priority": 1})
await repo.create("t2", metadata={"priority": 2})
await repo.create("t3", metadata={"priority": 1, "extra": "x"})
results = await repo.search(metadata={"priority": 1})
ids = {r["thread_id"] for r in results}
assert ids == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_filter_multiple_keys(self, repo):
await repo.create("t1", metadata={"env": "prod", "region": "us"})
await repo.create("t2", metadata={"env": "prod", "region": "eu"})
await repo.create("t3", metadata={"env": "staging", "region": "us"})
results = await repo.search(metadata={"env": "prod", "region": "us"})
assert len(results) == 1
assert results[0]["thread_id"] == "t1"
@pytest.mark.anyio
async def test_search_metadata_no_match(self, repo):
await repo.create("t1", metadata={"env": "prod"})
results = await repo.search(metadata={"env": "dev"})
assert results == []
@pytest.mark.anyio
async def test_search_metadata_pagination_correct(self, repo):
"""Regression: SQL push-down makes limit/offset exact even when most rows don't match."""
for i in range(30):
meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"}
await repo.create(f"t{i:03d}", metadata=meta)
# Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows
all_matches = await repo.search(metadata={"target": "yes"}, limit=100)
assert len(all_matches) == 10
# Paginate: first page
page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0)
assert len(page1) == 3
# Paginate: second page
page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3)
assert len(page2) == 3
# No overlap between pages
page1_ids = {r["thread_id"] for r in page1}
page2_ids = {r["thread_id"] for r in page2}
assert page1_ids.isdisjoint(page2_ids)
# Last page
page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9)
assert len(page_last) == 1
@pytest.mark.anyio
async def test_search_metadata_with_status_filter(self, repo):
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "prod"})
await repo.update_status("t1", "busy")
results = await repo.search(metadata={"env": "prod"}, status="busy")
assert len(results) == 1
assert results[0]["thread_id"] == "t1"
@pytest.mark.anyio
async def test_search_without_metadata_still_works(self, repo):
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2")
results = await repo.search(limit=10)
assert len(results) == 2
@pytest.mark.anyio
async def test_search_metadata_missing_key_no_match(self, repo):
"""Rows without the requested metadata key should not match."""
await repo.create("t1", metadata={"other": "val"})
await repo.create("t2", metadata={"env": "prod"})
results = await repo.search(metadata={"env": "prod"})
assert len(results) == 1
assert results[0]["thread_id"] == "t2"
@pytest.mark.anyio
async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog):
"""When ALL metadata keys are unsafe, raises InvalidMetadataFilterError."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info:
await repo.search(metadata={"bad;key": "x"})
assert any("bad;key" in r.message for r in caplog.records)
# Subclass of ValueError for backward compatibility
assert isinstance(exc_info.value, ValueError)
@pytest.mark.anyio
async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog):
"""Valid keys filter rows; only the invalid key is warned and skipped."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
results = await repo.search(metadata={"env": "prod", "bad;key": "x"})
ids = {r["thread_id"] for r in results}
assert ids == {"t1"}
assert any("bad;key" in r.message for r in caplog.records)
@pytest.mark.anyio
async def test_search_metadata_filter_boolean(self, repo):
"""True matches only boolean true, not integer 1."""
await repo.create("t1", metadata={"active": True})
await repo.create("t2", metadata={"active": False})
await repo.create("t3", metadata={"active": True, "extra": "x"})
await repo.create("t4", metadata={"active": 1})
results = await repo.search(metadata={"active": True})
ids = {r["thread_id"] for r in results}
assert ids == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_filter_none(self, repo):
"""Only rows with explicit JSON null match; missing key does not."""
await repo.create("t1", metadata={"tag": None})
await repo.create("t2", metadata={"tag": "present"})
await repo.create("t3", metadata={"other": "val"})
results = await repo.search(metadata={"tag": None})
ids = {r["thread_id"] for r in results}
assert ids == {"t1"}
@pytest.mark.anyio
async def test_search_metadata_non_string_key_skipped(self, repo, caplog):
"""Non-string keys raise ValueError from isinstance check; should be warned and skipped."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search(metadata={1: "x"})
assert any("1" in r.message for r in caplog.records)
@pytest.mark.anyio
async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog):
"""Unsupported value types (list, dict) raise TypeError; should be warned and skipped."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search(metadata={"env": ["prod", "staging"]})
@pytest.mark.anyio
async def test_search_metadata_dotted_key_raises(self, repo, caplog):
"""Dotted keys are rejected; when ALL keys are dotted, raises ValueError."""
await repo.create("t1", metadata={"env": "prod"})
await repo.create("t2", metadata={"env": "staging"})
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search(metadata={"a.b": "anything"})
assert any("a.b" in r.message for r in caplog.records)
# --- dialect-aware type-safe filtering edge cases ---
@pytest.mark.anyio
async def test_search_metadata_bool_vs_int_distinction(self, repo):
"""True must not match 1; False must not match 0."""
await repo.create("bool_true", metadata={"flag": True})
await repo.create("bool_false", metadata={"flag": False})
await repo.create("int_one", metadata={"flag": 1})
await repo.create("int_zero", metadata={"flag": 0})
true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})}
assert true_hits == {"bool_true"}
false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})}
assert false_hits == {"bool_false"}
@pytest.mark.anyio
async def test_search_metadata_int_does_not_match_bool(self, repo):
"""Integer 1 must not match boolean True."""
await repo.create("bool_true", metadata={"val": True})
await repo.create("int_one", metadata={"val": 1})
hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})}
assert hits == {"int_one"}
@pytest.mark.anyio
async def test_search_metadata_none_excludes_missing_key(self, repo):
"""Filtering by None matches explicit JSON null only, not missing key or empty {}."""
await repo.create("explicit_null", metadata={"k": None})
await repo.create("missing_key", metadata={"other": "x"})
await repo.create("empty_obj", metadata={})
hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})}
assert hits == {"explicit_null"}
@pytest.mark.anyio
async def test_search_metadata_float_value(self, repo):
await repo.create("t1", metadata={"score": 3.14})
await repo.create("t2", metadata={"score": 2.71})
await repo.create("t3", metadata={"score": 3.14})
hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})}
assert hits == {"t1", "t3"}
@pytest.mark.anyio
async def test_search_metadata_mixed_types_same_key(self, repo):
"""Each type query only matches its own type, even when the key is shared."""
await repo.create("str_row", metadata={"x": "hello"})
await repo.create("int_row", metadata={"x": 42})
await repo.create("bool_row", metadata={"x": True})
await repo.create("null_row", metadata={"x": None})
assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"}
assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"}
assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"}
assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"}
@pytest.mark.anyio
async def test_search_metadata_large_int_precision(self, repo):
"""Integers beyond float precision (> 2**53) must match exactly."""
large = 2**53 + 1
await repo.create("t1", metadata={"id": large})
await repo.create("t2", metadata={"id": large - 1})
hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})}
assert hits == {"t1"}
class TestJsonMatchCompilation:
"""Verify compiled SQL for both SQLite and PostgreSQL dialects."""
def test_json_match_compiles_sqlite(self):
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
engine = create_engine("sqlite://")
cases = [
(None, "json_type(t.data, '$.\"k\"') = 'null'"),
(True, "json_type(t.data, '$.\"k\"') = 'true'"),
(False, "json_type(t.data, '$.\"k\"') = 'false'"),
]
for value, expected_fragment in cases:
expr = json_match(t.c.data, "k", value)
sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})
assert str(sql) == expected_fragment, f"value={value!r}: {sql}"
# int: uses INTEGER cast for precision, type-check narrows to 'integer' only
int_expr = json_match(t.c.data, "k", 42)
sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
assert "json_type" in sql
assert "= 'integer'" in sql
assert "INTEGER" in sql
assert "CAST" in sql
# float: uses REAL cast, type-check spans 'integer' and 'real'
float_expr = json_match(t.c.data, "k", 3.14)
sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
assert "json_type" in sql
assert "IN ('integer', 'real')" in sql
assert "REAL" in sql
str_expr = json_match(t.c.data, "k", "hello")
sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
assert "json_type" in sql
assert "'text'" in sql
def test_json_match_compiles_pg(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
dialect = postgresql.dialect()
cases = [
(None, "json_typeof(t.data -> 'k') = 'null'"),
(True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"),
(False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"),
]
for value, expected_fragment in cases:
expr = json_match(t.c.data, "k", value)
sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
assert str(sql) == expected_fragment, f"value={value!r}: {sql}"
# int: CASE guard prevents CAST error when 'number' also matches floats
int_expr = json_match(t.c.data, "k", 42)
sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "json_typeof" in sql
assert "'number'" in sql
assert "BIGINT" in sql
assert "CASE WHEN" in sql
assert "'^-?[0-9]+$'" in sql
# float: uses DOUBLE PRECISION cast
float_expr = json_match(t.c.data, "k", 3.14)
sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "json_typeof" in sql
assert "'number'" in sql
assert "DOUBLE PRECISION" in sql
str_expr = json_match(t.c.data, "k", "hello")
sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "json_typeof" in sql
assert "'string'" in sql
def test_json_match_rejects_unsafe_key(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]:
with pytest.raises(ValueError, match="JsonMatch key must match"):
json_match(t.c.data, bad_key, "x")
# Non-string keys must also raise ValueError (not TypeError from re.match)
for non_str_key in [42, None, ("k",)]:
with pytest.raises(ValueError, match="JsonMatch key must match"):
json_match(t.c.data, non_str_key, "x")
def test_json_match_rejects_unsupported_value_type(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
for bad_value in [[], {}, object()]:
with pytest.raises(TypeError, match="JsonMatch value must be"):
json_match(t.c.data, "k", bad_value)
def test_json_match_unsupported_dialect_raises(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.dialects import mysql
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
expr = json_match(t.c.data, "k", "v")
with pytest.raises(NotImplementedError, match="mysql"):
str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True}))
def test_json_match_rejects_out_of_range_int(self):
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
# boundary values must be accepted
json_match(t.c.data, "k", 2**63 - 1)
json_match(t.c.data, "k", -(2**63))
# one beyond each boundary must be rejected
for out_of_range in [2**63, -(2**63) - 1, 10**30]:
with pytest.raises(TypeError, match="out of signed 64-bit range"):
json_match(t.c.data, "k", out_of_range)
def test_compiler_raises_on_escaped_key(self):
"""Compiler raises ValueError even when __init__ validation is bypassed."""
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import JSON
from deerflow.persistence.json_compat import json_match
metadata = MetaData()
t = Table("t", metadata, Column("data", JSON), Column("id", String))
engine = create_engine("sqlite://")
elem = json_match(t.c.data, "k", "v")
elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert
with pytest.raises(ValueError, match="Key escaped validation"):
str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
with pytest.raises(ValueError, match="Key escaped validation"):
str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
+54
View File
@@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore
from app.gateway.routers import threads
from deerflow.config.paths import Paths
from deerflow.persistence.thread_meta import InvalidMetadataFilterError
from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore
_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None
assert entries, "expected at least one history entry"
for entry in entries:
assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry
# ── Metadata filter validation at API boundary ────────────────────────────────
def test_search_threads_rejects_invalid_key_at_api_boundary() -> None:
"""Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic
validator on ThreadSearchRequest.metadata — 422 from both backends.
"""
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}})
assert response.status_code == 422
def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None:
"""Value types outside (None, bool, int, float, str) are rejected."""
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}})
assert response.status_code == 422
def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None:
"""If the backend still raises InvalidMetadataFilterError (defense in
depth), the handler surfaces it as HTTP 400.
"""
app, _store, _checkpointer = _build_thread_app()
thread_store = app.state.thread_store
async def _raise(**kwargs):
raise InvalidMetadataFilterError("rejected")
with TestClient(app) as client:
with patch.object(thread_store, "search", side_effect=_raise):
response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}})
assert response.status_code == 400
assert "rejected" in response.json()["detail"]
def test_search_threads_succeeds_with_valid_metadata() -> None:
"""Sanity check: valid metadata passes through without error."""
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}})
assert response.status_code == 200
+48 -1
View File
@@ -1,9 +1,10 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove",
}
]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}
+4 -8
View File
@@ -65,8 +65,7 @@ def _make_minimal_config(tools):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Config-loaded async-only tools can still be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
@@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash,
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
def test_no_duplicates_returned(mock_bash, mock_cfg):
"""get_available_tools() never returns two tools with the same name."""
mock_cfg.return_value = _make_minimal_config([])
@@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
def test_first_occurrence_wins(mock_bash, mock_cfg):
"""When duplicates exist, the first occurrence is kept."""
mock_cfg.return_value = _make_minimal_config([])
@@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
"""A warning is logged for every skipped duplicate."""
import logging
@@ -0,0 +1,253 @@
"""End-to-end verification for update_agent's user_id resolution.
PR #2784 hardened setup_agent to prefer runtime.context["user_id"] over the
contextvar. update_agent had the same latent gap: it unconditionally called
get_effective_user_id() at module level, so any scenario where the contextvar
was unavailable while runtime.context carried user_id (a background task
scheduled outside the request task, a worker pool that doesn't copy_context,
checkpoint resume on a different task) would silently route writes to
users/default/agents/...
These tests are load-bearing under @no_auto_user (contextvar empty):
- The negative-control test confirms the fixture actually puts the tool in
the regime where the contextvar fallback would land in users/default/.
Without that, the positive test would be vacuously satisfied.
- The positive test verifies update_agent honours runtime.context["user_id"]
injected by inject_authenticated_user_context in the gateway. Before the
fix in this PR, this test failed; now it passes.
"""
from __future__ import annotations
from contextlib import ExitStack
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
import yaml
from _agent_e2e_helpers import build_single_tool_call_model
from langchain_core.messages import HumanMessage
from app.gateway.services import (
build_run_config,
inject_authenticated_user_context,
merge_run_context_overrides,
)
from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context
def _make_request(user_id_str: str | None) -> SimpleNamespace:
user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") if user_id_str else None
return SimpleNamespace(state=SimpleNamespace(user=user))
def _assemble_config(*, body_context: dict | None, request_user_id: str | None, thread_id: str) -> dict:
config = build_run_config(thread_id, {"recursion_limit": 50}, None, assistant_id="lead_agent")
merge_run_context_overrides(config, body_context)
inject_authenticated_user_context(config, _make_request(request_user_id))
return config
def _seed_existing_agent(tmp_path: Path, user_id: str, agent_name: str, soul: str = "# Original"):
"""Pre-create an agent on disk for update_agent to overwrite."""
agent_dir = tmp_path / "users" / user_id / "agents" / agent_name
agent_dir.mkdir(parents=True, exist_ok=True)
(agent_dir / "config.yaml").write_text(
yaml.dump({"name": agent_name, "description": "old"}, allow_unicode=True),
encoding="utf-8",
)
(agent_dir / "SOUL.md").write_text(soul, encoding="utf-8")
return agent_dir
def _make_paths_mock(tmp_path: Path):
paths = MagicMock()
paths.base_dir = tmp_path
paths.agent_dir = lambda name: tmp_path / "agents" / name
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
return paths
def _patch_update_agent_dependencies(tmp_path: Path):
"""update_agent reads load_agent_config + get_app_config — stub them
minimally so the tool can run without a real config file or LLM."""
fake_model_cfg = SimpleNamespace(name="fake-model")
fake_app_cfg = MagicMock()
fake_app_cfg.get_model_config = lambda name: fake_model_cfg if name == "fake-model" else None
return [
patch(
"deerflow.tools.builtins.update_agent_tool.get_paths",
return_value=_make_paths_mock(tmp_path),
),
patch(
"deerflow.tools.builtins.update_agent_tool.get_app_config",
return_value=fake_app_cfg,
),
# load_agent_config (used by update_agent to read existing config) also
# reads paths via its own module-level get_paths reference. Patch it too
# or the tool returns "Agent does not exist" before touching disk.
patch(
"deerflow.config.agents_config.get_paths",
return_value=_make_paths_mock(tmp_path),
),
]
def _build_update_graph(*, soul_payload: str):
from langchain.agents import create_agent
from deerflow.tools.builtins.update_agent_tool import update_agent
fake_model = build_single_tool_call_model(
tool_name="update_agent",
tool_args={"soul": soul_payload, "description": "refined"},
tool_call_id="call_update_1",
final_text="updated",
)
return create_agent(model=fake_model, tools=[update_agent], system_prompt="updater")
# ---------------------------------------------------------------------------
# Negative control — proves the test environment puts update_agent in the
# regime where the contextvar fallback would land in default/.
# ---------------------------------------------------------------------------
@pytest.mark.no_auto_user
def test_update_agent_falls_back_to_default_when_no_inject_and_no_contextvar(tmp_path: Path):
"""No request.state.user, no contextvar — update_agent must look in
users/default/agents/. We seed the file there so the tool succeeds and
we know which directory it actually consulted."""
from langgraph.runtime import Runtime
_seed_existing_agent(tmp_path, "default", "fallback-target")
config = _assemble_config(
body_context={"agent_name": "fallback-target"},
request_user_id=None, # no auth, inject is no-op
thread_id="thread-update-1",
)
runtime_ctx = _build_runtime_context("thread-update-1", "run-1", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_update_graph(soul_payload="# Fallback Updated")
with ExitStack() as stack:
for p in _patch_update_agent_dependencies(tmp_path):
stack.enter_context(p)
graph.invoke(
{"messages": [HumanMessage(content="update fallback-target")]},
config=config,
)
soul = (tmp_path / "users" / "default" / "agents" / "fallback-target" / "SOUL.md").read_text()
assert soul == "# Fallback Updated", "Sanity: tool should have written under default/"
# ---------------------------------------------------------------------------
# Regression guard — passes on this branch, would fail on main before the fix.
# ---------------------------------------------------------------------------
@pytest.mark.no_auto_user
def test_update_agent_should_use_runtime_context_user_id_when_contextvar_missing(tmp_path: Path):
"""update_agent prefers the authenticated user_id carried in
runtime.context (placed there by inject_authenticated_user_context)
over the contextvar — same contract as setup_agent (PR #2784).
Before this PR's fix, update_agent unconditionally called
get_effective_user_id() and landed in default/ whenever the contextvar
was unavailable. This test pins the corrected behaviour.
"""
from langgraph.runtime import Runtime
auth_uid = "abcdef01-2345-6789-abcd-ef0123456789"
# Seed the agent in BOTH locations so we can prove which one was opened.
auth_dir = _seed_existing_agent(tmp_path, auth_uid, "shared-name", soul="# Auth Original")
default_dir = _seed_existing_agent(tmp_path, "default", "shared-name", soul="# Default Original")
config = _assemble_config(
body_context={"agent_name": "shared-name"},
request_user_id=auth_uid,
thread_id="thread-update-2",
)
runtime_ctx = _build_runtime_context("thread-update-2", "run-2", config.get("context"), None)
assert runtime_ctx["user_id"] == auth_uid, "Pre-condition: inject must have placed user_id into runtime_ctx"
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_update_graph(soul_payload="# Auth Updated")
with ExitStack() as stack:
for p in _patch_update_agent_dependencies(tmp_path):
stack.enter_context(p)
graph.invoke(
{"messages": [HumanMessage(content="update shared-name")]},
config=config,
)
auth_soul = (auth_dir / "SOUL.md").read_text()
default_soul = (default_dir / "SOUL.md").read_text()
assert auth_soul == "# Auth Updated", f"REGRESSION: update_agent ignored runtime.context['user_id']={auth_uid!r} and routed the write to users/default/ instead. auth_soul={auth_soul!r}, default_soul={default_soul!r}"
assert default_soul == "# Default Original", "REGRESSION: update_agent corrupted the shared default-user agent. It should have written under the authenticated user's path."
# ---------------------------------------------------------------------------
# Positive — when contextvar IS the auth user (the normal HTTP case), things
# already work. Pin it as a regression guard so future refactors don't
# accidentally break the contextvar path in pursuit of the runtime-context fix.
# ---------------------------------------------------------------------------
def test_update_agent_uses_contextvar_when_present(tmp_path: Path, monkeypatch):
"""The normal HTTP case: contextvar is set by auth_middleware. This must
keep working regardless of how runtime.context is populated."""
from types import SimpleNamespace as _SN
from deerflow.runtime.user_context import reset_current_user, set_current_user
auth_uid = "11112222-3333-4444-5555-666677778888"
user = _SN(id=auth_uid, email="ctxvar@local")
_seed_existing_agent(tmp_path, auth_uid, "ctxvar-agent", soul="# Original")
from langgraph.runtime import Runtime
config = _assemble_config(
body_context={"agent_name": "ctxvar-agent"},
request_user_id=auth_uid,
thread_id="thread-update-3",
)
runtime_ctx = _build_runtime_context("thread-update-3", "run-3", config.get("context"), None)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=runtime_ctx, store=None)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
graph = _build_update_graph(soul_payload="# CtxVar Updated")
with ExitStack() as stack:
for p in _patch_update_agent_dependencies(tmp_path):
stack.enter_context(p)
token = set_current_user(user)
try:
final = graph.invoke(
{"messages": [HumanMessage(content="update ctxvar-agent")]},
config=config,
)
finally:
reset_current_user(token)
# surface the tool's reply for debug if it errored
tool_replies = [m.content for m in final["messages"] if getattr(m, "type", "") == "tool"]
soul = (tmp_path / "users" / auth_uid / "agents" / "ctxvar-agent" / "SOUL.md").read_text()
assert soul == "# CtxVar Updated", f"tool replies: {tool_replies}"