mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
Merge branch 'main' of https://github.com/bytedance/deer-flow into rayhpeng/storage-package-base
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
"""Shared helpers for user-isolation e2e tests on the custom-agent tooling.
|
||||
|
||||
Centralises the small fake-LLM shim and a few test-data builders that the
|
||||
three e2e files in this PR (``test_setup_agent_e2e_user_isolation``,
|
||||
``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``)
|
||||
all need. The shim is what lets a real ``langchain.agents.create_agent``
|
||||
graph run without an API key — every other layer in those tests is real
|
||||
production code, which is the entire point of the test design.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
|
||||
class FakeToolCallingModel(FakeMessagesListChatModel):
|
||||
"""FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent.
|
||||
|
||||
``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to
|
||||
expose the tool schemas to the model; the upstream fake raises
|
||||
``NotImplementedError`` there. We just return ``self`` because we
|
||||
drive deterministic tool_call output via ``responses=...``, no schema
|
||||
handling needed.
|
||||
"""
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
return self
|
||||
|
||||
|
||||
def build_single_tool_call_model(
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str = "call_e2e_1",
|
||||
final_text: str = "done",
|
||||
) -> FakeToolCallingModel:
|
||||
"""Build a fake model that emits exactly one tool_call then finishes.
|
||||
|
||||
Two-turn behaviour, identical across our e2e tests:
|
||||
turn 1 → AIMessage with a single tool_call for *tool_name*
|
||||
turn 2 → AIMessage with *final_text* (terminates the agent loop)
|
||||
"""
|
||||
return FakeToolCallingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id,
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content=final_text),
|
||||
]
|
||||
)
|
||||
@@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
|
||||
issues when unit-testing lightweight config/registry code in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -11,11 +13,16 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
|
||||
|
||||
# Make 'app' and 'deerflow' importable from any working directory
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
|
||||
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
|
||||
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
|
||||
|
||||
# Break the circular import chain that exists in production code:
|
||||
# deerflow.subagents.__init__
|
||||
# -> .executor (SubagentExecutor, SubagentResult)
|
||||
@@ -56,6 +63,92 @@ def provisioner_module():
|
||||
return module
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def blocking_io_detector():
|
||||
"""Fail a focused test if blocking calls run on the event loop thread."""
|
||||
with detect_blocking_io(fail_on_exit=True) as detector:
|
||||
yield detector
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
group = parser.getgroup("blocking-io")
|
||||
group.addoption(
|
||||
"--detect-blocking-io",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
|
||||
)
|
||||
group.addoption(
|
||||
"--detect-blocking-io-fail",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set a failing exit status when --detect-blocking-io records violations.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
if _blocking_io_probe_enabled(session.config):
|
||||
_blocking_io_probe.clear()
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_call(item: pytest.Item):
|
||||
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
|
||||
yield
|
||||
return
|
||||
|
||||
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
|
||||
detector.__enter__()
|
||||
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_teardown(item: pytest.Item):
|
||||
yield
|
||||
|
||||
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
|
||||
if detector is None:
|
||||
return
|
||||
|
||||
try:
|
||||
detector.__exit__(None, None, None)
|
||||
_blocking_io_probe.record(item.nodeid, detector.violations)
|
||||
finally:
|
||||
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session) -> None:
|
||||
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
|
||||
session.exitstatus = pytest.ExitCode.TESTS_FAILED
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
|
||||
if not _blocking_io_probe_enabled(terminalreporter.config):
|
||||
return
|
||||
|
||||
header, *details = _blocking_io_probe.format_summary().splitlines()
|
||||
terminalreporter.write_sep("=", header)
|
||||
for line in details:
|
||||
terminalreporter.write_line(line)
|
||||
|
||||
|
||||
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
|
||||
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
|
||||
|
||||
|
||||
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
|
||||
return bool(config.getoption("--detect-blocking-io-fail"))
|
||||
|
||||
|
||||
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
|
||||
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Shared test support helpers."""
|
||||
@@ -0,0 +1 @@
|
||||
"""Runtime and static detectors used by tests."""
|
||||
@@ -0,0 +1,287 @@
|
||||
"""Test helper for detecting blocking calls on an asyncio event loop.
|
||||
|
||||
The detector is intentionally test-only. It monkeypatches a small set of
|
||||
well-known blocking entry points and their already-loaded module-level aliases,
|
||||
then records calls only when they happen on a thread that is currently running
|
||||
an asyncio event loop. Aliases captured in closures or default arguments remain
|
||||
out of scope.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
BlockingCallable = Callable[..., Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockingCallSpec:
|
||||
"""Describes one blocking callable to wrap during a detector run."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
record_on_iteration: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockingCall:
|
||||
"""One blocking call observed on an asyncio event loop thread."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
stack: tuple[traceback.FrameSummary, ...]
|
||||
|
||||
|
||||
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
|
||||
BlockingCallSpec("time.sleep", "time:sleep"),
|
||||
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
|
||||
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
|
||||
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
|
||||
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
|
||||
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
|
||||
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
|
||||
)
|
||||
|
||||
|
||||
def _is_event_loop_thread() -> bool:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return loop.is_running()
|
||||
|
||||
|
||||
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
|
||||
module_name, attr_path = target.split(":", maxsplit=1)
|
||||
owner: object = importlib.import_module(module_name)
|
||||
parts = attr_path.split(".")
|
||||
for part in parts[:-1]:
|
||||
owner = getattr(owner, part)
|
||||
|
||||
attr_name = parts[-1]
|
||||
original = getattr(owner, attr_name)
|
||||
return owner, attr_name, original
|
||||
|
||||
|
||||
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
|
||||
return tuple(frame for frame in stack if frame.filename != __file__)
|
||||
|
||||
|
||||
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
|
||||
"""Record blocking calls made from async runtime code.
|
||||
|
||||
By default the detector reports violations but does not fail on context
|
||||
exit. Tests can set ``fail_on_exit=True`` or call
|
||||
``assert_no_blocking_calls()`` explicitly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
||||
*,
|
||||
fail_on_exit: bool = False,
|
||||
patch_loaded_aliases: bool = True,
|
||||
stack_limit: int = 12,
|
||||
) -> None:
|
||||
self._specs = tuple(specs)
|
||||
self._fail_on_exit = fail_on_exit
|
||||
self._patch_loaded_aliases_enabled = patch_loaded_aliases
|
||||
self._stack_limit = stack_limit
|
||||
self._patches: list[tuple[object, str, BlockingCallable]] = []
|
||||
self._patch_keys: set[tuple[int, str]] = set()
|
||||
self.violations: list[BlockingCall] = []
|
||||
self._active = False
|
||||
|
||||
def __enter__(self) -> BlockingIODetector:
|
||||
try:
|
||||
self._active = True
|
||||
alias_replacements: dict[int, BlockingCallable] = {}
|
||||
for spec in self._specs:
|
||||
owner, attr_name, original = _resolve_target(spec.target)
|
||||
wrapper = self._wrap(spec, original)
|
||||
self._patch_attribute(owner, attr_name, original, wrapper)
|
||||
alias_replacements[id(original)] = wrapper
|
||||
|
||||
if self._patch_loaded_aliases_enabled:
|
||||
self._patch_loaded_module_aliases(alias_replacements)
|
||||
except Exception:
|
||||
self._restore()
|
||||
self._active = False
|
||||
raise
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback_value: TracebackType | None,
|
||||
) -> bool | None:
|
||||
self._restore()
|
||||
self._active = False
|
||||
if exc_type is None and self._fail_on_exit:
|
||||
self.assert_no_blocking_calls()
|
||||
return None
|
||||
|
||||
def _restore(self) -> None:
|
||||
for owner, attr_name, original in reversed(self._patches):
|
||||
setattr(owner, attr_name, original)
|
||||
self._patches.clear()
|
||||
self._patch_keys.clear()
|
||||
|
||||
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
|
||||
key = (id(owner), attr_name)
|
||||
if key in self._patch_keys:
|
||||
return
|
||||
setattr(owner, attr_name, replacement)
|
||||
self._patches.append((owner, attr_name, original))
|
||||
self._patch_keys.add(key)
|
||||
|
||||
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
|
||||
for module in tuple(sys.modules.values()):
|
||||
namespace = getattr(module, "__dict__", None)
|
||||
if not isinstance(namespace, dict):
|
||||
continue
|
||||
|
||||
for attr_name, value in tuple(namespace.items()):
|
||||
replacement = replacements_by_id.get(id(value))
|
||||
if replacement is not None:
|
||||
self._patch_attribute(module, attr_name, value, replacement)
|
||||
|
||||
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
|
||||
@wraps(original)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if spec.record_on_iteration:
|
||||
result = original(*args, **kwargs)
|
||||
return self._wrap_iteration(spec, result)
|
||||
self._record_if_blocking(spec)
|
||||
return original(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
|
||||
iterator = iter(iterable)
|
||||
reported = False
|
||||
|
||||
while True:
|
||||
if not reported:
|
||||
reported = self._record_if_blocking(spec)
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
|
||||
if self._active and _is_event_loop_thread():
|
||||
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
|
||||
self.violations.append(BlockingCall(spec.name, spec.target, stack))
|
||||
return True
|
||||
return False
|
||||
|
||||
def assert_no_blocking_calls(self) -> None:
|
||||
if self.violations:
|
||||
raise AssertionError(format_blocking_calls(self.violations))
|
||||
|
||||
|
||||
class BlockingIOProbe:
|
||||
"""Collect detector output across tests and format a compact summary."""
|
||||
|
||||
def __init__(self, project_root: Path) -> None:
|
||||
self._project_root = project_root.resolve()
|
||||
self._observed: list[tuple[str, BlockingCall]] = []
|
||||
|
||||
@property
|
||||
def violation_count(self) -> int:
|
||||
return len(self._observed)
|
||||
|
||||
@property
|
||||
def test_count(self) -> int:
|
||||
return len({nodeid for nodeid, _violation in self._observed})
|
||||
|
||||
def clear(self) -> None:
|
||||
self._observed.clear()
|
||||
|
||||
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
|
||||
for violation in violations:
|
||||
self._observed.append((nodeid, violation))
|
||||
|
||||
def format_summary(self, *, limit: int = 30) -> str:
|
||||
if not self._observed:
|
||||
return "blocking io probe: no violations"
|
||||
|
||||
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
|
||||
for _nodeid, violation in self._observed:
|
||||
frame = self._local_call_site(violation.stack)
|
||||
if frame is None:
|
||||
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
|
||||
continue
|
||||
|
||||
call_sites[
|
||||
(
|
||||
violation.name,
|
||||
self._relative(frame.filename),
|
||||
frame.lineno,
|
||||
frame.name,
|
||||
(frame.line or "").strip(),
|
||||
)
|
||||
] += 1
|
||||
|
||||
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
|
||||
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
|
||||
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _relative(self, filename: str) -> str:
|
||||
try:
|
||||
return str(Path(filename).resolve().relative_to(self._project_root))
|
||||
except ValueError:
|
||||
return filename
|
||||
|
||||
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
|
||||
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
|
||||
if local_frames:
|
||||
return local_frames[-1]
|
||||
|
||||
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
|
||||
return test_frames[-1] if test_frames else None
|
||||
|
||||
|
||||
def detect_blocking_io(
|
||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
||||
*,
|
||||
fail_on_exit: bool = False,
|
||||
patch_loaded_aliases: bool = True,
|
||||
stack_limit: int = 12,
|
||||
) -> BlockingIODetector:
|
||||
"""Create a detector context manager for a focused test scope."""
|
||||
|
||||
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
|
||||
|
||||
|
||||
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
|
||||
"""Format detector output with enough stack context to locate call sites."""
|
||||
|
||||
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
|
||||
for index, violation in enumerate(violations, start=1):
|
||||
lines.append(f"{index}. {violation.name} ({violation.target})")
|
||||
lines.extend(_format_stack(violation.stack))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
|
||||
for frame in stack:
|
||||
location = f"{frame.filename}:{frame.lineno}"
|
||||
lines = [f" at {frame.name} ({location})"]
|
||||
if frame.line:
|
||||
lines.append(f" {frame.line.strip()}")
|
||||
yield from lines
|
||||
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from os import walk as imported_walk
|
||||
from pathlib import Path
|
||||
from time import sleep as imported_sleep
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from support.detectors.blocking_io import (
|
||||
BlockingCallSpec,
|
||||
BlockingIOProbe,
|
||||
detect_blocking_io,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
|
||||
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
|
||||
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
|
||||
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
|
||||
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
|
||||
|
||||
|
||||
async def test_records_time_sleep_on_event_loop() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
time.sleep(0)
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
||||
|
||||
|
||||
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
|
||||
original_alias = imported_sleep
|
||||
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
imported_sleep(0)
|
||||
|
||||
assert imported_sleep is original_alias
|
||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
||||
|
||||
|
||||
async def test_can_disable_loaded_alias_patching() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
|
||||
imported_sleep(0)
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
await asyncio.to_thread(time.sleep, 0)
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
|
||||
await asyncio.to_thread(time.sleep, 0)
|
||||
|
||||
assert blocking_io_detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
|
||||
def call_sleep() -> list[str]:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
time.sleep(0)
|
||||
return [violation.name for violation in detector.violations]
|
||||
|
||||
assert await asyncio.to_thread(call_sleep) == []
|
||||
|
||||
|
||||
async def test_fail_on_exit_includes_call_site() -> None:
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
|
||||
time.sleep(0)
|
||||
|
||||
message = str(exc_info.value)
|
||||
assert "time.sleep" in message
|
||||
assert "test_fail_on_exit_includes_call_site" in message
|
||||
|
||||
|
||||
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
|
||||
return f"{method}:{url}"
|
||||
|
||||
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
|
||||
|
||||
with detect_blocking_io(REQUESTS_ONLY) as detector:
|
||||
assert requests.get("https://example.invalid") == "get:https://example.invalid"
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
|
||||
|
||||
|
||||
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
|
||||
return httpx.Response(200, request=httpx.Request(method, url))
|
||||
|
||||
monkeypatch.setattr(httpx.Client, "request", fake_request)
|
||||
|
||||
with detect_blocking_io(HTTPX_ONLY) as detector:
|
||||
with httpx.Client() as client:
|
||||
response = client.get("https://example.invalid")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
|
||||
|
||||
|
||||
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
assert list(os.walk(tmp_path))
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
||||
|
||||
|
||||
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
original_alias = imported_walk
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
assert list(imported_walk(tmp_path))
|
||||
|
||||
assert imported_walk is original_alias
|
||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
||||
|
||||
|
||||
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
walker = os.walk(tmp_path)
|
||||
|
||||
assert list(walker)
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
walker = os.walk(tmp_path)
|
||||
assert await asyncio.to_thread(lambda: list(walker))
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
|
||||
|
||||
|
||||
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
|
||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
||||
summary = probe.format_summary()
|
||||
|
||||
assert "blocking io probe: 1 violations across 1 tests" in summary
|
||||
assert "pathlib.Path.read_text" in summary
|
||||
|
||||
|
||||
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
|
||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
||||
|
||||
assert probe.format_summary() == "blocking io probe: no violations"
|
||||
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
||||
assert probe.violation_count == 1
|
||||
|
||||
probe.clear()
|
||||
|
||||
assert probe.violation_count == 0
|
||||
assert probe.format_summary() == "blocking io probe: no violations"
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
ORIGINAL_SLEEP = time.sleep
|
||||
|
||||
|
||||
def replacement_sleep(seconds: float) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(time, "sleep", replacement_sleep)
|
||||
assert time.sleep is replacement_sleep
|
||||
|
||||
|
||||
@pytest.mark.no_blocking_io_probe
|
||||
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
|
||||
assert time.sleep is ORIGINAL_SLEEP
|
||||
assert getattr(time.sleep, "__wrapped__", None) is None
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Real-LLM end-to-end verification for issue #2884.
|
||||
|
||||
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
|
||||
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
|
||||
and the production ``get_available_tools`` pipeline. The only thing we mock is
|
||||
the MCP tool source — we hand-roll two ``@tool``s and inject them through
|
||||
``deerflow.mcp.cache.get_cached_mcp_tools``.
|
||||
|
||||
The flow exercised:
|
||||
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
|
||||
that re-enters ``get_available_tools`` on the same task — this is the
|
||||
code path issue #2884 reports). It must call ``tool_search`` to
|
||||
discover the deferred ``fake_calculator`` tool.
|
||||
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
|
||||
``fake_subagent_trigger`` re-enters ``get_available_tools``.
|
||||
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
|
||||
so it can actually call it. Without this PR's fix, the re-entry wipes
|
||||
the promotion and the model can no longer invoke the tool.
|
||||
|
||||
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
|
||||
test run. Run with::
|
||||
|
||||
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
|
||||
PYTHONPATH=. uv run pytest \
|
||||
tests/test_deferred_tool_promotion_real_llm.py -v -s
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skip control: only run when explicitly opted in.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("ONEAPI_E2E") != "1",
|
||||
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake "MCP" tools the agent should discover via tool_search.
|
||||
# Keep them obviously synthetic so the model can pattern-match the search.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_calls: list[str] = []
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_calculator(expression: str) -> str:
|
||||
"""Evaluate a tiny arithmetic expression like '2 + 2'.
|
||||
|
||||
Reserved for the user — only call this if the user asks for arithmetic.
|
||||
"""
|
||||
_calls.append(f"fake_calculator:{expression}")
|
||||
try:
|
||||
# Trivially safe-eval just for the e2e check
|
||||
allowed = set("0123456789+-*/() .")
|
||||
if not set(expression) <= allowed:
|
||||
return "expression contains disallowed characters"
|
||||
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
|
||||
except Exception as e:
|
||||
return f"error: {e}"
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_translator(text: str, target_lang: str) -> str:
|
||||
"""Translate text into the given language code. Decorative — not used."""
|
||||
_calls.append(f"fake_translator:{text}:{target_lang}")
|
||||
return f"[{target_lang}] {text}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline wiring (same shape as the in-process tests).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry_between_tests():
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
real_ext = ExtensionsConfig(
|
||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: real_ext),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
||||
|
||||
|
||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Build a minimal mock AppConfig and patch the symbol — never call the
|
||||
real loader, which would trigger ``_apply_singleton_configs`` and
|
||||
permanently mutate cross-test singletons (memory, title, …)."""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_cfg = AppConfig.model_construct(
|
||||
log_level="info",
|
||||
models=[],
|
||||
tools=[],
|
||||
tool_groups=[],
|
||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
||||
tool_search=ToolSearchConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real-LLM e2e test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
|
||||
"""End-to-end against a real OpenAI-compatible LLM.
|
||||
|
||||
The model must:
|
||||
Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and
|
||||
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
|
||||
``fake_subagent_trigger(...)``.
|
||||
Turn 2 — call ``fake_calculator`` and finish.
|
||||
|
||||
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
|
||||
layer — recorded in ``_calls`` — which proves the model received the
|
||||
promoted schema after the re-entrant ``get_available_tools`` call.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
_calls.clear()
|
||||
|
||||
@as_tool
|
||||
async def fake_subagent_trigger(prompt: str) -> str:
|
||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
|
||||
|
||||
Use this whenever the user asks you to delegate work — pass a short
|
||||
description as ``prompt``.
|
||||
"""
|
||||
# ``task_tool`` does this internally. Whether the registry-reset that
|
||||
# used to happen here actually leaks back to the parent task depends
|
||||
# on asyncio's implicit context-copying semantics (gather creates
|
||||
# child tasks with copied contexts, so reset_deferred_registry is
|
||||
# task-local) — but the fix in this PR is what GUARANTEES the
|
||||
# promotion sticks regardless of which integration path triggers a
|
||||
# re-entrant ``get_available_tools`` call.
|
||||
get_available_tools(subagent_enabled=False)
|
||||
_calls.append(f"fake_subagent_trigger:{prompt}")
|
||||
return "subagent completed"
|
||||
|
||||
tools = get_available_tools() + [fake_subagent_trigger]
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
temperature=0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a meticulous assistant. Available deferred tools include a "
|
||||
"calculator and a translator — their schemas are hidden until you "
|
||||
"search for them via tool_search.\n\n"
|
||||
"Procedure for the user's request:\n"
|
||||
" 1. Call tool_search with query 'select:fake_calculator' AND "
|
||||
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
|
||||
"to delegate the side work. Put both tool_calls in your first response.\n"
|
||||
" 2. After both tool messages come back, call fake_calculator with "
|
||||
"the user's expression.\n"
|
||||
" 3. Reply with just the numeric result."
|
||||
)
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
result = await graph.ainvoke(
|
||||
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
|
||||
config={"recursion_limit": 12},
|
||||
)
|
||||
|
||||
print("\n=== tool calls recorded ===")
|
||||
for c in _calls:
|
||||
print(f" {c}")
|
||||
print("\n=== final message ===")
|
||||
final_text = result["messages"][-1].content if result["messages"] else "(none)"
|
||||
print(f" {final_text!r}")
|
||||
|
||||
# The smoking-gun assertion: fake_calculator was actually invoked at the
|
||||
# tool layer. This is only possible if the promoted schema reached the
|
||||
# model in turn 2, despite the subagent-style re-entry in turn 1.
|
||||
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
|
||||
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
|
||||
|
||||
# And the math should actually be done correctly (sanity that the LLM
|
||||
# really used the result, not just hallucinated the answer).
|
||||
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
|
||||
@@ -0,0 +1,390 @@
|
||||
"""Reproduce + regression-guard issue #2884.
|
||||
|
||||
Hypothesis from the issue:
|
||||
``tools.tools.get_available_tools`` unconditionally calls
|
||||
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
|
||||
every time it is invoked. If anything calls ``get_available_tools`` again
|
||||
during the same async context (after the agent has promoted tools via
|
||||
``tool_search``), the promotion is wiped and the next model call hides the
|
||||
tool's schema again.
|
||||
|
||||
These tests pin two things:
|
||||
|
||||
A. **At the unit boundary** — verify the failure mode directly. Promote a
|
||||
tool in the registry, then call ``get_available_tools`` again and observe
|
||||
that the ContextVar registry is reset and the promotion is lost.
|
||||
|
||||
B. **At the graph-execution boundary** — drive a real ``create_agent`` graph
|
||||
with the real ``DeferredToolFilterMiddleware`` through two model turns.
|
||||
The first turn calls ``tool_search`` which promotes a tool. The second
|
||||
turn must see that tool's schema in ``request.tools``. If
|
||||
``get_available_tools`` were to run again between the two turns and reset
|
||||
the registry, the second turn's filter would strip the tool.
|
||||
|
||||
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
|
||||
unmodified; mock only the LLM and the MCP tool source. Patch
|
||||
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
|
||||
``get_available_tools`` resolves via lazy import) to return our fixture
|
||||
tools so we don't need a real MCP server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
|
||||
class FakeToolCallingModel(FakeMessagesListChatModel):
|
||||
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_mcp_search(query: str) -> str:
|
||||
"""Pretend to search a knowledge base for the given query."""
|
||||
return f"results for {query}"
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_mcp_fetch(url: str) -> str:
|
||||
"""Pretend to fetch a page at the given URL."""
|
||||
return f"content of {url}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _supply_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
|
||||
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_deferred_registry_between_tests():
|
||||
"""Each test must start with a clean ContextVar.
|
||||
|
||||
The registry lives in a module-level ContextVar with no per-task isolation
|
||||
in a synchronous test runner, so one test's promotion can leak into the
|
||||
next and silently break filter assertions.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
||||
"""Make get_available_tools believe an MCP server is registered.
|
||||
|
||||
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
|
||||
that both ``AppConfig.from_file`` (which calls
|
||||
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
|
||||
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
|
||||
see a valid instance. Then point the MCP tool cache at our fixture tools.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
real_ext = ExtensionsConfig(
|
||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: real_ext),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
||||
|
||||
|
||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Force config.tool_search.enabled=True without touching the yaml.
|
||||
|
||||
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
|
||||
which permanently mutates module-level singletons (``_memory_config``,
|
||||
``_title_config``, …) to match the developer's ``config.yaml`` — even
|
||||
after pytest restores our patch. That leaks across tests later in the
|
||||
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
|
||||
require ``_memory_config.enabled = True``, which is the dataclass default
|
||||
but FALSE in the actual yaml).
|
||||
|
||||
Build a minimal mock AppConfig instead and never call the real loader.
|
||||
"""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_cfg = AppConfig.model_construct(
|
||||
log_level="info",
|
||||
models=[],
|
||||
tools=[],
|
||||
tool_groups=[],
|
||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
||||
tool_search=ToolSearchConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section A — direct unit-level reproduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
|
||||
|
||||
Step 1: call get_available_tools() — registers MCP tools as deferred.
|
||||
Step 2: simulate the agent calling tool_search by promoting one tool.
|
||||
Step 3: call get_available_tools() again (the same code path
|
||||
``task_tool`` exercises mid-run).
|
||||
|
||||
Assertion: after step 3, the promoted tool is STILL promoted (not
|
||||
re-deferred). On ``main`` before the fix, step 3's
|
||||
``reset_deferred_registry()`` wiped the promotion and re-registered
|
||||
every MCP tool as deferred — this assertion fired with
|
||||
``REGRESSION (#2884)``.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
# Step 1: first call — both MCP tools start deferred
|
||||
get_available_tools()
|
||||
reg1 = get_deferred_registry()
|
||||
assert reg1 is not None
|
||||
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
|
||||
|
||||
# Step 2: simulate tool_search promoting one of them
|
||||
reg1.promote({"fake_mcp_search"})
|
||||
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
|
||||
|
||||
# Step 3: second call — registry must NOT silently undo the promotion
|
||||
get_available_tools()
|
||||
reg2 = get_deferred_registry()
|
||||
assert reg2 is not None
|
||||
deferred_after = {e.name for e in reg2.entries}
|
||||
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section B — graph-execution reproduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ToolSearchPromotingModel(FakeToolCallingModel):
|
||||
"""Two-turn model that:
|
||||
|
||||
Turn 1 → emit a tool_call for ``tool_search`` (the real one)
|
||||
Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool)
|
||||
|
||||
Records the tools it received on each turn so the test can inspect what
|
||||
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
|
||||
"""
|
||||
|
||||
bound_tools_per_turn: list[list[str]] = []
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
# Record the tool names the model would see in this turn
|
||||
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
|
||||
self.bound_tools_per_turn.append(names)
|
||||
return self
|
||||
|
||||
|
||||
def _build_promoting_model() -> _ToolSearchPromotingModel:
|
||||
return _ToolSearchPromotingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool_search",
|
||||
"args": {"query": "select:fake_mcp_search"},
|
||||
"id": "call_search_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fake_mcp_search",
|
||||
"args": {"query": "hello"},
|
||||
"id": "call_mcp_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="all done"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""End-to-end: drive a real create_agent graph through two turns.
|
||||
|
||||
Without the fix, the second-turn bind_tools call should NOT contain
|
||||
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
|
||||
registry and strips it). With the fix, the model sees the schema and can
|
||||
invoke it.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
tools = get_available_tools()
|
||||
# Sanity: the assembled tool list includes the deferred tools (they're in
|
||||
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
|
||||
# they reach the model)
|
||||
tool_names = {getattr(t, "name", "") for t in tools}
|
||||
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
|
||||
|
||||
model = _build_promoting_model()
|
||||
model.bound_tools_per_turn = [] # reset class-level recorder
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt="bug-2884-repro",
|
||||
)
|
||||
|
||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
||||
|
||||
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
|
||||
turn1 = set(model.bound_tools_per_turn[0])
|
||||
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
|
||||
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
|
||||
|
||||
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
|
||||
# This is the load-bearing assertion for issue #2884.
|
||||
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
|
||||
turn2 = set(model.bound_tools_per_turn[1])
|
||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section C — the actual issue #2884 trigger: a re-entrant
|
||||
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
|
||||
# wipe the parent's promotion.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
|
||||
(the same pattern that happens when ``task_tool`` builds a subagent's
|
||||
toolset mid-run) must not wipe the parent agent's tool_search promotions.
|
||||
|
||||
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
|
||||
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
|
||||
``get_available_tools`` again — exactly what ``task_tool`` does when it
|
||||
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
|
||||
promoted tool. Without the fix, the re-entry wipes the registry and
|
||||
the filter re-hides it.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
# The trigger tool simulates what task_tool does internally: rebuild the
|
||||
# toolset by calling get_available_tools while the registry is live.
|
||||
@as_tool
|
||||
def fake_subagent_trigger(prompt: str) -> str:
|
||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
|
||||
get_available_tools(subagent_enabled=False)
|
||||
return f"spawned subagent for: {prompt}"
|
||||
|
||||
tools = get_available_tools() + [fake_subagent_trigger]
|
||||
|
||||
bound_per_turn: list[list[str]] = []
|
||||
|
||||
class _Model(FakeToolCallingModel):
|
||||
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
|
||||
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
|
||||
return self
|
||||
|
||||
model = _Model(
|
||||
responses=[
|
||||
# Turn 1: do both in one batch — promote AND trigger the
|
||||
# subagent-style rebuild. LangGraph executes them in order in the
|
||||
# same agent step.
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool_search",
|
||||
"args": {"query": "select:fake_mcp_search"},
|
||||
"id": "call_search_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "fake_subagent_trigger",
|
||||
"args": {"prompt": "go"},
|
||||
"id": "call_trigger_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
# Turn 2: try to invoke the promoted tool. The model gets this
|
||||
# turn only if turn 1's bind_tools recorded what the filter sent.
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fake_mcp_search",
|
||||
"args": {"query": "hello"},
|
||||
"id": "call_mcp_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="all done"),
|
||||
]
|
||||
)
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt="bug-2884-subagent-repro",
|
||||
)
|
||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
||||
|
||||
# Turn 1 sanity: deferred tool not visible yet
|
||||
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
|
||||
|
||||
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
|
||||
# re-entrant get_available_tools call that happened in turn 1's tool batch.
|
||||
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
|
||||
turn2 = set(bound_per_turn[1])
|
||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
|
||||
@@ -1,6 +1,6 @@
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
|
||||
assert elapsed < 0.1
|
||||
assert finished.is_set() is False
|
||||
assert finished.wait(1.0) is True
|
||||
|
||||
|
||||
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||
|
||||
assert queue.pending_count == 2
|
||||
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
|
||||
|
||||
|
||||
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(
|
||||
thread_id="thread-1",
|
||||
messages=["first"],
|
||||
agent_name="agent-a",
|
||||
correction_detected=True,
|
||||
)
|
||||
queue.add(
|
||||
thread_id="thread-1",
|
||||
messages=["second"],
|
||||
agent_name="agent-a",
|
||||
correction_detected=False,
|
||||
)
|
||||
|
||||
assert queue.pending_count == 1
|
||||
assert queue._queue[0].agent_name == "agent-a"
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].correction_detected is True
|
||||
|
||||
|
||||
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
|
||||
patch("deerflow.agents.memory.queue.time.sleep"),
|
||||
):
|
||||
queue.flush()
|
||||
|
||||
assert mock_updater.update_memory.call_count == 2
|
||||
mock_updater.update_memory.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
messages=["agent-a"],
|
||||
thread_id="thread-1",
|
||||
agent_name="agent-a",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
),
|
||||
call(
|
||||
messages=["agent-b"],
|
||||
thread_id="thread-1",
|
||||
agent_name="agent-b",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
assert q._queue[0].user_id == "alice"
|
||||
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
@@ -37,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
|
||||
mock_updater.update_memory.assert_called_once()
|
||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "alice"
|
||||
|
||||
|
||||
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||
|
||||
assert q.pending_count == 2
|
||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
|
||||
|
||||
|
||||
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
|
||||
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
|
||||
|
||||
assert q.pending_count == 1
|
||||
assert q._queue[0].messages == ["second"]
|
||||
assert q._queue[0].user_id == "alice"
|
||||
assert q._queue[0].agent_name == "researcher"
|
||||
|
||||
|
||||
def test_add_nowait_keeps_different_users_separate():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
|
||||
patch.object(q, "_schedule_timer"),
|
||||
):
|
||||
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||
|
||||
assert q.pending_count == 2
|
||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||
|
||||
@@ -268,6 +268,39 @@ class TestEdgeCases:
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.dialect = postgresql.dialect()
|
||||
self.execute_calls = []
|
||||
self.scalar_stmt = None
|
||||
|
||||
def get_bind(self):
|
||||
return self
|
||||
|
||||
async def execute(self, stmt, params=None):
|
||||
self.execute_calls.append((stmt, params))
|
||||
|
||||
async def scalar(self, stmt):
|
||||
self.scalar_stmt = stmt
|
||||
return 41
|
||||
|
||||
session = FakeSession()
|
||||
|
||||
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
|
||||
|
||||
assert max_seq == 41
|
||||
assert session.execute_calls
|
||||
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
|
||||
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
|
||||
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "FOR UPDATE" not in compiled
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
@@ -278,3 +281,48 @@ class TestRunRepository:
|
||||
assert row4["model_name"] is None
|
||||
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
|
||||
captured = []
|
||||
|
||||
class FakeResult:
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
class FakeSession:
|
||||
async def execute(self, stmt):
|
||||
captured.append(stmt)
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionContext:
|
||||
async def __aenter__(self):
|
||||
return FakeSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
repo = RunRepository(lambda: FakeSessionContext())
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||
assert agg == {
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_runs": 0,
|
||||
"by_model": {},
|
||||
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
|
||||
}
|
||||
assert len(captured) == 1
|
||||
|
||||
stmt = captured[0]
|
||||
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
|
||||
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
|
||||
|
||||
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
|
||||
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
|
||||
|
||||
assert select_match is not None
|
||||
assert group_by_match is not None
|
||||
assert select_match.group(1) == group_by_match.group(1)
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
"""End-to-end verification for issue #2862 (and the regression of #2782).
|
||||
|
||||
Goal: prove — without trusting any single layer's claim — that an authenticated
|
||||
user creating a custom agent through the real ``setup_agent`` tool, driven by a
|
||||
real LangGraph ``create_agent`` graph, ends up with files under
|
||||
``users/<auth_uid>/agents/<name>`` and **not** under ``users/default/agents/...``.
|
||||
|
||||
We intentionally exercise the full pipeline:
|
||||
|
||||
HTTP body shape (mimics LangGraph SDK wire format)
|
||||
-> app.gateway.services.start_run config-assembly chain
|
||||
-> deerflow.runtime.runs.worker._build_runtime_context
|
||||
-> langchain.agents.create_agent graph
|
||||
-> ToolNode dispatch
|
||||
-> setup_agent tool
|
||||
|
||||
The only thing we mock is the LLM (FakeMessagesListChatModel) — every layer
|
||||
that handles ``user_id`` is the real production code path. If the
|
||||
``user_id`` propagation is broken anywhere in this chain, these tests will
|
||||
fail.
|
||||
|
||||
These tests intentionally ``no_auto_user`` so that the ``contextvar``
|
||||
fallback would put files into ``default/`` if propagation breaks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from _agent_e2e_helpers import FakeToolCallingModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.gateway.services import (
|
||||
build_run_config,
|
||||
inject_authenticated_user_context,
|
||||
merge_run_context_overrides,
|
||||
)
|
||||
from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — real production code paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(user_id_str: str | None) -> SimpleNamespace:
|
||||
"""Build a fake FastAPI Request that carries an authenticated user."""
|
||||
if user_id_str is None:
|
||||
user = None
|
||||
else:
|
||||
# User.id is UUID in production; honour that
|
||||
user = SimpleNamespace(id=UUID(user_id_str), email="alice@local")
|
||||
return SimpleNamespace(state=SimpleNamespace(user=user))
|
||||
|
||||
|
||||
def _assemble_config(
|
||||
*,
|
||||
body_config: dict | None,
|
||||
body_context: dict | None,
|
||||
request_user_id: str | None,
|
||||
thread_id: str = "thread-e2e",
|
||||
assistant_id: str = "lead_agent",
|
||||
) -> dict:
|
||||
"""Replay the **exact** start_run config-assembly sequence."""
|
||||
config = build_run_config(thread_id, body_config, None, assistant_id=assistant_id)
|
||||
merge_run_context_overrides(config, body_context)
|
||||
inject_authenticated_user_context(config, _make_request(request_user_id))
|
||||
return config
|
||||
|
||||
|
||||
def _make_paths_mock(tmp_path: Path):
|
||||
"""Mirror the production paths.user_agent_dir signature."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
paths = MagicMock()
|
||||
paths.base_dir = tmp_path
|
||||
paths.agent_dir = lambda name: tmp_path / "agents" / name
|
||||
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
|
||||
return paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1-L3: HTTP wire format → start_run → worker._build_runtime_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigAssembly:
|
||||
"""Covers L1-L3: validate that user_id reaches runtime_ctx for every wire shape."""
|
||||
|
||||
def test_typical_wire_format_user_id_in_runtime_ctx(self):
|
||||
"""Real frontend: body.config={recursion_limit}, body.context={agent_name,...}."""
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 1000},
|
||||
body_context={"agent_name": "myagent", "is_bootstrap": True, "mode": "flash"},
|
||||
request_user_id="11111111-2222-3333-4444-555555555555",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
|
||||
assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555"
|
||||
assert runtime_ctx["agent_name"] == "myagent"
|
||||
|
||||
def test_body_context_none_still_injects_user_id(self):
|
||||
"""If frontend omits body.context entirely, inject must still create it."""
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 1000},
|
||||
body_context=None,
|
||||
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
|
||||
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
def test_body_context_empty_dict_still_injects_user_id(self):
|
||||
"""body.context={} (falsy) path: inject must still produce user_id."""
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 1000},
|
||||
body_context={},
|
||||
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
|
||||
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
def test_body_config_already_contains_context_field(self):
|
||||
"""body.config={'context': {...}} (LG 0.6 alt wire): inject still wins."""
|
||||
config = _assemble_config(
|
||||
body_config={"context": {"agent_name": "myagent"}, "recursion_limit": 1000},
|
||||
body_context=None,
|
||||
request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
|
||||
assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
def test_client_supplied_user_id_is_overridden(self):
|
||||
"""Spoofed client user_id must be overwritten by inject (auth-trusted source)."""
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 1000},
|
||||
body_context={"agent_name": "myagent", "user_id": "spoofed"},
|
||||
request_user_id="11111111-2222-3333-4444-555555555555",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
|
||||
assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555"
|
||||
|
||||
def test_unauthenticated_request_does_not_inject(self):
|
||||
"""If request.state.user is missing (impossible under fail-closed auth, but
|
||||
verify defensively), inject must not write user_id and runtime_ctx must
|
||||
therefore lack it — forcing the tool fallback path to reveal itself."""
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 1000},
|
||||
body_context={"agent_name": "myagent"},
|
||||
request_user_id=None,
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None)
|
||||
assert "user_id" not in runtime_ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L4-L7: Real LangGraph create_agent driving the real setup_agent tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_real_bootstrap_graph(authenticated_user_id: str):
|
||||
"""Construct a real LangGraph using create_agent + the real setup_agent tool.
|
||||
|
||||
The LLM is faked (FakeMessagesListChatModel) so we don't need an API key.
|
||||
Everything else — ToolNode dispatch, runtime injection, middleware — is
|
||||
the real production code path.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
||||
|
||||
# First model turn: emit a tool_call for setup_agent
|
||||
# Second model turn (after tool result): final answer (terminates the loop)
|
||||
fake_model = FakeToolCallingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "setup_agent",
|
||||
"args": {
|
||||
"soul": "# My E2E Agent\n\nA SOUL written by the model.",
|
||||
"description": "End-to-end test agent",
|
||||
},
|
||||
"id": "call_setup_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content=f"Done. Agent created for user {authenticated_user_id}."),
|
||||
]
|
||||
)
|
||||
|
||||
graph = create_agent(
|
||||
model=fake_model,
|
||||
tools=[setup_agent],
|
||||
system_prompt="You are a bootstrap agent. Call setup_agent immediately.",
|
||||
)
|
||||
return graph
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_graph_real_setup_agent_writes_to_authenticated_user_dir(tmp_path: Path):
|
||||
"""The smoking-gun test for issue #2862.
|
||||
|
||||
Under no_auto_user (contextvar = empty), if user_id propagation through
|
||||
runtime.context is broken, setup_agent will fall back to DEFAULT_USER_ID
|
||||
and write to users/default/agents/... The assertion that this directory
|
||||
DOES NOT exist is what makes this test load-bearing.
|
||||
"""
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
auth_uid = "abcdef01-2345-6789-abcd-ef0123456789"
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 50},
|
||||
body_context={"agent_name": "e2e-agent", "is_bootstrap": True},
|
||||
request_user_id=auth_uid,
|
||||
thread_id="thread-e2e-1",
|
||||
)
|
||||
|
||||
# Replay worker.run_agent's runtime construction. This is the key step:
|
||||
# it is what makes ToolRuntime.context contain user_id when the tool
|
||||
# actually fires.
|
||||
runtime_ctx = _build_runtime_context("thread-e2e-1", "run-1", config.get("context"), None)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
graph = _build_real_bootstrap_graph(auth_uid)
|
||||
|
||||
# Patch get_paths only (the file-system rooting); everything else is real
|
||||
with patch(
|
||||
"deerflow.tools.builtins.setup_agent_tool.get_paths",
|
||||
return_value=_make_paths_mock(tmp_path),
|
||||
):
|
||||
# Drive the real graph. This goes through real ToolNode + real Runtime merge.
|
||||
final_state = await graph.ainvoke(
|
||||
{"messages": [HumanMessage(content="Create an agent named e2e-agent")]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
expected_dir = tmp_path / "users" / auth_uid / "agents" / "e2e-agent"
|
||||
default_dir = tmp_path / "users" / "default" / "agents" / "e2e-agent"
|
||||
|
||||
# Load-bearing assertions:
|
||||
assert expected_dir.exists(), f"Agent directory not found at the authenticated user's path. Expected: {expected_dir}. tmp_path tree: {[str(p) for p in tmp_path.rglob('*')]}"
|
||||
assert (expected_dir / "SOUL.md").read_text() == "# My E2E Agent\n\nA SOUL written by the model."
|
||||
assert (expected_dir / "config.yaml").exists()
|
||||
assert not default_dir.exists(), "REGRESSION: agent landed under users/default/. user_id propagation broke somewhere between HTTP layer and ToolRuntime.context."
|
||||
|
||||
# And final state should reflect tool success
|
||||
last = final_state["messages"][-1]
|
||||
assert "Done" in (last.content if isinstance(last.content, str) else str(last.content))
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_failure_falls_back_to_default_proving_test_is_load_bearing(tmp_path: Path):
|
||||
"""Negative control: if inject does NOT happen (no user in request), and
|
||||
contextvar is empty (no_auto_user), setup_agent must land in default/.
|
||||
|
||||
This proves the positive test is actually load-bearing — i.e. it would
|
||||
have failed before PR #2784, not passed accidentally.
|
||||
"""
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 50},
|
||||
body_context={"agent_name": "fallback-agent", "is_bootstrap": True},
|
||||
request_user_id=None, # no auth — inject is a no-op
|
||||
thread_id="thread-e2e-2",
|
||||
)
|
||||
|
||||
runtime_ctx = _build_runtime_context("thread-e2e-2", "run-2", config.get("context"), None)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
graph = _build_real_bootstrap_graph("does-not-matter")
|
||||
|
||||
with patch(
|
||||
"deerflow.tools.builtins.setup_agent_tool.get_paths",
|
||||
return_value=_make_paths_mock(tmp_path),
|
||||
):
|
||||
await graph.ainvoke(
|
||||
{"messages": [HumanMessage(content="Create fallback-agent")]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
default_dir = tmp_path / "users" / "default" / "agents" / "fallback-agent"
|
||||
assert default_dir.exists(), "Negative control failed: even without inject + contextvar, agent did not land in default/. The test infrastructure may not be reproducing the bug condition."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L5: Sub-graph runtime propagation (the task tool case)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
@pytest.mark.asyncio
|
||||
async def test_subgraph_invocation_preserves_user_id_in_runtime(tmp_path: Path):
|
||||
"""When a parent graph invokes a child graph (the pattern used by
|
||||
subagents), parent_runtime.merge() must keep user_id intact.
|
||||
|
||||
We construct a child graph that contains setup_agent and call it from
|
||||
a parent graph's tool. If LangGraph re-creates the Runtime and drops
|
||||
user_id at the sub-graph boundary, this fails.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
||||
|
||||
auth_uid = "deadbeef-0000-1111-2222-333344445555"
|
||||
|
||||
# Inner graph: same as the bootstrap flow
|
||||
inner_model = FakeToolCallingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "setup_agent",
|
||||
"args": {"soul": "# Inner", "description": "subgraph"},
|
||||
"id": "call_inner_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="inner done"),
|
||||
]
|
||||
)
|
||||
inner_graph = create_agent(
|
||||
model=inner_model,
|
||||
tools=[setup_agent],
|
||||
system_prompt="inner",
|
||||
)
|
||||
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 50},
|
||||
body_context={"agent_name": "subgraph-agent", "is_bootstrap": True},
|
||||
request_user_id=auth_uid,
|
||||
thread_id="thread-e2e-3",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e-3", "run-3", config.get("context"), None)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
with patch(
|
||||
"deerflow.tools.builtins.setup_agent_tool.get_paths",
|
||||
return_value=_make_paths_mock(tmp_path),
|
||||
):
|
||||
# Direct sub-graph invoke (mimics what a subagent invocation looks like
|
||||
# — distinct ainvoke call, but parent config carries the same runtime).
|
||||
await inner_graph.ainvoke(
|
||||
{"messages": [HumanMessage(content="Create subgraph-agent")]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
expected_dir = tmp_path / "users" / auth_uid / "agents" / "subgraph-agent"
|
||||
default_dir = tmp_path / "users" / "default" / "agents" / "subgraph-agent"
|
||||
assert expected_dir.exists()
|
||||
assert not default_dir.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L6: Sync tool path through ContextThreadPoolExecutor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sync_tool_dispatch_through_thread_pool_uses_runtime_context(tmp_path: Path):
|
||||
"""setup_agent is a sync function. When dispatched through ToolNode's
|
||||
ContextThreadPoolExecutor, runtime.context must still carry user_id —
|
||||
not via thread-local copy_context (which only carries contextvars), but
|
||||
because it was passed in as the ToolRuntime constructor argument.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
||||
|
||||
auth_uid = "11112222-3333-4444-5555-666677778888"
|
||||
|
||||
fake_model = FakeToolCallingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "setup_agent",
|
||||
"args": {"soul": "# Sync", "description": "sync path"},
|
||||
"id": "call_sync_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="sync done"),
|
||||
]
|
||||
)
|
||||
graph = create_agent(model=fake_model, tools=[setup_agent], system_prompt="sync")
|
||||
|
||||
config = _assemble_config(
|
||||
body_config={"recursion_limit": 50},
|
||||
body_context={"agent_name": "sync-agent", "is_bootstrap": True},
|
||||
request_user_id=auth_uid,
|
||||
thread_id="thread-e2e-4",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-e2e-4", "run-4", config.get("context"), None)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
with patch(
|
||||
"deerflow.tools.builtins.setup_agent_tool.get_paths",
|
||||
return_value=_make_paths_mock(tmp_path),
|
||||
):
|
||||
# Use SYNC invoke to hit the ContextThreadPoolExecutor path
|
||||
graph.invoke(
|
||||
{"messages": [HumanMessage(content="Create sync-agent")]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
expected_dir = tmp_path / "users" / auth_uid / "agents" / "sync-agent"
|
||||
default_dir = tmp_path / "users" / "default" / "agents" / "sync-agent"
|
||||
assert expected_dir.exists()
|
||||
assert not default_dir.exists()
|
||||
@@ -0,0 +1,326 @@
|
||||
"""Real HTTP end-to-end verification for issue #2862's setup_agent path.
|
||||
|
||||
This test drives the **entire** FastAPI gateway through ``starlette.testclient.TestClient``:
|
||||
|
||||
starlette.testclient.TestClient (real ASGI stack)
|
||||
-> AuthMiddleware (real cookie parsing, real JWT decode)
|
||||
-> /api/v1/auth/register endpoint (real password hash + sqlite write)
|
||||
-> /api/threads/{id}/runs/stream endpoint (real start_run config-assembly)
|
||||
-> background asyncio.create_task(run_agent) (real worker, real Runtime)
|
||||
-> langchain.agents.create_agent graph (real, with fake LLM)
|
||||
-> ToolNode dispatch (real)
|
||||
-> setup_agent tool (real file I/O)
|
||||
|
||||
The only mock is the LLM (no API key needed). Every layer that participates
|
||||
in ``user_id`` propagation — auth, ContextVar, ``inject_authenticated_user_context``,
|
||||
``worker._build_runtime_context``, ``Runtime.merge`` — is the real production
|
||||
code path. If the chain is broken at any layer, this test fails.
|
||||
|
||||
This is what "真实验证" looks like for a server that lives behind authentication:
|
||||
register a user, log in (cookie), POST to /runs/stream, wait for the run to
|
||||
finish, then read the filesystem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model
|
||||
|
||||
|
||||
def _build_fake_create_chat_model(agent_name: str):
|
||||
"""Return a callable matching the real ``create_chat_model`` signature.
|
||||
|
||||
Whenever the lead agent constructs a chat model during the bootstrap flow,
|
||||
we hand it a fake that emits a single setup_agent tool_call on its first
|
||||
turn, then a benign final answer on its second turn.
|
||||
"""
|
||||
|
||||
def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel:
|
||||
return build_single_tool_call_model(
|
||||
tool_name="setup_agent",
|
||||
tool_args={
|
||||
"soul": f"# Real HTTP E2E SOUL for {agent_name}",
|
||||
"description": "real-http-e2e agent",
|
||||
},
|
||||
tool_call_id="call_real_http_1",
|
||||
final_text=f"Agent {agent_name} created via real HTTP e2e.",
|
||||
)
|
||||
|
||||
return fake_create_chat_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Stand up an isolated DeerFlow data root + config under tmp_path.
|
||||
|
||||
- Sets ``DEER_FLOW_HOME`` so paths land under tmp_path, not the real
|
||||
``.deer-flow`` directory.
|
||||
- Stages a copy of the project's ``config.yaml`` (or ``config.example.yaml``
|
||||
on a fresh CI checkout where ``config.yaml`` is gitignored) and pins
|
||||
``DEER_FLOW_CONFIG_PATH`` to it, so lifespan boot doesn't depend on the
|
||||
developer's local config layout.
|
||||
- Sets a placeholder OPENAI_API_KEY because the config has
|
||||
``$OPENAI_API_KEY`` that gets resolved at parse time; the LLM itself is
|
||||
mocked, so any non-empty value works.
|
||||
"""
|
||||
home = tmp_path / "deer-flow-home"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(home))
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used-because-llm-is-mocked")
|
||||
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
|
||||
|
||||
# Hermetic config: do not depend on whether the dev machine has a real
|
||||
# ``config.yaml`` at the repo root. CI's ``actions/checkout`` only ships
|
||||
# ``config.example.yaml`` (and its ``models:`` list is commented out, so
|
||||
# AppConfig validation would reject it). Write a minimal, self-sufficient
|
||||
# config to tmp_path and pin ``DEER_FLOW_CONFIG_PATH`` to it.
|
||||
staged_config = tmp_path / "config.yaml"
|
||||
staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config))
|
||||
|
||||
return home
|
||||
|
||||
|
||||
# Minimal config that satisfies AppConfig + LeadAgent's _resolve_model_name.
|
||||
# The model `use` path must resolve to a real class for config parsing to
|
||||
# succeed; the test patches ``create_chat_model`` on the lead agent module,
|
||||
# so the model is never actually instantiated. SandboxConfig.use is required
|
||||
# at schema level; LocalSandboxProvider is the only sandbox that runs without
|
||||
# Docker.
|
||||
_MINIMAL_CONFIG_YAML = """\
|
||||
log_level: info
|
||||
models:
|
||||
- name: fake-test-model
|
||||
display_name: Fake Test Model
|
||||
use: langchain_openai:ChatOpenAI
|
||||
model: gpt-4o-mini
|
||||
api_key: $OPENAI_API_KEY
|
||||
base_url: $OPENAI_API_BASE
|
||||
sandbox:
|
||||
use: deerflow.sandbox.local:LocalSandboxProvider
|
||||
agents_api:
|
||||
enabled: true
|
||||
database:
|
||||
backend: sqlite
|
||||
"""
|
||||
|
||||
|
||||
def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Reset every process-wide cache that would survive across tests.
|
||||
|
||||
This fixture stands up a full FastAPI app + sqlite DB + LangGraph runtime
|
||||
inside ``tmp_path``. To get true per-test isolation we have to invalidate
|
||||
a handful of module-level caches that production normally never resets,
|
||||
so they pick up our test-only ``DEER_FLOW_HOME`` and sqlite path:
|
||||
|
||||
- ``deerflow.config.app_config`` caches the parsed ``config.yaml``.
|
||||
- ``deerflow.config.paths`` caches the ``Paths`` singleton derived from
|
||||
``DEER_FLOW_HOME`` at first access.
|
||||
- ``deerflow.persistence.engine`` caches the SQLAlchemy engine and
|
||||
session factory after the first call to ``init_engine_from_config``.
|
||||
|
||||
``raising=False`` keeps the fixture resilient if upstream renames or
|
||||
drops one of these attributes — the test will simply skip that reset
|
||||
instead of failing with a confusing AttributeError, and the next test
|
||||
to call ``get_app_config()``/``get_paths()`` will surface the real
|
||||
incompatibility loudly.
|
||||
"""
|
||||
from deerflow.config import app_config as app_config_module
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.persistence import engine as engine_module
|
||||
|
||||
for module, attr in (
|
||||
(app_config_module, "_app_config"),
|
||||
(app_config_module, "_app_config_path"),
|
||||
(app_config_module, "_app_config_mtime"),
|
||||
(paths_module, "_paths_singleton"),
|
||||
(engine_module, "_engine"),
|
||||
(engine_module, "_session_factory"),
|
||||
):
|
||||
monkeypatch.setattr(module, attr, None, raising=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Build a fresh FastAPI app inside a clean DEER_FLOW_HOME.
|
||||
|
||||
Each test gets its own sqlite DB and checkpoint store under ``tmp_path``,
|
||||
with no cross-test contamination.
|
||||
"""
|
||||
_reset_process_singletons(monkeypatch)
|
||||
|
||||
# Re-resolve the config from the test-only DEER_FLOW_HOME and pin its
|
||||
# sqlite path into tmp_path so the lifespan-time engine init lands there.
|
||||
from deerflow.config import app_config as app_config_module
|
||||
|
||||
cfg = app_config_module.get_app_config()
|
||||
cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db")
|
||||
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _drain_stream(response, *, timeout: float = 30.0, max_bytes: int = 4 * 1024 * 1024) -> str:
|
||||
"""Consume an SSE response body until the run terminates and return the text.
|
||||
|
||||
Bounded to keep the test fail-fast:
|
||||
- Stops as soon as an ``event: end`` SSE frame is observed (the gateway
|
||||
sends this when the background run finishes — see ``services.format_sse``
|
||||
and ``StreamBridge.publish_end``).
|
||||
- Stops at ``timeout`` seconds wall-clock so a stuck run / runaway heartbeat
|
||||
loop surfaces a real failure instead of hanging pytest.
|
||||
- Stops at ``max_bytes`` so a runaway producer can't OOM the test process.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
deadline = _time.monotonic() + timeout
|
||||
body = b""
|
||||
for chunk in response.iter_bytes():
|
||||
body += chunk
|
||||
if b"event: end" in body:
|
||||
break
|
||||
if len(body) >= max_bytes:
|
||||
break
|
||||
if _time.monotonic() >= deadline:
|
||||
break
|
||||
return body.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _wait_for_file(path: Path, *, timeout: float = 10.0) -> bool:
|
||||
"""Block until *path* exists or *timeout* elapses.
|
||||
|
||||
The run completes inside ``asyncio.create_task`` after start_run returns,
|
||||
so the test must wait for the background task to flush its writes.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
deadline = _time.monotonic() + timeout
|
||||
while _time.monotonic() < deadline:
|
||||
if path.exists():
|
||||
return True
|
||||
_time.sleep(0.05)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_real_http_create_agent_lands_in_authenticated_user_dir(
|
||||
isolated_app: Any,
|
||||
isolated_deer_flow_home: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""The full real-server contract test.
|
||||
|
||||
1. Register a real user via POST /api/v1/auth/register (also auto-logs in)
|
||||
2. POST to /api/threads/{tid}/runs/stream with the **exact** body shape the
|
||||
frontend (LangGraph SDK) sends during the bootstrap flow.
|
||||
3. Wait for the background run to finish.
|
||||
4. Assert SOUL.md exists under users/<authenticated_uid>/agents/<name>/.
|
||||
5. Assert NOTHING exists under users/default/agents/<name>/.
|
||||
"""
|
||||
# ``deerflow.agents.lead_agent.agent`` imports ``create_chat_model`` with
|
||||
# ``from deerflow.models import create_chat_model`` at module load time,
|
||||
# rebinding the symbol into its own namespace. So the only patch that
|
||||
# intercepts the call is the bound name on ``lead_agent.agent`` — patching
|
||||
# ``deerflow.models.create_chat_model`` would be too late.
|
||||
agent_name = "real-http-agent"
|
||||
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.agents.lead_agent.agent.create_chat_model",
|
||||
new=_build_fake_create_chat_model(agent_name),
|
||||
),
|
||||
TestClient(isolated_app) as client,
|
||||
):
|
||||
# --- 1. Register & auto-login ---
|
||||
register = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "e2e-user@example.com", "password": "very-strong-password-123"},
|
||||
)
|
||||
assert register.status_code == 201, register.text
|
||||
registered = register.json()
|
||||
auth_uid = registered["id"]
|
||||
# The endpoint sets both access_token (auth) and csrf_token (CSRF Double
|
||||
# Submit Cookie) cookies; the TestClient cookie jar propagates them.
|
||||
assert client.cookies.get("access_token"), "register endpoint must set session cookie"
|
||||
csrf_token = client.cookies.get("csrf_token")
|
||||
assert csrf_token, "register endpoint must set csrf_token cookie"
|
||||
|
||||
# --- 2. Create a thread (require_existing=True on /runs/stream means
|
||||
# we must call POST /api/threads first; the React frontend does the
|
||||
# same via the LangGraph SDK's threads.create) ---
|
||||
import uuid as _uuid
|
||||
|
||||
thread_id = str(_uuid.uuid4())
|
||||
created = client.post(
|
||||
"/api/threads",
|
||||
json={"thread_id": thread_id, "metadata": {}},
|
||||
headers={"X-CSRF-Token": csrf_token},
|
||||
)
|
||||
assert created.status_code == 200, created.text
|
||||
|
||||
# --- 3. POST /runs/stream with the bootstrap wire format ---
|
||||
# This is the EXACT shape the React frontend sends after PR #2784:
|
||||
# thread.submit(input, {config, context}) ->
|
||||
# POST /api/threads/{id}/runs/stream body =
|
||||
# {assistant_id, input, config, context}
|
||||
body = {
|
||||
"assistant_id": "lead_agent",
|
||||
"input": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (f"The new custom agent name is {agent_name}. Help me design its SOUL.md before saving it."),
|
||||
}
|
||||
]
|
||||
},
|
||||
"config": {"recursion_limit": 50},
|
||||
"context": {
|
||||
"agent_name": agent_name,
|
||||
"is_bootstrap": True,
|
||||
"mode": "flash",
|
||||
"thinking_enabled": False,
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
},
|
||||
"stream_mode": ["values"],
|
||||
}
|
||||
# The /stream endpoint returns SSE; we drain it so the server-side
|
||||
# background task (run_agent) gets to completion before we look at disk.
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"/api/threads/{thread_id}/runs/stream",
|
||||
json=body,
|
||||
headers={"X-CSRF-Token": csrf_token},
|
||||
) as resp:
|
||||
assert resp.status_code == 200, resp.read().decode()
|
||||
transcript = _drain_stream(resp)
|
||||
|
||||
# Sanity: the stream should have produced at least one event
|
||||
assert "event:" in transcript, f"no SSE events in response: {transcript[:500]!r}"
|
||||
|
||||
# --- 4. Verify filesystem outcome ---
|
||||
expected_dir = isolated_deer_flow_home / "users" / auth_uid / "agents" / agent_name
|
||||
default_dir = isolated_deer_flow_home / "users" / "default" / "agents" / agent_name
|
||||
|
||||
# The setup_agent tool runs inside the background asyncio task spawned
|
||||
# by start_run; SSE-drain typically waits for it, but we add a bounded
|
||||
# poll to be robust against scheduler jitter.
|
||||
assert _wait_for_file(expected_dir / "SOUL.md", timeout=15.0), (
|
||||
"SOUL.md did not appear under users/<auth_uid>/agents/. "
|
||||
f"Expected: {expected_dir / 'SOUL.md'}. "
|
||||
f"tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}. "
|
||||
f"SSE transcript tail: {transcript[-1000:]!r}"
|
||||
)
|
||||
|
||||
soul_text = (expected_dir / "SOUL.md").read_text()
|
||||
assert agent_name in soul_text, f"unexpected SOUL content: {soul_text!r}"
|
||||
|
||||
# The smoking-gun assertion: the agent must NOT have landed in default/
|
||||
assert not default_dir.exists(), f"REGRESSION: agent landed under users/default/{agent_name} instead of the authenticated user. Default-dir contents: {list(default_dir.rglob('*')) if default_dir.exists() else 'n/a'}"
|
||||
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||
def _runtime(
|
||||
thread_id: str | None = "thread-1",
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> SimpleNamespace:
|
||||
context = {}
|
||||
if thread_id is not None:
|
||||
context["thread_id"] = thread_id
|
||||
if agent_name is not None:
|
||||
context["agent_name"] = agent_name
|
||||
if user_id is not None:
|
||||
context["user_id"] = user_id
|
||||
return SimpleNamespace(context=context)
|
||||
|
||||
|
||||
@@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
|
||||
|
||||
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="main",
|
||||
agent_name="researcher",
|
||||
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||
|
||||
@@ -59,12 +59,15 @@ def _make_result(
|
||||
ai_messages: list[dict] | None = None,
|
||||
result: str | None = None,
|
||||
error: str | None = None,
|
||||
token_usage_records: list[dict] | None = None,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
status=status,
|
||||
ai_messages=ai_messages or [],
|
||||
result=result,
|
||||
error=error,
|
||||
token_usage_records=token_usage_records or [],
|
||||
usage_reported=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
|
||||
assert len(report_calls) == 1
|
||||
assert report_calls[0][1] is cancel_result
|
||||
assert cleanup_calls == ["tc-cancel-report"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected_type",
|
||||
[
|
||||
(FakeSubagentStatus.COMPLETED, "task_completed"),
|
||||
(FakeSubagentStatus.FAILED, "task_failed"),
|
||||
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
|
||||
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
|
||||
],
|
||||
)
|
||||
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
|
||||
"""Terminal task events include a usage summary from token_usage_records."""
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
|
||||
records = [
|
||||
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
|
||||
]
|
||||
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-usage",
|
||||
)
|
||||
|
||||
terminal_events = [e for e in events if e["type"] == expected_type]
|
||||
assert len(terminal_events) == 1
|
||||
assert terminal_events[0]["usage"] == {
|
||||
"input_tokens": 300,
|
||||
"output_tokens": 130,
|
||||
"total_tokens": 430,
|
||||
}
|
||||
|
||||
|
||||
def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
||||
"""Terminal event has usage=None when token_usage_records is empty."""
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
|
||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-records",
|
||||
)
|
||||
|
||||
completed = [e for e in events if e["type"] == "task_completed"]
|
||||
assert len(completed) == 1
|
||||
assert completed[0]["usage"] is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_app_config",
|
||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
||||
)
|
||||
|
||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
|
||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
|
||||
|
||||
task_tool_module._subagent_usage_cache.clear()
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-disabled-cache",
|
||||
)
|
||||
|
||||
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
|
||||
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
with pytest.raises(RuntimeError, match="poll failed"):
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-error",
|
||||
)
|
||||
|
||||
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
"""Tests for ThreadMetaRepository (SQLAlchemy-backed)."""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return ThreadMetaRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
yield ThreadMetaRepository(get_session_factory())
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_create_and_get(self, repo):
|
||||
record = await repo.create("t1")
|
||||
assert record["thread_id"] == "t1"
|
||||
assert record["status"] == "idle"
|
||||
@@ -31,148 +28,523 @@ class TestThreadMetaRepository:
|
||||
fetched = await repo.get("t1")
|
||||
assert fetched is not None
|
||||
assert fetched["thread_id"] == "t1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_assistant_id(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_create_with_assistant_id(self, repo):
|
||||
record = await repo.create("t1", assistant_id="agent1")
|
||||
assert record["assistant_id"] == "agent1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_create_with_owner_and_display_name(self, repo):
|
||||
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||
assert record["user_id"] == "user1"
|
||||
assert record["display_name"] == "My Thread"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_metadata(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_create_with_metadata(self, repo):
|
||||
record = await repo.create("t1", metadata={"key": "value"})
|
||||
assert record["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_get_nonexistent(self, repo):
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_record_allows(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_check_access_no_record_allows(self, repo):
|
||||
assert await repo.check_access("unknown", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_matches(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_check_access_owner_matches(self, repo):
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_check_access_owner_mismatch(self, repo):
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2") is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_check_access_no_owner_allows_all(self, repo):
|
||||
# Explicit user_id=None to bypass the new AUTO default that
|
||||
# would otherwise pick up the test user from the autouse fixture.
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_missing_row_denied(self, tmp_path):
|
||||
async def test_check_access_strict_missing_row_denied(self, repo):
|
||||
"""require_existing=True flips the missing-row case to *denied*.
|
||||
|
||||
Closes the delete-idempotence cross-user gap: after a thread is
|
||||
deleted, the row is gone, and the permissive default would let any
|
||||
caller "claim" it as untracked. The strict mode demands a row.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.check_access("never-existed", "user1", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_check_access_strict_owner_match_allowed(self, repo):
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_check_access_strict_owner_mismatch_denied(self, repo):
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||
async def test_check_access_strict_null_owner_still_allowed(self, repo):
|
||||
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||
|
||||
The strict flag tightens the *missing row* case, not the *shared
|
||||
row* case — legacy pre-auth rows that survived a clean migration
|
||||
without an owner are still everyone's.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_update_status(self, repo):
|
||||
await repo.create("t1")
|
||||
await repo.update_status("t1", "busy")
|
||||
record = await repo.get("t1")
|
||||
assert record["status"] == "busy"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_delete(self, repo):
|
||||
await repo.create("t1")
|
||||
await repo.delete("t1")
|
||||
assert await repo.get("t1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_delete_nonexistent_is_noop(self, repo):
|
||||
await repo.delete("nonexistent") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_merges(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_update_metadata_merges(self, repo):
|
||||
await repo.create("t1", metadata={"a": 1, "b": 2})
|
||||
await repo.update_metadata("t1", {"b": 99, "c": 3})
|
||||
record = await repo.get("t1")
|
||||
# Existing key preserved, overlapping key overwritten, new key added
|
||||
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_on_empty(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_update_metadata_on_empty(self, repo):
|
||||
await repo.create("t1")
|
||||
await repo.update_metadata("t1", {"k": "v"})
|
||||
record = await repo.get("t1")
|
||||
assert record["metadata"] == {"k": "v"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||
await _cleanup()
|
||||
|
||||
# --- search with metadata filter (SQL push-down) ---
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_filter_string(self, repo):
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "staging"})
|
||||
await repo.create("t3", metadata={"env": "prod", "region": "us"})
|
||||
|
||||
results = await repo.search(metadata={"env": "prod"})
|
||||
ids = {r["thread_id"] for r in results}
|
||||
assert ids == {"t1", "t3"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_filter_numeric(self, repo):
|
||||
await repo.create("t1", metadata={"priority": 1})
|
||||
await repo.create("t2", metadata={"priority": 2})
|
||||
await repo.create("t3", metadata={"priority": 1, "extra": "x"})
|
||||
|
||||
results = await repo.search(metadata={"priority": 1})
|
||||
ids = {r["thread_id"] for r in results}
|
||||
assert ids == {"t1", "t3"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_filter_multiple_keys(self, repo):
|
||||
await repo.create("t1", metadata={"env": "prod", "region": "us"})
|
||||
await repo.create("t2", metadata={"env": "prod", "region": "eu"})
|
||||
await repo.create("t3", metadata={"env": "staging", "region": "us"})
|
||||
|
||||
results = await repo.search(metadata={"env": "prod", "region": "us"})
|
||||
assert len(results) == 1
|
||||
assert results[0]["thread_id"] == "t1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_no_match(self, repo):
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
|
||||
results = await repo.search(metadata={"env": "dev"})
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_pagination_correct(self, repo):
|
||||
"""Regression: SQL push-down makes limit/offset exact even when most rows don't match."""
|
||||
for i in range(30):
|
||||
meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"}
|
||||
await repo.create(f"t{i:03d}", metadata=meta)
|
||||
|
||||
# Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows
|
||||
all_matches = await repo.search(metadata={"target": "yes"}, limit=100)
|
||||
assert len(all_matches) == 10
|
||||
|
||||
# Paginate: first page
|
||||
page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0)
|
||||
assert len(page1) == 3
|
||||
|
||||
# Paginate: second page
|
||||
page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3)
|
||||
assert len(page2) == 3
|
||||
|
||||
# No overlap between pages
|
||||
page1_ids = {r["thread_id"] for r in page1}
|
||||
page2_ids = {r["thread_id"] for r in page2}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
# Last page
|
||||
page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9)
|
||||
assert len(page_last) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_with_status_filter(self, repo):
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "prod"})
|
||||
await repo.update_status("t1", "busy")
|
||||
|
||||
results = await repo.search(metadata={"env": "prod"}, status="busy")
|
||||
assert len(results) == 1
|
||||
assert results[0]["thread_id"] == "t1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_without_metadata_still_works(self, repo):
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2")
|
||||
|
||||
results = await repo.search(limit=10)
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_missing_key_no_match(self, repo):
|
||||
"""Rows without the requested metadata key should not match."""
|
||||
await repo.create("t1", metadata={"other": "val"})
|
||||
await repo.create("t2", metadata={"env": "prod"})
|
||||
|
||||
results = await repo.search(metadata={"env": "prod"})
|
||||
assert len(results) == 1
|
||||
assert results[0]["thread_id"] == "t2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog):
|
||||
"""When ALL metadata keys are unsafe, raises InvalidMetadataFilterError."""
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "staging"})
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info:
|
||||
await repo.search(metadata={"bad;key": "x"})
|
||||
assert any("bad;key" in r.message for r in caplog.records)
|
||||
# Subclass of ValueError for backward compatibility
|
||||
assert isinstance(exc_info.value, ValueError)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog):
|
||||
"""Valid keys filter rows; only the invalid key is warned and skipped."""
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "staging"})
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
|
||||
results = await repo.search(metadata={"env": "prod", "bad;key": "x"})
|
||||
ids = {r["thread_id"] for r in results}
|
||||
assert ids == {"t1"}
|
||||
assert any("bad;key" in r.message for r in caplog.records)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_filter_boolean(self, repo):
|
||||
"""True matches only boolean true, not integer 1."""
|
||||
await repo.create("t1", metadata={"active": True})
|
||||
await repo.create("t2", metadata={"active": False})
|
||||
await repo.create("t3", metadata={"active": True, "extra": "x"})
|
||||
await repo.create("t4", metadata={"active": 1})
|
||||
|
||||
results = await repo.search(metadata={"active": True})
|
||||
ids = {r["thread_id"] for r in results}
|
||||
assert ids == {"t1", "t3"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_filter_none(self, repo):
|
||||
"""Only rows with explicit JSON null match; missing key does not."""
|
||||
await repo.create("t1", metadata={"tag": None})
|
||||
await repo.create("t2", metadata={"tag": "present"})
|
||||
await repo.create("t3", metadata={"other": "val"})
|
||||
|
||||
results = await repo.search(metadata={"tag": None})
|
||||
ids = {r["thread_id"] for r in results}
|
||||
assert ids == {"t1"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_non_string_key_skipped(self, repo, caplog):
|
||||
"""Non-string keys raise ValueError from isinstance check; should be warned and skipped."""
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "staging"})
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search(metadata={1: "x"})
|
||||
assert any("1" in r.message for r in caplog.records)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog):
|
||||
"""Unsupported value types (list, dict) raise TypeError; should be warned and skipped."""
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "staging"})
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search(metadata={"env": ["prod", "staging"]})
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_dotted_key_raises(self, repo, caplog):
|
||||
"""Dotted keys are rejected; when ALL keys are dotted, raises ValueError."""
|
||||
await repo.create("t1", metadata={"env": "prod"})
|
||||
await repo.create("t2", metadata={"env": "staging"})
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"):
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search(metadata={"a.b": "anything"})
|
||||
assert any("a.b" in r.message for r in caplog.records)
|
||||
|
||||
# --- dialect-aware type-safe filtering edge cases ---
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_bool_vs_int_distinction(self, repo):
|
||||
"""True must not match 1; False must not match 0."""
|
||||
await repo.create("bool_true", metadata={"flag": True})
|
||||
await repo.create("bool_false", metadata={"flag": False})
|
||||
await repo.create("int_one", metadata={"flag": 1})
|
||||
await repo.create("int_zero", metadata={"flag": 0})
|
||||
|
||||
true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})}
|
||||
assert true_hits == {"bool_true"}
|
||||
|
||||
false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})}
|
||||
assert false_hits == {"bool_false"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_int_does_not_match_bool(self, repo):
|
||||
"""Integer 1 must not match boolean True."""
|
||||
await repo.create("bool_true", metadata={"val": True})
|
||||
await repo.create("int_one", metadata={"val": 1})
|
||||
|
||||
hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})}
|
||||
assert hits == {"int_one"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_none_excludes_missing_key(self, repo):
|
||||
"""Filtering by None matches explicit JSON null only, not missing key or empty {}."""
|
||||
await repo.create("explicit_null", metadata={"k": None})
|
||||
await repo.create("missing_key", metadata={"other": "x"})
|
||||
await repo.create("empty_obj", metadata={})
|
||||
|
||||
hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})}
|
||||
assert hits == {"explicit_null"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_float_value(self, repo):
|
||||
await repo.create("t1", metadata={"score": 3.14})
|
||||
await repo.create("t2", metadata={"score": 2.71})
|
||||
await repo.create("t3", metadata={"score": 3.14})
|
||||
|
||||
hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})}
|
||||
assert hits == {"t1", "t3"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_mixed_types_same_key(self, repo):
|
||||
"""Each type query only matches its own type, even when the key is shared."""
|
||||
await repo.create("str_row", metadata={"x": "hello"})
|
||||
await repo.create("int_row", metadata={"x": 42})
|
||||
await repo.create("bool_row", metadata={"x": True})
|
||||
await repo.create("null_row", metadata={"x": None})
|
||||
|
||||
assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"}
|
||||
assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"}
|
||||
assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"}
|
||||
assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_metadata_large_int_precision(self, repo):
|
||||
"""Integers beyond float precision (> 2**53) must match exactly."""
|
||||
large = 2**53 + 1
|
||||
await repo.create("t1", metadata={"id": large})
|
||||
await repo.create("t2", metadata={"id": large - 1})
|
||||
|
||||
hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})}
|
||||
assert hits == {"t1"}
|
||||
|
||||
|
||||
class TestJsonMatchCompilation:
|
||||
"""Verify compiled SQL for both SQLite and PostgreSQL dialects."""
|
||||
|
||||
def test_json_match_compiles_sqlite(self):
|
||||
from sqlalchemy import Column, MetaData, String, Table, create_engine
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
engine = create_engine("sqlite://")
|
||||
|
||||
cases = [
|
||||
(None, "json_type(t.data, '$.\"k\"') = 'null'"),
|
||||
(True, "json_type(t.data, '$.\"k\"') = 'true'"),
|
||||
(False, "json_type(t.data, '$.\"k\"') = 'false'"),
|
||||
]
|
||||
for value, expected_fragment in cases:
|
||||
expr = json_match(t.c.data, "k", value)
|
||||
sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})
|
||||
assert str(sql) == expected_fragment, f"value={value!r}: {sql}"
|
||||
|
||||
# int: uses INTEGER cast for precision, type-check narrows to 'integer' only
|
||||
int_expr = json_match(t.c.data, "k", 42)
|
||||
sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "json_type" in sql
|
||||
assert "= 'integer'" in sql
|
||||
assert "INTEGER" in sql
|
||||
assert "CAST" in sql
|
||||
|
||||
# float: uses REAL cast, type-check spans 'integer' and 'real'
|
||||
float_expr = json_match(t.c.data, "k", 3.14)
|
||||
sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "json_type" in sql
|
||||
assert "IN ('integer', 'real')" in sql
|
||||
assert "REAL" in sql
|
||||
|
||||
str_expr = json_match(t.c.data, "k", "hello")
|
||||
sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "json_type" in sql
|
||||
assert "'text'" in sql
|
||||
|
||||
def test_json_match_compiles_pg(self):
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
dialect = postgresql.dialect()
|
||||
|
||||
cases = [
|
||||
(None, "json_typeof(t.data -> 'k') = 'null'"),
|
||||
(True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"),
|
||||
(False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"),
|
||||
]
|
||||
for value, expected_fragment in cases:
|
||||
expr = json_match(t.c.data, "k", value)
|
||||
sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
assert str(sql) == expected_fragment, f"value={value!r}: {sql}"
|
||||
|
||||
# int: CASE guard prevents CAST error when 'number' also matches floats
|
||||
int_expr = json_match(t.c.data, "k", 42)
|
||||
sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "json_typeof" in sql
|
||||
assert "'number'" in sql
|
||||
assert "BIGINT" in sql
|
||||
assert "CASE WHEN" in sql
|
||||
assert "'^-?[0-9]+$'" in sql
|
||||
|
||||
# float: uses DOUBLE PRECISION cast
|
||||
float_expr = json_match(t.c.data, "k", 3.14)
|
||||
sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "json_typeof" in sql
|
||||
assert "'number'" in sql
|
||||
assert "DOUBLE PRECISION" in sql
|
||||
|
||||
str_expr = json_match(t.c.data, "k", "hello")
|
||||
sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "json_typeof" in sql
|
||||
assert "'string'" in sql
|
||||
|
||||
def test_json_match_rejects_unsafe_key(self):
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]:
|
||||
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||
json_match(t.c.data, bad_key, "x")
|
||||
|
||||
# Non-string keys must also raise ValueError (not TypeError from re.match)
|
||||
for non_str_key in [42, None, ("k",)]:
|
||||
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||
json_match(t.c.data, non_str_key, "x")
|
||||
|
||||
def test_json_match_rejects_unsupported_value_type(self):
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
for bad_value in [[], {}, object()]:
|
||||
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||
json_match(t.c.data, "k", bad_value)
|
||||
|
||||
def test_json_match_unsupported_dialect_raises(self):
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
expr = json_match(t.c.data, "k", "v")
|
||||
|
||||
with pytest.raises(NotImplementedError, match="mysql"):
|
||||
str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True}))
|
||||
|
||||
def test_json_match_rejects_out_of_range_int(self):
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
# boundary values must be accepted
|
||||
json_match(t.c.data, "k", 2**63 - 1)
|
||||
json_match(t.c.data, "k", -(2**63))
|
||||
|
||||
# one beyond each boundary must be rejected
|
||||
for out_of_range in [2**63, -(2**63) - 1, 10**30]:
|
||||
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||
json_match(t.c.data, "k", out_of_range)
|
||||
|
||||
def test_compiler_raises_on_escaped_key(self):
|
||||
"""Compiler raises ValueError even when __init__ validation is bypassed."""
|
||||
from sqlalchemy import Column, MetaData, String, Table, create_engine
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
|
||||
metadata = MetaData()
|
||||
t = Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
engine = create_engine("sqlite://")
|
||||
|
||||
elem = json_match(t.c.data, "k", "v")
|
||||
elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert
|
||||
|
||||
with pytest.raises(ValueError, match="Key escaped validation"):
|
||||
str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
with pytest.raises(ValueError, match="Key escaped validation"):
|
||||
str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
|
||||
|
||||
@@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from app.gateway.routers import threads
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.persistence.thread_meta import InvalidMetadataFilterError
|
||||
from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore
|
||||
|
||||
_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||
@@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None
|
||||
assert entries, "expected at least one history entry"
|
||||
for entry in entries:
|
||||
assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry
|
||||
|
||||
|
||||
# ── Metadata filter validation at API boundary ────────────────────────────────
|
||||
|
||||
|
||||
def test_search_threads_rejects_invalid_key_at_api_boundary() -> None:
|
||||
"""Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic
|
||||
validator on ThreadSearchRequest.metadata — 422 from both backends.
|
||||
"""
|
||||
app, _store, _checkpointer = _build_thread_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}})
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None:
|
||||
"""Value types outside (None, bool, int, float, str) are rejected."""
|
||||
app, _store, _checkpointer = _build_thread_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}})
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None:
|
||||
"""If the backend still raises InvalidMetadataFilterError (defense in
|
||||
depth), the handler surfaces it as HTTP 400.
|
||||
"""
|
||||
app, _store, _checkpointer = _build_thread_app()
|
||||
thread_store = app.state.thread_store
|
||||
|
||||
async def _raise(**kwargs):
|
||||
raise InvalidMetadataFilterError("rejected")
|
||||
|
||||
with TestClient(app) as client:
|
||||
with patch.object(thread_store, "search", side_effect=_raise):
|
||||
response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}})
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "rejected" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_search_threads_succeeds_with_valid_metadata() -> None:
|
||||
"""Sanity check: valid metadata passes through without error."""
|
||||
app, _store, _checkpointer = _build_thread_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}})
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
|
||||
"tool_call_id": "write_todos:remove",
|
||||
}
|
||||
]
|
||||
|
||||
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
|
||||
middleware = TokenUsageMiddleware()
|
||||
first_dispatch = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
|
||||
)
|
||||
second_dispatch = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "task:second-a", "name": "task", "args": {}},
|
||||
{"id": "task:second-b", "name": "task", "args": {}},
|
||||
],
|
||||
)
|
||||
messages = [
|
||||
first_dispatch,
|
||||
ToolMessage(content="first", tool_call_id="task:first"),
|
||||
second_dispatch,
|
||||
ToolMessage(content="second-a", tool_call_id="task:second-a"),
|
||||
ToolMessage(content="second-b", tool_call_id="task:second-b"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
cached_usage = {
|
||||
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
|
||||
}
|
||||
|
||||
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"pop_cached_subagent_usage",
|
||||
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
|
||||
)
|
||||
|
||||
result = middleware.after_model({"messages": messages}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
|
||||
assert len(usage_updates) == 1
|
||||
updated = usage_updates[0]
|
||||
assert updated.tool_calls == second_dispatch.tool_calls
|
||||
assert updated.usage_metadata == {
|
||||
"input_tokens": 30,
|
||||
"output_tokens": 12,
|
||||
"total_tokens": 42,
|
||||
}
|
||||
|
||||
@@ -65,8 +65,7 @@ def _make_minimal_config(tools):
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
|
||||
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||
"""Config-loaded async-only tools can still be invoked by sync clients."""
|
||||
|
||||
async def async_tool_impl(x: int) -> str:
|
||||
@@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash,
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
|
||||
def test_no_duplicates_returned(mock_bash, mock_cfg):
|
||||
"""get_available_tools() never returns two tools with the same name."""
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
|
||||
@@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||
def test_first_occurrence_wins(mock_bash, mock_cfg):
|
||||
"""When duplicates exist, the first occurrence is kept."""
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
|
||||
@@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
|
||||
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
|
||||
"""A warning is logged for every skipped duplicate."""
|
||||
import logging
|
||||
|
||||
|
||||
@@ -0,0 +1,253 @@
|
||||
"""End-to-end verification for update_agent's user_id resolution.
|
||||
|
||||
PR #2784 hardened setup_agent to prefer runtime.context["user_id"] over the
|
||||
contextvar. update_agent had the same latent gap: it unconditionally called
|
||||
get_effective_user_id() at module level, so any scenario where the contextvar
|
||||
was unavailable while runtime.context carried user_id (a background task
|
||||
scheduled outside the request task, a worker pool that doesn't copy_context,
|
||||
checkpoint resume on a different task) would silently route writes to
|
||||
users/default/agents/...
|
||||
|
||||
These tests are load-bearing under @no_auto_user (contextvar empty):
|
||||
|
||||
- The negative-control test confirms the fixture actually puts the tool in
|
||||
the regime where the contextvar fallback would land in users/default/.
|
||||
Without that, the positive test would be vacuously satisfied.
|
||||
- The positive test verifies update_agent honours runtime.context["user_id"]
|
||||
injected by inject_authenticated_user_context in the gateway. Before the
|
||||
fix in this PR, this test failed; now it passes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from _agent_e2e_helpers import build_single_tool_call_model
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import (
|
||||
build_run_config,
|
||||
inject_authenticated_user_context,
|
||||
merge_run_context_overrides,
|
||||
)
|
||||
from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context
|
||||
|
||||
|
||||
def _make_request(user_id_str: str | None) -> SimpleNamespace:
|
||||
user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") if user_id_str else None
|
||||
return SimpleNamespace(state=SimpleNamespace(user=user))
|
||||
|
||||
|
||||
def _assemble_config(*, body_context: dict | None, request_user_id: str | None, thread_id: str) -> dict:
|
||||
config = build_run_config(thread_id, {"recursion_limit": 50}, None, assistant_id="lead_agent")
|
||||
merge_run_context_overrides(config, body_context)
|
||||
inject_authenticated_user_context(config, _make_request(request_user_id))
|
||||
return config
|
||||
|
||||
|
||||
def _seed_existing_agent(tmp_path: Path, user_id: str, agent_name: str, soul: str = "# Original"):
|
||||
"""Pre-create an agent on disk for update_agent to overwrite."""
|
||||
agent_dir = tmp_path / "users" / user_id / "agents" / agent_name
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "config.yaml").write_text(
|
||||
yaml.dump({"name": agent_name, "description": "old"}, allow_unicode=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(agent_dir / "SOUL.md").write_text(soul, encoding="utf-8")
|
||||
return agent_dir
|
||||
|
||||
|
||||
def _make_paths_mock(tmp_path: Path):
|
||||
paths = MagicMock()
|
||||
paths.base_dir = tmp_path
|
||||
paths.agent_dir = lambda name: tmp_path / "agents" / name
|
||||
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
|
||||
return paths
|
||||
|
||||
|
||||
def _patch_update_agent_dependencies(tmp_path: Path):
|
||||
"""update_agent reads load_agent_config + get_app_config — stub them
|
||||
minimally so the tool can run without a real config file or LLM."""
|
||||
fake_model_cfg = SimpleNamespace(name="fake-model")
|
||||
fake_app_cfg = MagicMock()
|
||||
fake_app_cfg.get_model_config = lambda name: fake_model_cfg if name == "fake-model" else None
|
||||
|
||||
return [
|
||||
patch(
|
||||
"deerflow.tools.builtins.update_agent_tool.get_paths",
|
||||
return_value=_make_paths_mock(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"deerflow.tools.builtins.update_agent_tool.get_app_config",
|
||||
return_value=fake_app_cfg,
|
||||
),
|
||||
# load_agent_config (used by update_agent to read existing config) also
|
||||
# reads paths via its own module-level get_paths reference. Patch it too
|
||||
# or the tool returns "Agent does not exist" before touching disk.
|
||||
patch(
|
||||
"deerflow.config.agents_config.get_paths",
|
||||
return_value=_make_paths_mock(tmp_path),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _build_update_graph(*, soul_payload: str):
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.tools.builtins.update_agent_tool import update_agent
|
||||
|
||||
fake_model = build_single_tool_call_model(
|
||||
tool_name="update_agent",
|
||||
tool_args={"soul": soul_payload, "description": "refined"},
|
||||
tool_call_id="call_update_1",
|
||||
final_text="updated",
|
||||
)
|
||||
return create_agent(model=fake_model, tools=[update_agent], system_prompt="updater")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Negative control — proves the test environment puts update_agent in the
|
||||
# regime where the contextvar fallback would land in default/.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_update_agent_falls_back_to_default_when_no_inject_and_no_contextvar(tmp_path: Path):
|
||||
"""No request.state.user, no contextvar — update_agent must look in
|
||||
users/default/agents/. We seed the file there so the tool succeeds and
|
||||
we know which directory it actually consulted."""
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
_seed_existing_agent(tmp_path, "default", "fallback-target")
|
||||
|
||||
config = _assemble_config(
|
||||
body_context={"agent_name": "fallback-target"},
|
||||
request_user_id=None, # no auth, inject is no-op
|
||||
thread_id="thread-update-1",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-update-1", "run-1", config.get("context"), None)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
graph = _build_update_graph(soul_payload="# Fallback Updated")
|
||||
|
||||
with ExitStack() as stack:
|
||||
for p in _patch_update_agent_dependencies(tmp_path):
|
||||
stack.enter_context(p)
|
||||
graph.invoke(
|
||||
{"messages": [HumanMessage(content="update fallback-target")]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
soul = (tmp_path / "users" / "default" / "agents" / "fallback-target" / "SOUL.md").read_text()
|
||||
assert soul == "# Fallback Updated", "Sanity: tool should have written under default/"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression guard — passes on this branch, would fail on main before the fix.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_update_agent_should_use_runtime_context_user_id_when_contextvar_missing(tmp_path: Path):
|
||||
"""update_agent prefers the authenticated user_id carried in
|
||||
runtime.context (placed there by inject_authenticated_user_context)
|
||||
over the contextvar — same contract as setup_agent (PR #2784).
|
||||
|
||||
Before this PR's fix, update_agent unconditionally called
|
||||
get_effective_user_id() and landed in default/ whenever the contextvar
|
||||
was unavailable. This test pins the corrected behaviour.
|
||||
"""
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
auth_uid = "abcdef01-2345-6789-abcd-ef0123456789"
|
||||
|
||||
# Seed the agent in BOTH locations so we can prove which one was opened.
|
||||
auth_dir = _seed_existing_agent(tmp_path, auth_uid, "shared-name", soul="# Auth Original")
|
||||
default_dir = _seed_existing_agent(tmp_path, "default", "shared-name", soul="# Default Original")
|
||||
|
||||
config = _assemble_config(
|
||||
body_context={"agent_name": "shared-name"},
|
||||
request_user_id=auth_uid,
|
||||
thread_id="thread-update-2",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-update-2", "run-2", config.get("context"), None)
|
||||
assert runtime_ctx["user_id"] == auth_uid, "Pre-condition: inject must have placed user_id into runtime_ctx"
|
||||
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
graph = _build_update_graph(soul_payload="# Auth Updated")
|
||||
|
||||
with ExitStack() as stack:
|
||||
for p in _patch_update_agent_dependencies(tmp_path):
|
||||
stack.enter_context(p)
|
||||
graph.invoke(
|
||||
{"messages": [HumanMessage(content="update shared-name")]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
auth_soul = (auth_dir / "SOUL.md").read_text()
|
||||
default_soul = (default_dir / "SOUL.md").read_text()
|
||||
|
||||
assert auth_soul == "# Auth Updated", f"REGRESSION: update_agent ignored runtime.context['user_id']={auth_uid!r} and routed the write to users/default/ instead. auth_soul={auth_soul!r}, default_soul={default_soul!r}"
|
||||
assert default_soul == "# Default Original", "REGRESSION: update_agent corrupted the shared default-user agent. It should have written under the authenticated user's path."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Positive — when contextvar IS the auth user (the normal HTTP case), things
|
||||
# already work. Pin it as a regression guard so future refactors don't
|
||||
# accidentally break the contextvar path in pursuit of the runtime-context fix.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_agent_uses_contextvar_when_present(tmp_path: Path, monkeypatch):
|
||||
"""The normal HTTP case: contextvar is set by auth_middleware. This must
|
||||
keep working regardless of how runtime.context is populated."""
|
||||
from types import SimpleNamespace as _SN
|
||||
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
auth_uid = "11112222-3333-4444-5555-666677778888"
|
||||
user = _SN(id=auth_uid, email="ctxvar@local")
|
||||
|
||||
_seed_existing_agent(tmp_path, auth_uid, "ctxvar-agent", soul="# Original")
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
config = _assemble_config(
|
||||
body_context={"agent_name": "ctxvar-agent"},
|
||||
request_user_id=auth_uid,
|
||||
thread_id="thread-update-3",
|
||||
)
|
||||
runtime_ctx = _build_runtime_context("thread-update-3", "run-3", config.get("context"), None)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=runtime_ctx, store=None)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
graph = _build_update_graph(soul_payload="# CtxVar Updated")
|
||||
|
||||
with ExitStack() as stack:
|
||||
for p in _patch_update_agent_dependencies(tmp_path):
|
||||
stack.enter_context(p)
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
final = graph.invoke(
|
||||
{"messages": [HumanMessage(content="update ctxvar-agent")]},
|
||||
config=config,
|
||||
)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
# surface the tool's reply for debug if it errored
|
||||
tool_replies = [m.content for m in final["messages"] if getattr(m, "type", "") == "tool"]
|
||||
soul = (tmp_path / "users" / auth_uid / "agents" / "ctxvar-agent" / "SOUL.md").read_text()
|
||||
assert soul == "# CtxVar Updated", f"tool replies: {tool_replies}"
|
||||
Reference in New Issue
Block a user