Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e5c7328cf5 | |||
| ba864112a3 | |||
| 6e8e6a969b | |||
| eab7ae3d62 | |||
| f1a0ab699a | |||
| 2a1ac06bf4 |
@@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
|
||||
|
||||
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
||||
|
||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
|
||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
|
||||
|
||||
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
||||
|
||||
|
||||
+1
-1
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
|
||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
||||
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
|
||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
||||
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, override
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
||||
for tc in message.tool_calls or []:
|
||||
if isinstance(tc, dict):
|
||||
if tc.get("id") == tool_call_id:
|
||||
return True
|
||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
||||
# written back to the same dispatch message (merging into one update).
|
||||
state_updates: dict[int, AIMessage] = {}
|
||||
if len(messages) >= 2:
|
||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
||||
|
||||
idx = len(messages) - 2
|
||||
while idx >= 0:
|
||||
tool_msg = messages[idx]
|
||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
||||
break
|
||||
|
||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
||||
if subagent_usage:
|
||||
# Search backward from the ToolMessage to find the AIMessage
|
||||
# that dispatched it. A single model response can dispatch
|
||||
# multiple task tool calls, so we can't assume a fixed offset.
|
||||
dispatch_idx = idx - 1
|
||||
while dispatch_idx >= 0:
|
||||
candidate = messages[dispatch_idx]
|
||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
||||
# Accumulate into an existing update for the same
|
||||
# AIMessage (multiple task calls in one response),
|
||||
# or merge fresh from the original message.
|
||||
existing_update = state_updates.get(dispatch_idx)
|
||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
||||
merged = {
|
||||
**prev,
|
||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
||||
}
|
||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
||||
break
|
||||
dispatch_idx -= 1
|
||||
idx -= 1
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
if state_updates:
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
state_updates[len(messages) - 1] = updated_msg
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
|
||||
@@ -223,10 +223,11 @@ class RunRepository(RunStore):
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
||||
model_name.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||
@@ -236,7 +237,7 @@ class RunRepository(RunStore):
|
||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(_thread, _completed)
|
||||
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
||||
.group_by(model_name)
|
||||
)
|
||||
|
||||
async with self._sf() as session:
|
||||
|
||||
@@ -26,6 +26,28 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
|
||||
# write it back to the triggering AIMessage's usage_metadata.
|
||||
_subagent_usage_cache: dict[str, dict[str, int]] = {}
|
||||
|
||||
|
||||
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
||||
if app_config is None:
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
||||
|
||||
|
||||
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
|
||||
if enabled and usage:
|
||||
_subagent_usage_cache[tool_call_id] = usage
|
||||
|
||||
|
||||
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
|
||||
return _subagent_usage_cache.pop(tool_call_id, None)
|
||||
|
||||
|
||||
def _is_subagent_terminal(result: Any) -> bool:
|
||||
"""Return whether a background subagent result is safe to clean up."""
|
||||
@@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||
return None
|
||||
|
||||
|
||||
def _summarize_usage(records: list[dict] | None) -> dict | None:
|
||||
"""Summarize token usage records into a compact dict for SSE events."""
|
||||
if not records:
|
||||
return None
|
||||
return {
|
||||
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
|
||||
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
|
||||
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
|
||||
}
|
||||
|
||||
|
||||
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
||||
"""Report subagent token usage to the parent RunJournal, if available.
|
||||
|
||||
@@ -177,6 +210,7 @@ async def task_tool(
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
"""
|
||||
runtime_app_config = _get_runtime_app_config(runtime)
|
||||
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
|
||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||
|
||||
# Get subagent configuration
|
||||
@@ -312,27 +346,32 @@ async def task_tool(
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.CANCELLED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return "Task cancelled by user."
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
@@ -351,7 +390,9 @@ async def task_tool(
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
# Signal the background subagent thread to stop cooperatively.
|
||||
@@ -374,4 +415,8 @@ async def task_tool(
|
||||
cleanup_background_task(task_id)
|
||||
else:
|
||||
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
||||
_subagent_usage_cache.pop(tool_call_id, None)
|
||||
raise
|
||||
except Exception:
|
||||
_subagent_usage_cache.pop(tool_call_id, None)
|
||||
raise
|
||||
|
||||
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -116,8 +116,6 @@ def get_available_tools(
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when loading MCP tools.
|
||||
mcp_tools = []
|
||||
# Reset deferred registry upfront to prevent stale state from previous calls
|
||||
reset_deferred_registry()
|
||||
if include_mcp:
|
||||
try:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
@@ -135,12 +133,51 @@ def get_available_tools(
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
|
||||
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
|
||||
|
||||
registry = DeferredToolRegistry()
|
||||
for t in mcp_tools:
|
||||
registry.register(t)
|
||||
set_deferred_registry(registry)
|
||||
# Reuse the existing registry if one is already set for
|
||||
# this async context. ``get_available_tools`` is
|
||||
# re-entered whenever a subagent is spawned
|
||||
# (``task_tool`` calls it to build the child agent's
|
||||
# toolset), and previously we used to unconditionally
|
||||
# rebuild the registry — wiping out the parent agent's
|
||||
# tool_search promotions. The
|
||||
# ``DeferredToolFilterMiddleware`` then re-hid those
|
||||
# tools from subsequent model calls, leaving the agent
|
||||
# able to see a tool's name but unable to invoke it
|
||||
# (issue #2884). ``contextvars`` already gives us the
|
||||
# lifetime semantics we want: a fresh request / graph
|
||||
# run starts in a new asyncio task with the
|
||||
# ContextVar at its default of ``None``, so reuse is
|
||||
# only triggered for re-entrant calls inside one run.
|
||||
#
|
||||
# Intentionally NOT reconciling against the current
|
||||
# ``mcp_tools`` snapshot. The MCP cache only refreshes
|
||||
# on ``extensions_config.json`` mtime changes, which
|
||||
# in practice happens between graph runs — not inside
|
||||
# one. And even if a refresh did happen mid-run, the
|
||||
# already-built lead agent's ``ToolNode`` still holds
|
||||
# the *previous* tool set (LangGraph binds tools at
|
||||
# graph construction time), so a brand-new MCP tool
|
||||
# couldn't actually be invoked anyway. The
|
||||
# ``DeferredToolRegistry`` doesn't retain the names
|
||||
# of previously-promoted tools (``promote()`` drops
|
||||
# the entry entirely), so re-syncing the registry
|
||||
# against a fresh ``mcp_tools`` list would
|
||||
# mis-classify those promotions as new tools and
|
||||
# re-register them as deferred — exactly the bug
|
||||
# this fix exists to prevent.
|
||||
existing_registry = get_deferred_registry()
|
||||
if existing_registry is None:
|
||||
registry = DeferredToolRegistry()
|
||||
for t in mcp_tools:
|
||||
registry.register(t)
|
||||
set_deferred_registry(registry)
|
||||
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
||||
else:
|
||||
mcp_tool_names = {t.name for t in mcp_tools}
|
||||
still_deferred = len(existing_registry)
|
||||
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
|
||||
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
|
||||
builtin_tools.append(tool_search_tool)
|
||||
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
||||
except ImportError:
|
||||
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
|
||||
issues when unit-testing lightweight config/registry code in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -11,11 +13,16 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
|
||||
|
||||
# Make 'app' and 'deerflow' importable from any working directory
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
|
||||
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
|
||||
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
|
||||
|
||||
# Break the circular import chain that exists in production code:
|
||||
# deerflow.subagents.__init__
|
||||
# -> .executor (SubagentExecutor, SubagentResult)
|
||||
@@ -56,6 +63,92 @@ def provisioner_module():
|
||||
return module
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def blocking_io_detector():
|
||||
"""Fail a focused test if blocking calls run on the event loop thread."""
|
||||
with detect_blocking_io(fail_on_exit=True) as detector:
|
||||
yield detector
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
group = parser.getgroup("blocking-io")
|
||||
group.addoption(
|
||||
"--detect-blocking-io",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
|
||||
)
|
||||
group.addoption(
|
||||
"--detect-blocking-io-fail",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set a failing exit status when --detect-blocking-io records violations.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
if _blocking_io_probe_enabled(session.config):
|
||||
_blocking_io_probe.clear()
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_call(item: pytest.Item):
|
||||
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
|
||||
yield
|
||||
return
|
||||
|
||||
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
|
||||
detector.__enter__()
|
||||
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_teardown(item: pytest.Item):
|
||||
yield
|
||||
|
||||
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
|
||||
if detector is None:
|
||||
return
|
||||
|
||||
try:
|
||||
detector.__exit__(None, None, None)
|
||||
_blocking_io_probe.record(item.nodeid, detector.violations)
|
||||
finally:
|
||||
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session) -> None:
|
||||
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
|
||||
session.exitstatus = pytest.ExitCode.TESTS_FAILED
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
|
||||
if not _blocking_io_probe_enabled(terminalreporter.config):
|
||||
return
|
||||
|
||||
header, *details = _blocking_io_probe.format_summary().splitlines()
|
||||
terminalreporter.write_sep("=", header)
|
||||
for line in details:
|
||||
terminalreporter.write_line(line)
|
||||
|
||||
|
||||
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
|
||||
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
|
||||
|
||||
|
||||
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
|
||||
return bool(config.getoption("--detect-blocking-io-fail"))
|
||||
|
||||
|
||||
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
|
||||
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Shared test support helpers."""
|
||||
@@ -0,0 +1 @@
|
||||
"""Runtime and static detectors used by tests."""
|
||||
@@ -0,0 +1,287 @@
|
||||
"""Test helper for detecting blocking calls on an asyncio event loop.
|
||||
|
||||
The detector is intentionally test-only. It monkeypatches a small set of
|
||||
well-known blocking entry points and their already-loaded module-level aliases,
|
||||
then records calls only when they happen on a thread that is currently running
|
||||
an asyncio event loop. Aliases captured in closures or default arguments remain
|
||||
out of scope.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
BlockingCallable = Callable[..., Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockingCallSpec:
|
||||
"""Describes one blocking callable to wrap during a detector run."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
record_on_iteration: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockingCall:
|
||||
"""One blocking call observed on an asyncio event loop thread."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
stack: tuple[traceback.FrameSummary, ...]
|
||||
|
||||
|
||||
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
|
||||
BlockingCallSpec("time.sleep", "time:sleep"),
|
||||
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
|
||||
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
|
||||
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
|
||||
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
|
||||
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
|
||||
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
|
||||
)
|
||||
|
||||
|
||||
def _is_event_loop_thread() -> bool:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return loop.is_running()
|
||||
|
||||
|
||||
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
|
||||
module_name, attr_path = target.split(":", maxsplit=1)
|
||||
owner: object = importlib.import_module(module_name)
|
||||
parts = attr_path.split(".")
|
||||
for part in parts[:-1]:
|
||||
owner = getattr(owner, part)
|
||||
|
||||
attr_name = parts[-1]
|
||||
original = getattr(owner, attr_name)
|
||||
return owner, attr_name, original
|
||||
|
||||
|
||||
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
|
||||
return tuple(frame for frame in stack if frame.filename != __file__)
|
||||
|
||||
|
||||
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
|
||||
"""Record blocking calls made from async runtime code.
|
||||
|
||||
By default the detector reports violations but does not fail on context
|
||||
exit. Tests can set ``fail_on_exit=True`` or call
|
||||
``assert_no_blocking_calls()`` explicitly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
||||
*,
|
||||
fail_on_exit: bool = False,
|
||||
patch_loaded_aliases: bool = True,
|
||||
stack_limit: int = 12,
|
||||
) -> None:
|
||||
self._specs = tuple(specs)
|
||||
self._fail_on_exit = fail_on_exit
|
||||
self._patch_loaded_aliases_enabled = patch_loaded_aliases
|
||||
self._stack_limit = stack_limit
|
||||
self._patches: list[tuple[object, str, BlockingCallable]] = []
|
||||
self._patch_keys: set[tuple[int, str]] = set()
|
||||
self.violations: list[BlockingCall] = []
|
||||
self._active = False
|
||||
|
||||
def __enter__(self) -> BlockingIODetector:
|
||||
try:
|
||||
self._active = True
|
||||
alias_replacements: dict[int, BlockingCallable] = {}
|
||||
for spec in self._specs:
|
||||
owner, attr_name, original = _resolve_target(spec.target)
|
||||
wrapper = self._wrap(spec, original)
|
||||
self._patch_attribute(owner, attr_name, original, wrapper)
|
||||
alias_replacements[id(original)] = wrapper
|
||||
|
||||
if self._patch_loaded_aliases_enabled:
|
||||
self._patch_loaded_module_aliases(alias_replacements)
|
||||
except Exception:
|
||||
self._restore()
|
||||
self._active = False
|
||||
raise
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback_value: TracebackType | None,
|
||||
) -> bool | None:
|
||||
self._restore()
|
||||
self._active = False
|
||||
if exc_type is None and self._fail_on_exit:
|
||||
self.assert_no_blocking_calls()
|
||||
return None
|
||||
|
||||
def _restore(self) -> None:
|
||||
for owner, attr_name, original in reversed(self._patches):
|
||||
setattr(owner, attr_name, original)
|
||||
self._patches.clear()
|
||||
self._patch_keys.clear()
|
||||
|
||||
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
|
||||
key = (id(owner), attr_name)
|
||||
if key in self._patch_keys:
|
||||
return
|
||||
setattr(owner, attr_name, replacement)
|
||||
self._patches.append((owner, attr_name, original))
|
||||
self._patch_keys.add(key)
|
||||
|
||||
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
|
||||
for module in tuple(sys.modules.values()):
|
||||
namespace = getattr(module, "__dict__", None)
|
||||
if not isinstance(namespace, dict):
|
||||
continue
|
||||
|
||||
for attr_name, value in tuple(namespace.items()):
|
||||
replacement = replacements_by_id.get(id(value))
|
||||
if replacement is not None:
|
||||
self._patch_attribute(module, attr_name, value, replacement)
|
||||
|
||||
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
|
||||
@wraps(original)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if spec.record_on_iteration:
|
||||
result = original(*args, **kwargs)
|
||||
return self._wrap_iteration(spec, result)
|
||||
self._record_if_blocking(spec)
|
||||
return original(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
|
||||
iterator = iter(iterable)
|
||||
reported = False
|
||||
|
||||
while True:
|
||||
if not reported:
|
||||
reported = self._record_if_blocking(spec)
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
|
||||
if self._active and _is_event_loop_thread():
|
||||
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
|
||||
self.violations.append(BlockingCall(spec.name, spec.target, stack))
|
||||
return True
|
||||
return False
|
||||
|
||||
def assert_no_blocking_calls(self) -> None:
|
||||
if self.violations:
|
||||
raise AssertionError(format_blocking_calls(self.violations))
|
||||
|
||||
|
||||
class BlockingIOProbe:
|
||||
"""Collect detector output across tests and format a compact summary."""
|
||||
|
||||
def __init__(self, project_root: Path) -> None:
|
||||
self._project_root = project_root.resolve()
|
||||
self._observed: list[tuple[str, BlockingCall]] = []
|
||||
|
||||
@property
|
||||
def violation_count(self) -> int:
|
||||
return len(self._observed)
|
||||
|
||||
@property
|
||||
def test_count(self) -> int:
|
||||
return len({nodeid for nodeid, _violation in self._observed})
|
||||
|
||||
def clear(self) -> None:
|
||||
self._observed.clear()
|
||||
|
||||
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
|
||||
for violation in violations:
|
||||
self._observed.append((nodeid, violation))
|
||||
|
||||
def format_summary(self, *, limit: int = 30) -> str:
|
||||
if not self._observed:
|
||||
return "blocking io probe: no violations"
|
||||
|
||||
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
|
||||
for _nodeid, violation in self._observed:
|
||||
frame = self._local_call_site(violation.stack)
|
||||
if frame is None:
|
||||
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
|
||||
continue
|
||||
|
||||
call_sites[
|
||||
(
|
||||
violation.name,
|
||||
self._relative(frame.filename),
|
||||
frame.lineno,
|
||||
frame.name,
|
||||
(frame.line or "").strip(),
|
||||
)
|
||||
] += 1
|
||||
|
||||
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
|
||||
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
|
||||
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _relative(self, filename: str) -> str:
|
||||
try:
|
||||
return str(Path(filename).resolve().relative_to(self._project_root))
|
||||
except ValueError:
|
||||
return filename
|
||||
|
||||
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
|
||||
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
|
||||
if local_frames:
|
||||
return local_frames[-1]
|
||||
|
||||
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
|
||||
return test_frames[-1] if test_frames else None
|
||||
|
||||
|
||||
def detect_blocking_io(
|
||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
||||
*,
|
||||
fail_on_exit: bool = False,
|
||||
patch_loaded_aliases: bool = True,
|
||||
stack_limit: int = 12,
|
||||
) -> BlockingIODetector:
|
||||
"""Create a detector context manager for a focused test scope."""
|
||||
|
||||
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
|
||||
|
||||
|
||||
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
|
||||
"""Format detector output with enough stack context to locate call sites."""
|
||||
|
||||
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
|
||||
for index, violation in enumerate(violations, start=1):
|
||||
lines.append(f"{index}. {violation.name} ({violation.target})")
|
||||
lines.extend(_format_stack(violation.stack))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
|
||||
for frame in stack:
|
||||
location = f"{frame.filename}:{frame.lineno}"
|
||||
lines = [f" at {frame.name} ({location})"]
|
||||
if frame.line:
|
||||
lines.append(f" {frame.line.strip()}")
|
||||
yield from lines
|
||||
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from os import walk as imported_walk
|
||||
from pathlib import Path
|
||||
from time import sleep as imported_sleep
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from support.detectors.blocking_io import (
|
||||
BlockingCallSpec,
|
||||
BlockingIOProbe,
|
||||
detect_blocking_io,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
|
||||
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
|
||||
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
|
||||
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
|
||||
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
|
||||
|
||||
|
||||
async def test_records_time_sleep_on_event_loop() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
time.sleep(0)
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
||||
|
||||
|
||||
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
|
||||
original_alias = imported_sleep
|
||||
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
imported_sleep(0)
|
||||
|
||||
assert imported_sleep is original_alias
|
||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
||||
|
||||
|
||||
async def test_can_disable_loaded_alias_patching() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
|
||||
imported_sleep(0)
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
await asyncio.to_thread(time.sleep, 0)
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
|
||||
await asyncio.to_thread(time.sleep, 0)
|
||||
|
||||
assert blocking_io_detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
|
||||
def call_sleep() -> list[str]:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
time.sleep(0)
|
||||
return [violation.name for violation in detector.violations]
|
||||
|
||||
assert await asyncio.to_thread(call_sleep) == []
|
||||
|
||||
|
||||
async def test_fail_on_exit_includes_call_site() -> None:
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
|
||||
time.sleep(0)
|
||||
|
||||
message = str(exc_info.value)
|
||||
assert "time.sleep" in message
|
||||
assert "test_fail_on_exit_includes_call_site" in message
|
||||
|
||||
|
||||
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
|
||||
return f"{method}:{url}"
|
||||
|
||||
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
|
||||
|
||||
with detect_blocking_io(REQUESTS_ONLY) as detector:
|
||||
assert requests.get("https://example.invalid") == "get:https://example.invalid"
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
|
||||
|
||||
|
||||
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
|
||||
return httpx.Response(200, request=httpx.Request(method, url))
|
||||
|
||||
monkeypatch.setattr(httpx.Client, "request", fake_request)
|
||||
|
||||
with detect_blocking_io(HTTPX_ONLY) as detector:
|
||||
with httpx.Client() as client:
|
||||
response = client.get("https://example.invalid")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
|
||||
|
||||
|
||||
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
assert list(os.walk(tmp_path))
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
||||
|
||||
|
||||
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
original_alias = imported_walk
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
assert list(imported_walk(tmp_path))
|
||||
|
||||
assert imported_walk is original_alias
|
||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
||||
|
||||
|
||||
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
walker = os.walk(tmp_path)
|
||||
|
||||
assert list(walker)
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
walker = os.walk(tmp_path)
|
||||
assert await asyncio.to_thread(lambda: list(walker))
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
|
||||
|
||||
|
||||
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
|
||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
||||
summary = probe.format_summary()
|
||||
|
||||
assert "blocking io probe: 1 violations across 1 tests" in summary
|
||||
assert "pathlib.Path.read_text" in summary
|
||||
|
||||
|
||||
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
|
||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
||||
|
||||
assert probe.format_summary() == "blocking io probe: no violations"
|
||||
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
||||
assert probe.violation_count == 1
|
||||
|
||||
probe.clear()
|
||||
|
||||
assert probe.violation_count == 0
|
||||
assert probe.format_summary() == "blocking io probe: no violations"
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
ORIGINAL_SLEEP = time.sleep
|
||||
|
||||
|
||||
def replacement_sleep(seconds: float) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(time, "sleep", replacement_sleep)
|
||||
assert time.sleep is replacement_sleep
|
||||
|
||||
|
||||
@pytest.mark.no_blocking_io_probe
|
||||
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
|
||||
assert time.sleep is ORIGINAL_SLEEP
|
||||
assert getattr(time.sleep, "__wrapped__", None) is None
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Real-LLM end-to-end verification for issue #2884.
|
||||
|
||||
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
|
||||
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
|
||||
and the production ``get_available_tools`` pipeline. The only thing we mock is
|
||||
the MCP tool source — we hand-roll two ``@tool``s and inject them through
|
||||
``deerflow.mcp.cache.get_cached_mcp_tools``.
|
||||
|
||||
The flow exercised:
|
||||
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
|
||||
that re-enters ``get_available_tools`` on the same task — this is the
|
||||
code path issue #2884 reports). It must call ``tool_search`` to
|
||||
discover the deferred ``fake_calculator`` tool.
|
||||
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
|
||||
``fake_subagent_trigger`` re-enters ``get_available_tools``.
|
||||
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
|
||||
so it can actually call it. Without this PR's fix, the re-entry wipes
|
||||
the promotion and the model can no longer invoke the tool.
|
||||
|
||||
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
|
||||
test run. Run with::
|
||||
|
||||
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
|
||||
PYTHONPATH=. uv run pytest \
|
||||
tests/test_deferred_tool_promotion_real_llm.py -v -s
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skip control: only run when explicitly opted in.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("ONEAPI_E2E") != "1",
|
||||
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake "MCP" tools the agent should discover via tool_search.
|
||||
# Keep them obviously synthetic so the model can pattern-match the search.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_calls: list[str] = []
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_calculator(expression: str) -> str:
|
||||
"""Evaluate a tiny arithmetic expression like '2 + 2'.
|
||||
|
||||
Reserved for the user — only call this if the user asks for arithmetic.
|
||||
"""
|
||||
_calls.append(f"fake_calculator:{expression}")
|
||||
try:
|
||||
# Trivially safe-eval just for the e2e check
|
||||
allowed = set("0123456789+-*/() .")
|
||||
if not set(expression) <= allowed:
|
||||
return "expression contains disallowed characters"
|
||||
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
|
||||
except Exception as e:
|
||||
return f"error: {e}"
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_translator(text: str, target_lang: str) -> str:
|
||||
"""Translate text into the given language code. Decorative — not used."""
|
||||
_calls.append(f"fake_translator:{text}:{target_lang}")
|
||||
return f"[{target_lang}] {text}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline wiring (same shape as the in-process tests).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry_between_tests():
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
real_ext = ExtensionsConfig(
|
||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: real_ext),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
||||
|
||||
|
||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Build a minimal mock AppConfig and patch the symbol — never call the
|
||||
real loader, which would trigger ``_apply_singleton_configs`` and
|
||||
permanently mutate cross-test singletons (memory, title, …)."""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_cfg = AppConfig.model_construct(
|
||||
log_level="info",
|
||||
models=[],
|
||||
tools=[],
|
||||
tool_groups=[],
|
||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
||||
tool_search=ToolSearchConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real-LLM e2e test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
|
||||
"""End-to-end against a real OpenAI-compatible LLM.
|
||||
|
||||
The model must:
|
||||
Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and
|
||||
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
|
||||
``fake_subagent_trigger(...)``.
|
||||
Turn 2 — call ``fake_calculator`` and finish.
|
||||
|
||||
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
|
||||
layer — recorded in ``_calls`` — which proves the model received the
|
||||
promoted schema after the re-entrant ``get_available_tools`` call.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
_calls.clear()
|
||||
|
||||
@as_tool
|
||||
async def fake_subagent_trigger(prompt: str) -> str:
|
||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
|
||||
|
||||
Use this whenever the user asks you to delegate work — pass a short
|
||||
description as ``prompt``.
|
||||
"""
|
||||
# ``task_tool`` does this internally. Whether the registry-reset that
|
||||
# used to happen here actually leaks back to the parent task depends
|
||||
# on asyncio's implicit context-copying semantics (gather creates
|
||||
# child tasks with copied contexts, so reset_deferred_registry is
|
||||
# task-local) — but the fix in this PR is what GUARANTEES the
|
||||
# promotion sticks regardless of which integration path triggers a
|
||||
# re-entrant ``get_available_tools`` call.
|
||||
get_available_tools(subagent_enabled=False)
|
||||
_calls.append(f"fake_subagent_trigger:{prompt}")
|
||||
return "subagent completed"
|
||||
|
||||
tools = get_available_tools() + [fake_subagent_trigger]
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
temperature=0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a meticulous assistant. Available deferred tools include a "
|
||||
"calculator and a translator — their schemas are hidden until you "
|
||||
"search for them via tool_search.\n\n"
|
||||
"Procedure for the user's request:\n"
|
||||
" 1. Call tool_search with query 'select:fake_calculator' AND "
|
||||
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
|
||||
"to delegate the side work. Put both tool_calls in your first response.\n"
|
||||
" 2. After both tool messages come back, call fake_calculator with "
|
||||
"the user's expression.\n"
|
||||
" 3. Reply with just the numeric result."
|
||||
)
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
result = await graph.ainvoke(
|
||||
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
|
||||
config={"recursion_limit": 12},
|
||||
)
|
||||
|
||||
print("\n=== tool calls recorded ===")
|
||||
for c in _calls:
|
||||
print(f" {c}")
|
||||
print("\n=== final message ===")
|
||||
final_text = result["messages"][-1].content if result["messages"] else "(none)"
|
||||
print(f" {final_text!r}")
|
||||
|
||||
# The smoking-gun assertion: fake_calculator was actually invoked at the
|
||||
# tool layer. This is only possible if the promoted schema reached the
|
||||
# model in turn 2, despite the subagent-style re-entry in turn 1.
|
||||
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
|
||||
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
|
||||
|
||||
# And the math should actually be done correctly (sanity that the LLM
|
||||
# really used the result, not just hallucinated the answer).
|
||||
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
|
||||
@@ -0,0 +1,390 @@
|
||||
"""Reproduce + regression-guard issue #2884.
|
||||
|
||||
Hypothesis from the issue:
|
||||
``tools.tools.get_available_tools`` unconditionally calls
|
||||
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
|
||||
every time it is invoked. If anything calls ``get_available_tools`` again
|
||||
during the same async context (after the agent has promoted tools via
|
||||
``tool_search``), the promotion is wiped and the next model call hides the
|
||||
tool's schema again.
|
||||
|
||||
These tests pin two things:
|
||||
|
||||
A. **At the unit boundary** — verify the failure mode directly. Promote a
|
||||
tool in the registry, then call ``get_available_tools`` again and observe
|
||||
that the ContextVar registry is reset and the promotion is lost.
|
||||
|
||||
B. **At the graph-execution boundary** — drive a real ``create_agent`` graph
|
||||
with the real ``DeferredToolFilterMiddleware`` through two model turns.
|
||||
The first turn calls ``tool_search`` which promotes a tool. The second
|
||||
turn must see that tool's schema in ``request.tools``. If
|
||||
``get_available_tools`` were to run again between the two turns and reset
|
||||
the registry, the second turn's filter would strip the tool.
|
||||
|
||||
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
|
||||
unmodified; mock only the LLM and the MCP tool source. Patch
|
||||
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
|
||||
``get_available_tools`` resolves via lazy import) to return our fixture
|
||||
tools so we don't need a real MCP server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
|
||||
class FakeToolCallingModel(FakeMessagesListChatModel):
|
||||
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_mcp_search(query: str) -> str:
|
||||
"""Pretend to search a knowledge base for the given query."""
|
||||
return f"results for {query}"
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_mcp_fetch(url: str) -> str:
|
||||
"""Pretend to fetch a page at the given URL."""
|
||||
return f"content of {url}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _supply_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
|
||||
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_deferred_registry_between_tests():
|
||||
"""Each test must start with a clean ContextVar.
|
||||
|
||||
The registry lives in a module-level ContextVar with no per-task isolation
|
||||
in a synchronous test runner, so one test's promotion can leak into the
|
||||
next and silently break filter assertions.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
||||
"""Make get_available_tools believe an MCP server is registered.
|
||||
|
||||
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
|
||||
that both ``AppConfig.from_file`` (which calls
|
||||
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
|
||||
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
|
||||
see a valid instance. Then point the MCP tool cache at our fixture tools.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
real_ext = ExtensionsConfig(
|
||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: real_ext),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
||||
|
||||
|
||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Force config.tool_search.enabled=True without touching the yaml.
|
||||
|
||||
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
|
||||
which permanently mutates module-level singletons (``_memory_config``,
|
||||
``_title_config``, …) to match the developer's ``config.yaml`` — even
|
||||
after pytest restores our patch. That leaks across tests later in the
|
||||
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
|
||||
require ``_memory_config.enabled = True``, which is the dataclass default
|
||||
but FALSE in the actual yaml).
|
||||
|
||||
Build a minimal mock AppConfig instead and never call the real loader.
|
||||
"""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_cfg = AppConfig.model_construct(
|
||||
log_level="info",
|
||||
models=[],
|
||||
tools=[],
|
||||
tool_groups=[],
|
||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
||||
tool_search=ToolSearchConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section A — direct unit-level reproduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
|
||||
|
||||
Step 1: call get_available_tools() — registers MCP tools as deferred.
|
||||
Step 2: simulate the agent calling tool_search by promoting one tool.
|
||||
Step 3: call get_available_tools() again (the same code path
|
||||
``task_tool`` exercises mid-run).
|
||||
|
||||
Assertion: after step 3, the promoted tool is STILL promoted (not
|
||||
re-deferred). On ``main`` before the fix, step 3's
|
||||
``reset_deferred_registry()`` wiped the promotion and re-registered
|
||||
every MCP tool as deferred — this assertion fired with
|
||||
``REGRESSION (#2884)``.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
# Step 1: first call — both MCP tools start deferred
|
||||
get_available_tools()
|
||||
reg1 = get_deferred_registry()
|
||||
assert reg1 is not None
|
||||
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
|
||||
|
||||
# Step 2: simulate tool_search promoting one of them
|
||||
reg1.promote({"fake_mcp_search"})
|
||||
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
|
||||
|
||||
# Step 3: second call — registry must NOT silently undo the promotion
|
||||
get_available_tools()
|
||||
reg2 = get_deferred_registry()
|
||||
assert reg2 is not None
|
||||
deferred_after = {e.name for e in reg2.entries}
|
||||
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section B — graph-execution reproduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ToolSearchPromotingModel(FakeToolCallingModel):
|
||||
"""Two-turn model that:
|
||||
|
||||
Turn 1 → emit a tool_call for ``tool_search`` (the real one)
|
||||
Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool)
|
||||
|
||||
Records the tools it received on each turn so the test can inspect what
|
||||
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
|
||||
"""
|
||||
|
||||
bound_tools_per_turn: list[list[str]] = []
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
# Record the tool names the model would see in this turn
|
||||
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
|
||||
self.bound_tools_per_turn.append(names)
|
||||
return self
|
||||
|
||||
|
||||
def _build_promoting_model() -> _ToolSearchPromotingModel:
|
||||
return _ToolSearchPromotingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool_search",
|
||||
"args": {"query": "select:fake_mcp_search"},
|
||||
"id": "call_search_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fake_mcp_search",
|
||||
"args": {"query": "hello"},
|
||||
"id": "call_mcp_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="all done"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""End-to-end: drive a real create_agent graph through two turns.
|
||||
|
||||
Without the fix, the second-turn bind_tools call should NOT contain
|
||||
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
|
||||
registry and strips it). With the fix, the model sees the schema and can
|
||||
invoke it.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
tools = get_available_tools()
|
||||
# Sanity: the assembled tool list includes the deferred tools (they're in
|
||||
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
|
||||
# they reach the model)
|
||||
tool_names = {getattr(t, "name", "") for t in tools}
|
||||
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
|
||||
|
||||
model = _build_promoting_model()
|
||||
model.bound_tools_per_turn = [] # reset class-level recorder
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt="bug-2884-repro",
|
||||
)
|
||||
|
||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
||||
|
||||
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
|
||||
turn1 = set(model.bound_tools_per_turn[0])
|
||||
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
|
||||
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
|
||||
|
||||
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
|
||||
# This is the load-bearing assertion for issue #2884.
|
||||
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
|
||||
turn2 = set(model.bound_tools_per_turn[1])
|
||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section C — the actual issue #2884 trigger: a re-entrant
|
||||
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
|
||||
# wipe the parent's promotion.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
|
||||
(the same pattern that happens when ``task_tool`` builds a subagent's
|
||||
toolset mid-run) must not wipe the parent agent's tool_search promotions.
|
||||
|
||||
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
|
||||
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
|
||||
``get_available_tools`` again — exactly what ``task_tool`` does when it
|
||||
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
|
||||
promoted tool. Without the fix, the re-entry wipes the registry and
|
||||
the filter re-hides it.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
# The trigger tool simulates what task_tool does internally: rebuild the
|
||||
# toolset by calling get_available_tools while the registry is live.
|
||||
@as_tool
|
||||
def fake_subagent_trigger(prompt: str) -> str:
|
||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
|
||||
get_available_tools(subagent_enabled=False)
|
||||
return f"spawned subagent for: {prompt}"
|
||||
|
||||
tools = get_available_tools() + [fake_subagent_trigger]
|
||||
|
||||
bound_per_turn: list[list[str]] = []
|
||||
|
||||
class _Model(FakeToolCallingModel):
|
||||
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
|
||||
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
|
||||
return self
|
||||
|
||||
model = _Model(
|
||||
responses=[
|
||||
# Turn 1: do both in one batch — promote AND trigger the
|
||||
# subagent-style rebuild. LangGraph executes them in order in the
|
||||
# same agent step.
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool_search",
|
||||
"args": {"query": "select:fake_mcp_search"},
|
||||
"id": "call_search_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "fake_subagent_trigger",
|
||||
"args": {"prompt": "go"},
|
||||
"id": "call_trigger_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
# Turn 2: try to invoke the promoted tool. The model gets this
|
||||
# turn only if turn 1's bind_tools recorded what the filter sent.
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fake_mcp_search",
|
||||
"args": {"query": "hello"},
|
||||
"id": "call_mcp_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="all done"),
|
||||
]
|
||||
)
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt="bug-2884-subagent-repro",
|
||||
)
|
||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
||||
|
||||
# Turn 1 sanity: deferred tool not visible yet
|
||||
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
|
||||
|
||||
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
|
||||
# re-entrant get_available_tools call that happened in turn 1's tool batch.
|
||||
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
|
||||
turn2 = set(bound_per_turn[1])
|
||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
|
||||
@@ -3,6 +3,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
assert q._queue[0].user_id == "alice"
|
||||
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
@@ -278,3 +281,48 @@ class TestRunRepository:
|
||||
assert row4["model_name"] is None
|
||||
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
|
||||
captured = []
|
||||
|
||||
class FakeResult:
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
class FakeSession:
|
||||
async def execute(self, stmt):
|
||||
captured.append(stmt)
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionContext:
|
||||
async def __aenter__(self):
|
||||
return FakeSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
repo = RunRepository(lambda: FakeSessionContext())
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||
assert agg == {
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_runs": 0,
|
||||
"by_model": {},
|
||||
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
|
||||
}
|
||||
assert len(captured) == 1
|
||||
|
||||
stmt = captured[0]
|
||||
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
|
||||
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
|
||||
|
||||
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
|
||||
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
|
||||
|
||||
assert select_match is not None
|
||||
assert group_by_match is not None
|
||||
assert select_match.group(1) == group_by_match.group(1)
|
||||
|
||||
@@ -59,12 +59,15 @@ def _make_result(
|
||||
ai_messages: list[dict] | None = None,
|
||||
result: str | None = None,
|
||||
error: str | None = None,
|
||||
token_usage_records: list[dict] | None = None,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
status=status,
|
||||
ai_messages=ai_messages or [],
|
||||
result=result,
|
||||
error=error,
|
||||
token_usage_records=token_usage_records or [],
|
||||
usage_reported=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
|
||||
assert len(report_calls) == 1
|
||||
assert report_calls[0][1] is cancel_result
|
||||
assert cleanup_calls == ["tc-cancel-report"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected_type",
|
||||
[
|
||||
(FakeSubagentStatus.COMPLETED, "task_completed"),
|
||||
(FakeSubagentStatus.FAILED, "task_failed"),
|
||||
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
|
||||
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
|
||||
],
|
||||
)
|
||||
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
|
||||
"""Terminal task events include a usage summary from token_usage_records."""
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
|
||||
records = [
|
||||
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
|
||||
]
|
||||
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-usage",
|
||||
)
|
||||
|
||||
terminal_events = [e for e in events if e["type"] == expected_type]
|
||||
assert len(terminal_events) == 1
|
||||
assert terminal_events[0]["usage"] == {
|
||||
"input_tokens": 300,
|
||||
"output_tokens": 130,
|
||||
"total_tokens": 430,
|
||||
}
|
||||
|
||||
|
||||
def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
||||
"""Terminal event has usage=None when token_usage_records is empty."""
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
|
||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-records",
|
||||
)
|
||||
|
||||
completed = [e for e in events if e["type"] == "task_completed"]
|
||||
assert len(completed) == 1
|
||||
assert completed[0]["usage"] is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_app_config",
|
||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
||||
)
|
||||
|
||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
|
||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
|
||||
|
||||
task_tool_module._subagent_usage_cache.clear()
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-disabled-cache",
|
||||
)
|
||||
|
||||
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
|
||||
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
with pytest.raises(RuntimeError, match="poll failed"):
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-error",
|
||||
)
|
||||
|
||||
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
|
||||
"tool_call_id": "write_todos:remove",
|
||||
}
|
||||
]
|
||||
|
||||
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
|
||||
middleware = TokenUsageMiddleware()
|
||||
first_dispatch = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
|
||||
)
|
||||
second_dispatch = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "task:second-a", "name": "task", "args": {}},
|
||||
{"id": "task:second-b", "name": "task", "args": {}},
|
||||
],
|
||||
)
|
||||
messages = [
|
||||
first_dispatch,
|
||||
ToolMessage(content="first", tool_call_id="task:first"),
|
||||
second_dispatch,
|
||||
ToolMessage(content="second-a", tool_call_id="task:second-a"),
|
||||
ToolMessage(content="second-b", tool_call_id="task:second-b"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
cached_usage = {
|
||||
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
|
||||
}
|
||||
|
||||
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"pop_cached_subagent_usage",
|
||||
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
|
||||
)
|
||||
|
||||
result = middleware.after_model({"messages": messages}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
|
||||
assert len(usage_updates) == 1
|
||||
updated = usage_updates[0]
|
||||
assert updated.tool_calls == second_dispatch.tool_calls
|
||||
assert updated.usage_metadata == {
|
||||
"input_tokens": 30,
|
||||
"output_tokens": 12,
|
||||
"total_tokens": 42,
|
||||
}
|
||||
|
||||
@@ -65,8 +65,7 @@ def _make_minimal_config(tools):
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
|
||||
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||
"""Config-loaded async-only tools can still be invoked by sync clients."""
|
||||
|
||||
async def async_tool_impl(x: int) -> str:
|
||||
@@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash,
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
|
||||
def test_no_duplicates_returned(mock_bash, mock_cfg):
|
||||
"""get_available_tools() never returns two tools with the same name."""
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
|
||||
@@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||
def test_first_occurrence_wins(mock_bash, mock_cfg):
|
||||
"""When duplicates exist, the first occurrence is kept."""
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
|
||||
@@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
|
||||
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
|
||||
"""A warning is logged for every skipped duplicate."""
|
||||
import logging
|
||||
|
||||
|
||||
Generated
+3
-3
@@ -2005,7 +2005,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.7.36"
|
||||
version = "0.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
@@ -2018,9 +2018,9 @@ dependencies = [
|
||||
{ name = "xxhash" },
|
||||
{ name = "zstandard" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
||||
@@ -10,7 +10,6 @@ import { FlickeringGrid } from "@/components/ui/flickering-grid";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { useAuth } from "@/core/auth/AuthProvider";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
/**
|
||||
* Validate next parameter
|
||||
@@ -72,7 +71,7 @@ export default function LoginPage() {
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
|
||||
void fetch("/api/v1/auth/setup-status")
|
||||
.then((r) => r.json())
|
||||
.then((data: { needs_setup?: boolean }) => {
|
||||
if (!cancelled && data.needs_setup) {
|
||||
@@ -95,8 +94,8 @@ export default function LoginPage() {
|
||||
|
||||
try {
|
||||
const endpoint = isLogin
|
||||
? `${getBackendBaseURL()}/api/v1/auth/login/local`
|
||||
: `${getBackendBaseURL()}/api/v1/auth/register`;
|
||||
? "/api/v1/auth/login/local"
|
||||
: "/api/v1/auth/register";
|
||||
const body = isLogin
|
||||
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
|
||||
: JSON.stringify({ email, password });
|
||||
|
||||
@@ -10,7 +10,6 @@ import { Input } from "@/components/ui/input";
|
||||
import { getCsrfHeaders } from "@/core/api/fetcher";
|
||||
import { useAuth } from "@/core/auth/AuthProvider";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
type SetupMode = "loading" | "init_admin" | "change_password";
|
||||
|
||||
@@ -37,7 +36,7 @@ export default function SetupPage() {
|
||||
setMode("change_password");
|
||||
} else if (!isAuthenticated) {
|
||||
// Check if the system has no users yet
|
||||
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
|
||||
void fetch("/api/v1/auth/setup-status")
|
||||
.then((r) => r.json())
|
||||
.then((data: { needs_setup?: boolean }) => {
|
||||
if (cancelled) return;
|
||||
@@ -73,7 +72,7 @@ export default function SetupPage() {
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/initialize`, {
|
||||
const res = await fetch("/api/v1/auth/initialize", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
credentials: "include",
|
||||
@@ -114,22 +113,19 @@ export default function SetupPage() {
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${getBackendBaseURL()}/api/v1/auth/change-password`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...getCsrfHeaders(),
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify({
|
||||
current_password: currentPassword,
|
||||
new_password: newPassword,
|
||||
new_email: email || undefined,
|
||||
}),
|
||||
const res = await fetch("/api/v1/auth/change-password", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...getCsrfHeaders(),
|
||||
},
|
||||
);
|
||||
credentials: "include",
|
||||
body: JSON.stringify({
|
||||
current_password: currentPassword,
|
||||
new_password: newPassword,
|
||||
new_email: email || undefined,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
|
||||
@@ -4,7 +4,6 @@ import { redirect } from "next/navigation";
|
||||
import { AuthProvider } from "@/core/auth/AuthProvider";
|
||||
import { getServerSideUser } from "@/core/auth/server";
|
||||
import { assertNever } from "@/core/auth/types";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import { WorkspaceContent } from "./workspace-content";
|
||||
|
||||
@@ -45,7 +44,7 @@ export default async function WorkspaceLayout({
|
||||
Retry
|
||||
</Link>
|
||||
<Link
|
||||
href={`${getBackendBaseURL()}/api/v1/auth/logout`}
|
||||
href="/api/v1/auth/logout"
|
||||
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
|
||||
>
|
||||
Logout & Reset
|
||||
|
||||
@@ -12,13 +12,11 @@ function TokenUsageSummary({
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
totalTokens,
|
||||
unavailable = false,
|
||||
}: {
|
||||
className?: string;
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
totalTokens?: number;
|
||||
unavailable?: boolean;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
|
||||
@@ -33,21 +31,15 @@ function TokenUsageSummary({
|
||||
<CoinsIcon className="size-3" />
|
||||
{t.tokenUsage.label}
|
||||
</span>
|
||||
{!unavailable ? (
|
||||
<>
|
||||
<span>
|
||||
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
||||
</span>
|
||||
<span>
|
||||
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
||||
</span>
|
||||
<span className="font-medium">
|
||||
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
||||
</span>
|
||||
</>
|
||||
) : (
|
||||
<span>{t.tokenUsage.unavailableShort}</span>
|
||||
)}
|
||||
<span>
|
||||
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
||||
</span>
|
||||
<span>
|
||||
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
||||
</span>
|
||||
<span className="font-medium">
|
||||
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -55,7 +47,7 @@ function TokenUsageSummary({
|
||||
export function MessageTokenUsageList({
|
||||
className,
|
||||
enabled = false,
|
||||
isLoading = false,
|
||||
isLoading: _isLoading = false,
|
||||
messages,
|
||||
}: {
|
||||
className?: string;
|
||||
@@ -63,7 +55,7 @@ export function MessageTokenUsageList({
|
||||
isLoading?: boolean;
|
||||
messages: Message[];
|
||||
}) {
|
||||
if (!enabled || isLoading) {
|
||||
if (!enabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -75,13 +67,16 @@ export function MessageTokenUsageList({
|
||||
|
||||
const usage = accumulateUsage(aiMessages);
|
||||
|
||||
if (!usage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<TokenUsageSummary
|
||||
className={className}
|
||||
inputTokens={usage?.inputTokens}
|
||||
outputTokens={usage?.outputTokens}
|
||||
totalTokens={usage?.totalTokens}
|
||||
unavailable={!usage}
|
||||
inputTokens={usage.inputTokens}
|
||||
outputTokens={usage.outputTokens}
|
||||
totalTokens={usage.totalTokens}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import { Input } from "@/components/ui/input";
|
||||
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
|
||||
import { useAuth } from "@/core/auth/AuthProvider";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
|
||||
import { SettingsSection } from "./settings-section";
|
||||
@@ -39,20 +38,17 @@ export function AccountSettingsPage() {
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${getBackendBaseURL()}/api/v1/auth/change-password`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...getCsrfHeaders(),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
current_password: currentPassword,
|
||||
new_password: newPassword,
|
||||
}),
|
||||
const res = await fetch("/api/v1/auth/change-password", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...getCsrfHeaders(),
|
||||
},
|
||||
);
|
||||
body: JSON.stringify({
|
||||
current_password: currentPassword,
|
||||
new_password: newPassword,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
|
||||
@@ -10,8 +10,6 @@ import React, {
|
||||
type ReactNode,
|
||||
} from "react";
|
||||
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import { type User, buildLoginUrl } from "./types";
|
||||
|
||||
// Re-export for consumers
|
||||
@@ -58,7 +56,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
|
||||
const refreshUser = useCallback(async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/me`, {
|
||||
const res = await fetch("/api/v1/auth/me", {
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
@@ -90,7 +88,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
|
||||
setUser(null);
|
||||
|
||||
try {
|
||||
await fetch(`${getBackendBaseURL()}/api/v1/auth/logout`, {
|
||||
await fetch("/api/v1/auth/logout", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
@@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
|
||||
return hasUsage ? cumulative : null;
|
||||
}
|
||||
|
||||
function hasNonZeroUsage(
|
||||
export function hasNonZeroUsage(
|
||||
usage: TokenUsage | null | undefined,
|
||||
): usage is TokenUsage {
|
||||
return (
|
||||
@@ -75,7 +75,7 @@ function hasNonZeroUsage(
|
||||
);
|
||||
}
|
||||
|
||||
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
||||
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
||||
return {
|
||||
inputTokens: base.inputTokens + delta.inputTokens,
|
||||
outputTokens: base.outputTokens + delta.outputTokens,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
|
||||
/**
|
||||
* Deduplicate incoming messages against an existing history.
|
||||
* A message is considered a duplicate if its `id` or `tool_call_id`
|
||||
* (for tool messages) already appears in the existing list.
|
||||
*/
|
||||
export function deduplicateHistoryMessages(
|
||||
existing: Message[],
|
||||
incoming: Message[],
|
||||
): Message[] {
|
||||
const existingIds = new Set(
|
||||
existing
|
||||
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
|
||||
.filter(Boolean),
|
||||
);
|
||||
|
||||
return incoming.filter((m) => {
|
||||
if (m.id && existingIds.has(m.id)) return false;
|
||||
if (
|
||||
"tool_call_id" in m &&
|
||||
m.tool_call_id &&
|
||||
existingIds.has(m.tool_call_id)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the new history-loading index when the runs list grows.
|
||||
*
|
||||
* - `currentIndex < 0` means all previously-known runs have been loaded;
|
||||
* reset to the last run so the user can scroll up to load new runs.
|
||||
* - `currentIndex >= 0` means some runs haven't been loaded yet;
|
||||
* shift the index by the number of newly-added runs.
|
||||
* - If no new runs were added, return `currentIndex` unchanged.
|
||||
*/
|
||||
export function adjustHistoryIndex(
|
||||
currentIndex: number,
|
||||
prevRunsLength: number,
|
||||
newRunsLength: number,
|
||||
): number {
|
||||
const added = newRunsLength - prevRunsLength;
|
||||
if (added <= 0) return currentIndex;
|
||||
if (currentIndex < 0) return newRunsLength - 1;
|
||||
return currentIndex + added;
|
||||
}
|
||||
@@ -18,6 +18,10 @@ import type { UploadedFileInfo } from "../uploads";
|
||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||
|
||||
import { fetchThreadTokenUsage } from "./api";
|
||||
import {
|
||||
adjustHistoryIndex,
|
||||
deduplicateHistoryMessages,
|
||||
} from "./history-utils";
|
||||
import { threadTokenUsageQueryKey } from "./token-usage";
|
||||
import type {
|
||||
AgentThread,
|
||||
@@ -296,7 +300,11 @@ export function useThreadStream({
|
||||
onError(error) {
|
||||
setOptimisticMessages([]);
|
||||
toast.error(getStreamErrorMessage(error));
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
messagesRef.current
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
@@ -305,9 +313,16 @@ export function useThreadStream({
|
||||
},
|
||||
onFinish(state) {
|
||||
listeners.current.onFinish?.(state.values);
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
messagesRef.current
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: ["thread", threadIdRef.current],
|
||||
});
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
});
|
||||
@@ -339,7 +354,11 @@ export function useThreadStream({
|
||||
useEffect(() => {
|
||||
startedRef.current = false;
|
||||
sendInFlightRef.current = false;
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
messagesRef.current
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
prevHumanMsgCountRef.current =
|
||||
latestMessageCountsRef.current.humanMessageCount;
|
||||
}, [threadId]);
|
||||
@@ -617,6 +636,7 @@ export function useThreadHistory(threadId: string) {
|
||||
const loadingRef = useRef(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const initialLoadDoneRef = useRef(false);
|
||||
|
||||
loadingRef.current = loading;
|
||||
const loadMessages = useCallback(async () => {
|
||||
@@ -644,7 +664,10 @@ export function useThreadHistory(threadId: string) {
|
||||
const _messages = result.data
|
||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||
.map((m) => m.content);
|
||||
setMessages((prev) => [..._messages, ...prev]);
|
||||
setMessages((prev) => {
|
||||
const deduped = deduplicateHistoryMessages(prev, _messages);
|
||||
return [...deduped, ...prev];
|
||||
});
|
||||
indexRef.current -= 1;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
@@ -652,15 +675,39 @@ export function useThreadHistory(threadId: string) {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Reset state when threadId changes
|
||||
useEffect(() => {
|
||||
threadIdRef.current = threadId;
|
||||
runsRef.current = [];
|
||||
indexRef.current = -1;
|
||||
initialLoadDoneRef.current = false;
|
||||
setMessages([]);
|
||||
}, [threadId]);
|
||||
|
||||
// Load/update history when runs data changes
|
||||
useEffect(() => {
|
||||
if (runs.data && runs.data.length > 0) {
|
||||
runsRef.current = runs.data ?? [];
|
||||
indexRef.current = runs.data.length - 1;
|
||||
const prevLength = runsRef.current.length;
|
||||
runsRef.current = runs.data;
|
||||
|
||||
if (!initialLoadDoneRef.current) {
|
||||
// Initial load: start from the most recent run
|
||||
initialLoadDoneRef.current = true;
|
||||
indexRef.current = runs.data.length - 1;
|
||||
loadMessages().catch(() => {
|
||||
toast.error("Failed to load thread history.");
|
||||
});
|
||||
} else if (runs.data.length > prevLength) {
|
||||
// New runs added (e.g., after query invalidation): adjust indexRef
|
||||
// so the user can load older history by scrolling up
|
||||
indexRef.current = adjustHistoryIndex(
|
||||
indexRef.current,
|
||||
prevLength,
|
||||
runs.data.length,
|
||||
);
|
||||
}
|
||||
}
|
||||
loadMessages().catch(() => {
|
||||
toast.error("Failed to load thread history.");
|
||||
});
|
||||
}, [threadId, runs.data, loadMessages]);
|
||||
|
||||
const appendMessages = useCallback((_messages: Message[]) => {
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import {
|
||||
adjustHistoryIndex,
|
||||
deduplicateHistoryMessages,
|
||||
} from "@/core/threads/history-utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// deduplicateHistoryMessages
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test("returns all incoming messages when existing history is empty", () => {
|
||||
const existing: Message[] = [];
|
||||
const incoming: Message[] = [
|
||||
{ type: "human", id: "m1", content: "hello" },
|
||||
{ type: "ai", id: "m2", content: "hi" },
|
||||
];
|
||||
|
||||
const result = deduplicateHistoryMessages(existing, incoming);
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result.map((m) => m.id)).toEqual(["m1", "m2"]);
|
||||
});
|
||||
|
||||
test("filters out messages whose id already exists in history", () => {
|
||||
const existing: Message[] = [
|
||||
{ type: "human", id: "m1", content: "hello" },
|
||||
{ type: "ai", id: "m2", content: "hi" },
|
||||
];
|
||||
const incoming: Message[] = [
|
||||
{ type: "human", id: "m1", content: "hello" }, // duplicate
|
||||
{ type: "ai", id: "m3", content: "new" },
|
||||
];
|
||||
|
||||
const result = deduplicateHistoryMessages(existing, incoming);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]!.id).toBe("m3");
|
||||
});
|
||||
|
||||
test("filters out tool messages by tool_call_id", () => {
|
||||
const existing: Message[] = [
|
||||
{
|
||||
type: "tool",
|
||||
id: "t1",
|
||||
tool_call_id: "tc-1",
|
||||
content: "tool result",
|
||||
name: "search",
|
||||
} as unknown as Message,
|
||||
];
|
||||
const incoming: Message[] = [
|
||||
{
|
||||
type: "tool",
|
||||
id: "t1-dup",
|
||||
tool_call_id: "tc-1",
|
||||
content: "tool result",
|
||||
name: "search",
|
||||
} as unknown as Message,
|
||||
{
|
||||
type: "tool",
|
||||
id: "t2",
|
||||
tool_call_id: "tc-2",
|
||||
content: "other result",
|
||||
name: "search",
|
||||
} as unknown as Message,
|
||||
];
|
||||
|
||||
const result = deduplicateHistoryMessages(existing, incoming);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]!.id).toBe("t2");
|
||||
});
|
||||
|
||||
test("keeps messages with no id or tool_call_id", () => {
|
||||
const existing: Message[] = [
|
||||
{ type: "human", id: "m1", content: "existing" },
|
||||
];
|
||||
const incoming: Message[] = [
|
||||
// Message without id — should be kept (not considered a duplicate)
|
||||
{ type: "ai", content: "no id" } as Message,
|
||||
];
|
||||
|
||||
const result = deduplicateHistoryMessages(existing, incoming);
|
||||
expect(result).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("deduplicates against tool_call_id from existing messages", () => {
|
||||
// Existing message has tool_call_id stored in the id set
|
||||
const existing: Message[] = [
|
||||
{
|
||||
type: "tool",
|
||||
id: "t0",
|
||||
tool_call_id: "tc-x",
|
||||
content: "result",
|
||||
name: "tool",
|
||||
} as unknown as Message,
|
||||
];
|
||||
// Incoming AI message references the same id — should be filtered
|
||||
const incoming: Message[] = [{ type: "ai", id: "tc-x", content: "response" }];
|
||||
|
||||
const result = deduplicateHistoryMessages(existing, incoming);
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// adjustHistoryIndex
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test("returns unchanged index when no new runs were added", () => {
|
||||
expect(adjustHistoryIndex(2, 5, 5)).toBe(2);
|
||||
expect(adjustHistoryIndex(-1, 3, 3)).toBe(-1);
|
||||
expect(adjustHistoryIndex(0, 1, 0)).toBe(0); // shouldn't happen, but safe
|
||||
});
|
||||
|
||||
test("resets to last run when all previous runs were loaded", () => {
|
||||
// 3 runs existed, all loaded (index = -1), now 5 runs
|
||||
const result = adjustHistoryIndex(-1, 3, 5);
|
||||
expect(result).toBe(4); // last index of new runs list
|
||||
});
|
||||
|
||||
test("shifts index by number of added runs when some are unloaded", () => {
|
||||
// 3 runs, currently at index 1 (run at index 2 loaded), now 6 runs
|
||||
const result = adjustHistoryIndex(1, 3, 6);
|
||||
// 3 new runs added, shift: 1 + (6 - 3) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
test("handles single new run when all previous were loaded", () => {
|
||||
// 4 runs, all loaded (index = -1), now 5 runs
|
||||
const result = adjustHistoryIndex(-1, 4, 5);
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
test("handles transition from empty runs to populated", () => {
|
||||
// 0 runs → 3 runs, all loaded (index = -1)
|
||||
const result = adjustHistoryIndex(-1, 0, 3);
|
||||
expect(result).toBe(2);
|
||||
});
|
||||
Reference in New Issue
Block a user