Compare commits

..

1 Commits

Author SHA1 Message Date
Willem Jiang 4dc328e460 fix(auth): use getBackendBaseURL() in auth-related fetch calls
Auth pages (login, setup) and components (AuthProvider, account-settings,
  workspace layout) used hardcoded relative paths like /api/v1/auth/...
  instead of the configurable getBackendBaseURL() used by the rest of the
  codebase. This prevented them from reaching the backend when
  NEXT_PUBLIC_BACKEND_BASE_URL was set to a different origin.

  Closes #2859
2026-05-13 15:46:29 +08:00
30 changed files with 110 additions and 1914 deletions
+1 -1
View File
@@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
Complex tasks rarely fit in a single pass. DeerFlow decomposes them. Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step. The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands. This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
+1 -1
View File
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id 11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) 14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
@@ -9,7 +9,7 @@ from typing import Any, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -217,17 +217,6 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
return "thinking" return "thinking"
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
"""Return True if the AIMessage contains a tool_call with the given id."""
for tc in message.tool_calls or []:
if isinstance(tc, dict):
if tc.get("id") == tool_call_id:
return True
elif hasattr(tc, "id") and tc.id == tool_call_id:
return True
return False
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]: def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or [] tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = [] actions: list[dict[str, Any]] = []
@@ -272,51 +261,8 @@ class TokenUsageMiddleware(AgentMiddleware):
if not messages: if not messages:
return None return None
# Annotate subagent token usage onto the AIMessage that dispatched it.
# When a task tool completes, its usage is cached by tool_call_id. Detect
# the ToolMessage → search backward for the corresponding AIMessage → merge.
# Walk backward through consecutive ToolMessages before the new AIMessage
# so that multiple concurrent task tool calls all get their subagent tokens
# written back to the same dispatch message (merging into one update).
state_updates: dict[int, AIMessage] = {}
if len(messages) >= 2:
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
idx = len(messages) - 2
while idx >= 0:
tool_msg = messages[idx]
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
break
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
if subagent_usage:
# Search backward from the ToolMessage to find the AIMessage
# that dispatched it. A single model response can dispatch
# multiple task tool calls, so we can't assume a fixed offset.
dispatch_idx = idx - 1
while dispatch_idx >= 0:
candidate = messages[dispatch_idx]
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
# Accumulate into an existing update for the same
# AIMessage (multiple task calls in one response),
# or merge fresh from the original message.
existing_update = state_updates.get(dispatch_idx)
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
merged = {
**prev,
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
}
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
break
dispatch_idx -= 1
idx -= 1
last = messages[-1] last = messages[-1]
if not isinstance(last, AIMessage): if not isinstance(last, AIMessage):
if state_updates:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return None return None
usage = getattr(last, "usage_metadata", None) usage = getattr(last, "usage_metadata", None)
@@ -342,12 +288,11 @@ class TokenUsageMiddleware(AgentMiddleware):
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {}) additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution: if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs}) updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
state_updates[len(messages) - 1] = updated_msg return {"messages": [updated_msg]}
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
@override @override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
@@ -223,11 +223,10 @@ class RunRepository(RunStore):
"""Aggregate token usage via a single SQL GROUP BY query.""" """Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error")) _completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id _thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
stmt = ( stmt = (
select( select(
model_name.label("model"), func.coalesce(RunRow.model_name, "unknown").label("model"),
func.count().label("runs"), func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"), func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"), func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
@@ -237,7 +236,7 @@ class RunRepository(RunStore):
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"), func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
) )
.where(_thread, _completed) .where(_thread, _completed)
.group_by(model_name) .group_by(func.coalesce(RunRow.model_name, "unknown"))
) )
async with self._sf() as session: async with self._sf() as session:
@@ -26,28 +26,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
# write it back to the triggering AIMessage's usage_metadata.
_subagent_usage_cache: dict[str, dict[str, int]] = {}
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
if app_config is None:
try:
app_config = get_app_config()
except FileNotFoundError:
return False
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
if enabled and usage:
_subagent_usage_cache[tool_call_id] = usage
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
return _subagent_usage_cache.pop(tool_call_id, None)
def _is_subagent_terminal(result: Any) -> bool: def _is_subagent_terminal(result: Any) -> bool:
"""Return whether a background subagent result is safe to clean up.""" """Return whether a background subagent result is safe to clean up."""
@@ -114,17 +92,6 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
return None return None
def _summarize_usage(records: list[dict] | None) -> dict | None:
"""Summarize token usage records into a compact dict for SSE events."""
if not records:
return None
return {
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
}
def _report_subagent_usage(runtime: Any, result: Any) -> None: def _report_subagent_usage(runtime: Any, result: Any) -> None:
"""Report subagent token usage to the parent RunJournal, if available. """Report subagent token usage to the parent RunJournal, if available.
@@ -210,7 +177,6 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
""" """
runtime_app_config = _get_runtime_app_config(runtime) runtime_app_config = _get_runtime_app_config(runtime)
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names() available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration # Get subagent configuration
@@ -346,32 +312,27 @@ async def task_tool(
last_message_count = current_message_count last_message_count = current_message_count
# Check if task completed, failed, or timed out # Check if task completed, failed, or timed out
usage = _summarize_usage(getattr(result, "token_usage_records", None))
if result.status == SubagentStatus.COMPLETED: if result.status == SubagentStatus.COMPLETED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result) _report_subagent_usage(runtime, result)
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage}) writer({"type": "task_completed", "task_id": task_id, "result": result.result})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}" return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED: elif result.status == SubagentStatus.FAILED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result) _report_subagent_usage(runtime, result)
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage}) writer({"type": "task_failed", "task_id": task_id, "error": result.error})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}" return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED: elif result.status == SubagentStatus.CANCELLED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result) _report_subagent_usage(runtime, result)
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage}) writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return "Task cancelled by user." return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT: elif result.status == SubagentStatus.TIMED_OUT:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result) _report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage}) writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}" return f"Task timed out. Error: {result.error}"
@@ -390,9 +351,7 @@ async def task_tool(
timeout_minutes = config.timeout_seconds // 60 timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
_report_subagent_usage(runtime, result) _report_subagent_usage(runtime, result)
usage = _summarize_usage(getattr(result, "token_usage_records", None)) writer({"type": "task_timed_out", "task_id": task_id})
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError: except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively. # Signal the background subagent thread to stop cooperatively.
@@ -415,8 +374,4 @@ async def task_tool(
cleanup_background_task(task_id) cleanup_background_task(task_id)
else: else:
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
_subagent_usage_cache.pop(tool_call_id, None)
raise
except Exception:
_subagent_usage_cache.pop(tool_call_id, None)
raise raise
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.builtins.tool_search import reset_deferred_registry
from deerflow.tools.sync import make_sync_tool_wrapper from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -116,6 +116,8 @@ def get_available_tools(
# made through the Gateway API (which runs in a separate process) are immediately # made through the Gateway API (which runs in a separate process) are immediately
# reflected when loading MCP tools. # reflected when loading MCP tools.
mcp_tools = [] mcp_tools = []
# Reset deferred registry upfront to prevent stale state from previous calls
reset_deferred_registry()
if include_mcp: if include_mcp:
try: try:
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
@@ -133,51 +135,12 @@ def get_available_tools(
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
# Reuse the existing registry if one is already set for registry = DeferredToolRegistry()
# this async context. ``get_available_tools`` is for t in mcp_tools:
# re-entered whenever a subagent is spawned registry.register(t)
# (``task_tool`` calls it to build the child agent's set_deferred_registry(registry)
# toolset), and previously we used to unconditionally
# rebuild the registry — wiping out the parent agent's
# tool_search promotions. The
# ``DeferredToolFilterMiddleware`` then re-hid those
# tools from subsequent model calls, leaving the agent
# able to see a tool's name but unable to invoke it
# (issue #2884). ``contextvars`` already gives us the
# lifetime semantics we want: a fresh request / graph
# run starts in a new asyncio task with the
# ContextVar at its default of ``None``, so reuse is
# only triggered for re-entrant calls inside one run.
#
# Intentionally NOT reconciling against the current
# ``mcp_tools`` snapshot. The MCP cache only refreshes
# on ``extensions_config.json`` mtime changes, which
# in practice happens between graph runs — not inside
# one. And even if a refresh did happen mid-run, the
# already-built lead agent's ``ToolNode`` still holds
# the *previous* tool set (LangGraph binds tools at
# graph construction time), so a brand-new MCP tool
# couldn't actually be invoked anyway. The
# ``DeferredToolRegistry`` doesn't retain the names
# of previously-promoted tools (``promote()`` drops
# the entry entirely), so re-syncing the registry
# against a fresh ``mcp_tools`` list would
# mis-classify those promotions as new tools and
# re-register them as deferred — exactly the bug
# this fix exists to prevent.
existing_registry = get_deferred_registry()
if existing_registry is None:
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
else:
mcp_tool_names = {t.name for t in mcp_tools}
still_deferred = len(existing_registry)
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
builtin_tools.append(tool_search_tool) builtin_tools.append(tool_search_tool)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
except ImportError: except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.") logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e: except Exception as e:
-93
View File
@@ -4,8 +4,6 @@ Sets up sys.path and pre-mocks modules that would cause circular import
issues when unit-testing lightweight config/registry code in isolation. issues when unit-testing lightweight config/registry code in isolation.
""" """
from __future__ import annotations
import importlib.util import importlib.util
import sys import sys
from pathlib import Path from pathlib import Path
@@ -13,16 +11,11 @@ from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
# Make 'app' and 'deerflow' importable from any working directory # Make 'app' and 'deerflow' importable from any working directory
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
# Break the circular import chain that exists in production code: # Break the circular import chain that exists in production code:
# deerflow.subagents.__init__ # deerflow.subagents.__init__
# -> .executor (SubagentExecutor, SubagentResult) # -> .executor (SubagentExecutor, SubagentResult)
@@ -63,92 +56,6 @@ def provisioner_module():
return module return module
@pytest.fixture()
def blocking_io_detector():
"""Fail a focused test if blocking calls run on the event loop thread."""
with detect_blocking_io(fail_on_exit=True) as detector:
yield detector
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("blocking-io")
group.addoption(
"--detect-blocking-io",
action="store_true",
default=False,
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
)
group.addoption(
"--detect-blocking-io-fail",
action="store_true",
default=False,
help="Set a failing exit status when --detect-blocking-io records violations.",
)
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
def pytest_sessionstart(session: pytest.Session) -> None:
if _blocking_io_probe_enabled(session.config):
_blocking_io_probe.clear()
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item: pytest.Item):
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
yield
return
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
detector.__enter__()
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item: pytest.Item):
yield
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
if detector is None:
return
try:
detector.__exit__(None, None, None)
_blocking_io_probe.record(item.nodeid, detector.violations)
finally:
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
def pytest_sessionfinish(session: pytest.Session) -> None:
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
session.exitstatus = pytest.ExitCode.TESTS_FAILED
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
if not _blocking_io_probe_enabled(terminalreporter.config):
return
header, *details = _blocking_io_probe.format_summary().splitlines()
terminalreporter.write_sep("=", header)
for line in details:
terminalreporter.write_line(line)
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io-fail"))
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Auto-set user context for every test unless marked no_auto_user # Auto-set user context for every test unless marked no_auto_user
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
-1
View File
@@ -1 +0,0 @@
"""Shared test support helpers."""
@@ -1 +0,0 @@
"""Runtime and static detectors used by tests."""
@@ -1,287 +0,0 @@
"""Test helper for detecting blocking calls on an asyncio event loop.
The detector is intentionally test-only. It monkeypatches a small set of
well-known blocking entry points and their already-loaded module-level aliases,
then records calls only when they happen on a thread that is currently running
an asyncio event loop. Aliases captured in closures or default arguments remain
out of scope.
"""
from __future__ import annotations
import asyncio
import importlib
import sys
import traceback
from collections import Counter
from collections.abc import Callable, Iterable, Iterator
from contextlib import AbstractContextManager
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from types import TracebackType
from typing import Any
BlockingCallable = Callable[..., Any]
@dataclass(frozen=True)
class BlockingCallSpec:
"""Describes one blocking callable to wrap during a detector run."""
name: str
target: str
record_on_iteration: bool = False
@dataclass(frozen=True)
class BlockingCall:
"""One blocking call observed on an asyncio event loop thread."""
name: str
target: str
stack: tuple[traceback.FrameSummary, ...]
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
BlockingCallSpec("time.sleep", "time:sleep"),
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
)
def _is_event_loop_thread() -> bool:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return False
return loop.is_running()
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
module_name, attr_path = target.split(":", maxsplit=1)
owner: object = importlib.import_module(module_name)
parts = attr_path.split(".")
for part in parts[:-1]:
owner = getattr(owner, part)
attr_name = parts[-1]
original = getattr(owner, attr_name)
return owner, attr_name, original
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
return tuple(frame for frame in stack if frame.filename != __file__)
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
"""Record blocking calls made from async runtime code.
By default the detector reports violations but does not fail on context
exit. Tests can set ``fail_on_exit=True`` or call
``assert_no_blocking_calls()`` explicitly.
"""
def __init__(
self,
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> None:
self._specs = tuple(specs)
self._fail_on_exit = fail_on_exit
self._patch_loaded_aliases_enabled = patch_loaded_aliases
self._stack_limit = stack_limit
self._patches: list[tuple[object, str, BlockingCallable]] = []
self._patch_keys: set[tuple[int, str]] = set()
self.violations: list[BlockingCall] = []
self._active = False
def __enter__(self) -> BlockingIODetector:
try:
self._active = True
alias_replacements: dict[int, BlockingCallable] = {}
for spec in self._specs:
owner, attr_name, original = _resolve_target(spec.target)
wrapper = self._wrap(spec, original)
self._patch_attribute(owner, attr_name, original, wrapper)
alias_replacements[id(original)] = wrapper
if self._patch_loaded_aliases_enabled:
self._patch_loaded_module_aliases(alias_replacements)
except Exception:
self._restore()
self._active = False
raise
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback_value: TracebackType | None,
) -> bool | None:
self._restore()
self._active = False
if exc_type is None and self._fail_on_exit:
self.assert_no_blocking_calls()
return None
def _restore(self) -> None:
for owner, attr_name, original in reversed(self._patches):
setattr(owner, attr_name, original)
self._patches.clear()
self._patch_keys.clear()
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
key = (id(owner), attr_name)
if key in self._patch_keys:
return
setattr(owner, attr_name, replacement)
self._patches.append((owner, attr_name, original))
self._patch_keys.add(key)
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
for module in tuple(sys.modules.values()):
namespace = getattr(module, "__dict__", None)
if not isinstance(namespace, dict):
continue
for attr_name, value in tuple(namespace.items()):
replacement = replacements_by_id.get(id(value))
if replacement is not None:
self._patch_attribute(module, attr_name, value, replacement)
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
@wraps(original)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if spec.record_on_iteration:
result = original(*args, **kwargs)
return self._wrap_iteration(spec, result)
self._record_if_blocking(spec)
return original(*args, **kwargs)
return wrapper
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
iterator = iter(iterable)
reported = False
while True:
if not reported:
reported = self._record_if_blocking(spec)
try:
yield next(iterator)
except StopIteration:
return
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
if self._active and _is_event_loop_thread():
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
self.violations.append(BlockingCall(spec.name, spec.target, stack))
return True
return False
def assert_no_blocking_calls(self) -> None:
if self.violations:
raise AssertionError(format_blocking_calls(self.violations))
class BlockingIOProbe:
"""Collect detector output across tests and format a compact summary."""
def __init__(self, project_root: Path) -> None:
self._project_root = project_root.resolve()
self._observed: list[tuple[str, BlockingCall]] = []
@property
def violation_count(self) -> int:
return len(self._observed)
@property
def test_count(self) -> int:
return len({nodeid for nodeid, _violation in self._observed})
def clear(self) -> None:
self._observed.clear()
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
for violation in violations:
self._observed.append((nodeid, violation))
def format_summary(self, *, limit: int = 30) -> str:
if not self._observed:
return "blocking io probe: no violations"
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
for _nodeid, violation in self._observed:
frame = self._local_call_site(violation.stack)
if frame is None:
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
continue
call_sites[
(
violation.name,
self._relative(frame.filename),
frame.lineno,
frame.name,
(frame.line or "").strip(),
)
] += 1
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
return "\n".join(lines)
def _relative(self, filename: str) -> str:
try:
return str(Path(filename).resolve().relative_to(self._project_root))
except ValueError:
return filename
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
if local_frames:
return local_frames[-1]
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
return test_frames[-1] if test_frames else None
def detect_blocking_io(
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> BlockingIODetector:
"""Create a detector context manager for a focused test scope."""
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
"""Format detector output with enough stack context to locate call sites."""
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
for index, violation in enumerate(violations, start=1):
lines.append(f"{index}. {violation.name} ({violation.target})")
lines.extend(_format_stack(violation.stack))
return "\n".join(lines)
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
for frame in stack:
location = f"{frame.filename}:{frame.lineno}"
lines = [f" at {frame.name} ({location})"]
if frame.line:
lines.append(f" {frame.line.strip()}")
yield from lines
-190
View File
@@ -1,190 +0,0 @@
from __future__ import annotations
import asyncio
import os
import time
from os import walk as imported_walk
from pathlib import Path
from time import sleep as imported_sleep
import httpx
import pytest
import requests
from support.detectors.blocking_io import (
BlockingCallSpec,
BlockingIOProbe,
detect_blocking_io,
)
pytestmark = pytest.mark.asyncio
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
async def test_records_time_sleep_on_event_loop() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
original_alias = imported_sleep
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
imported_sleep(0)
assert imported_sleep is original_alias
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_can_disable_loaded_alias_patching() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
imported_sleep(0)
assert detector.violations == []
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
await asyncio.to_thread(time.sleep, 0)
assert detector.violations == []
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
await asyncio.to_thread(time.sleep, 0)
assert blocking_io_detector.violations == []
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
def call_sleep() -> list[str]:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
return [violation.name for violation in detector.violations]
assert await asyncio.to_thread(call_sleep) == []
async def test_fail_on_exit_includes_call_site() -> None:
with pytest.raises(AssertionError) as exc_info:
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
time.sleep(0)
message = str(exc_info.value)
assert "time.sleep" in message
assert "test_fail_on_exit_includes_call_site" in message
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
return f"{method}:{url}"
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
with detect_blocking_io(REQUESTS_ONLY) as detector:
assert requests.get("https://example.invalid") == "get:https://example.invalid"
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
return httpx.Response(200, request=httpx.Request(method, url))
monkeypatch.setattr(httpx.Client, "request", fake_request)
with detect_blocking_io(HTTPX_ONLY) as detector:
with httpx.Client() as client:
response = client.get("https://example.invalid")
assert response.status_code == 200
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(os.walk(tmp_path))
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
original_alias = imported_walk
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(imported_walk(tmp_path))
assert imported_walk is original_alias
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert list(walker)
assert detector.violations == []
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert await asyncio.to_thread(lambda: list(walker))
assert detector.violations == []
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
assert path.read_text(encoding="utf-8") == "content"
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
summary = probe.format_summary()
assert "blocking io probe: 1 violations across 1 tests" in summary
assert "pathlib.Path.read_text" in summary
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
assert probe.format_summary() == "blocking io probe: no violations"
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
assert probe.violation_count == 1
probe.clear()
assert probe.violation_count == 0
assert probe.format_summary() == "blocking io probe: no violations"
@@ -1,22 +0,0 @@
from __future__ import annotations
import time
import pytest
ORIGINAL_SLEEP = time.sleep
def replacement_sleep(seconds: float) -> None:
return None
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(time, "sleep", replacement_sleep)
assert time.sleep is replacement_sleep
@pytest.mark.no_blocking_io_probe
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
assert time.sleep is ORIGINAL_SLEEP
assert getattr(time.sleep, "__wrapped__", None) is None
@@ -1,222 +0,0 @@
"""Real-LLM end-to-end verification for issue #2884.
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
and the production ``get_available_tools`` pipeline. The only thing we mock is
the MCP tool source — we hand-roll two ``@tool``s and inject them through
``deerflow.mcp.cache.get_cached_mcp_tools``.
The flow exercised:
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
that re-enters ``get_available_tools`` on the same task — this is the
code path issue #2884 reports). It must call ``tool_search`` to
discover the deferred ``fake_calculator`` tool.
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
``fake_subagent_trigger`` re-enters ``get_available_tools``.
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
so it can actually call it. Without this PR's fix, the re-entry wipes
the promotion and the model can no longer invoke the tool.
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
test run. Run with::
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
PYTHONPATH=. uv run pytest \
tests/test_deferred_tool_promotion_real_llm.py -v -s
"""
from __future__ import annotations
import os
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool as as_tool
# ---------------------------------------------------------------------------
# Skip control: only run when explicitly opted in.
# ---------------------------------------------------------------------------
pytestmark = pytest.mark.skipif(
os.getenv("ONEAPI_E2E") != "1",
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
)
# ---------------------------------------------------------------------------
# Fake "MCP" tools the agent should discover via tool_search.
# Keep them obviously synthetic so the model can pattern-match the search.
# ---------------------------------------------------------------------------
_calls: list[str] = []
@as_tool
def fake_calculator(expression: str) -> str:
"""Evaluate a tiny arithmetic expression like '2 + 2'.
Reserved for the user — only call this if the user asks for arithmetic.
"""
_calls.append(f"fake_calculator:{expression}")
try:
# Trivially safe-eval just for the e2e check
allowed = set("0123456789+-*/() .")
if not set(expression) <= allowed:
return "expression contains disallowed characters"
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
except Exception as e:
return f"error: {e}"
@as_tool
def fake_translator(text: str, target_lang: str) -> str:
"""Translate text into the given language code. Decorative — not used."""
_calls.append(f"fake_translator:{text}:{target_lang}")
return f"[{target_lang}] {text}"
# ---------------------------------------------------------------------------
# Pipeline wiring (same shape as the in-process tests).
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_registry_between_tests():
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Build a minimal mock AppConfig and patch the symbol — never call the
real loader, which would trigger ``_apply_singleton_configs`` and
permanently mutate cross-test singletons (memory, title, …)."""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Real-LLM e2e test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
"""End-to-end against a real OpenAI-compatible LLM.
The model must:
Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
``fake_subagent_trigger(...)``.
Turn 2 — call ``fake_calculator`` and finish.
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
layer — recorded in ``_calls`` — which proves the model received the
promoted schema after the re-entrant ``get_available_tools`` call.
"""
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
_force_tool_search_enabled(monkeypatch)
_calls.clear()
@as_tool
async def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
Use this whenever the user asks you to delegate work — pass a short
description as ``prompt``.
"""
# ``task_tool`` does this internally. Whether the registry-reset that
# used to happen here actually leaks back to the parent task depends
# on asyncio's implicit context-copying semantics (gather creates
# child tasks with copied contexts, so reset_deferred_registry is
# task-local) — but the fix in this PR is what GUARANTEES the
# promotion sticks regardless of which integration path triggers a
# re-entrant ``get_available_tools`` call.
get_available_tools(subagent_enabled=False)
_calls.append(f"fake_subagent_trigger:{prompt}")
return "subagent completed"
tools = get_available_tools() + [fake_subagent_trigger]
model = ChatOpenAI(
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
temperature=0,
max_retries=1,
)
system_prompt = (
"You are a meticulous assistant. Available deferred tools include a "
"calculator and a translator — their schemas are hidden until you "
"search for them via tool_search.\n\n"
"Procedure for the user's request:\n"
" 1. Call tool_search with query 'select:fake_calculator' AND "
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
"to delegate the side work. Put both tool_calls in your first response.\n"
" 2. After both tool messages come back, call fake_calculator with "
"the user's expression.\n"
" 3. Reply with just the numeric result."
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt=system_prompt,
)
result = await graph.ainvoke(
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
config={"recursion_limit": 12},
)
print("\n=== tool calls recorded ===")
for c in _calls:
print(f" {c}")
print("\n=== final message ===")
final_text = result["messages"][-1].content if result["messages"] else "(none)"
print(f" {final_text!r}")
# The smoking-gun assertion: fake_calculator was actually invoked at the
# tool layer. This is only possible if the promoted schema reached the
# model in turn 2, despite the subagent-style re-entry in turn 1.
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
# And the math should actually be done correctly (sanity that the LLM
# really used the result, not just hallucinated the answer).
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
@@ -1,390 +0,0 @@
"""Reproduce + regression-guard issue #2884.
Hypothesis from the issue:
``tools.tools.get_available_tools`` unconditionally calls
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
every time it is invoked. If anything calls ``get_available_tools`` again
during the same async context (after the agent has promoted tools via
``tool_search``), the promotion is wiped and the next model call hides the
tool's schema again.
These tests pin two things:
A. **At the unit boundary** — verify the failure mode directly. Promote a
tool in the registry, then call ``get_available_tools`` again and observe
that the ContextVar registry is reset and the promotion is lost.
B. **At the graph-execution boundary** — drive a real ``create_agent`` graph
with the real ``DeferredToolFilterMiddleware`` through two model turns.
The first turn calls ``tool_search`` which promotes a tool. The second
turn must see that tool's schema in ``request.tools``. If
``get_available_tools`` were to run again between the two turns and reset
the registry, the second turn's filter would strip the tool.
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
unmodified; mock only the LLM and the MCP tool source. Patch
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
``get_available_tools`` resolves via lazy import) to return our fixture
tools so we don't need a real MCP server.
"""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
# ---------------------------------------------------------------------------
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
# ---------------------------------------------------------------------------
@as_tool
def fake_mcp_search(query: str) -> str:
"""Pretend to search a knowledge base for the given query."""
return f"results for {query}"
@as_tool
def fake_mcp_fetch(url: str) -> str:
"""Pretend to fetch a page at the given URL."""
return f"content of {url}"
@pytest.fixture(autouse=True)
def _supply_env(monkeypatch: pytest.MonkeyPatch):
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
@pytest.fixture(autouse=True)
def _reset_deferred_registry_between_tests():
"""Each test must start with a clean ContextVar.
The registry lives in a module-level ContextVar with no per-task isolation
in a synchronous test runner, so one test's promotion can leak into the
next and silently break filter assertions.
"""
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
"""Make get_available_tools believe an MCP server is registered.
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
that both ``AppConfig.from_file`` (which calls
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
see a valid instance. Then point the MCP tool cache at our fixture tools.
"""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Force config.tool_search.enabled=True without touching the yaml.
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
which permanently mutates module-level singletons (``_memory_config``,
``_title_config``, …) to match the developer's ``config.yaml`` — even
after pytest restores our patch. That leaks across tests later in the
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
require ``_memory_config.enabled = True``, which is the dataclass default
but FALSE in the actual yaml).
Build a minimal mock AppConfig instead and never call the real loader.
"""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Section A — direct unit-level reproduction
# ---------------------------------------------------------------------------
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
Step 1: call get_available_tools() — registers MCP tools as deferred.
Step 2: simulate the agent calling tool_search by promoting one tool.
Step 3: call get_available_tools() again (the same code path
``task_tool`` exercises mid-run).
Assertion: after step 3, the promoted tool is STILL promoted (not
re-deferred). On ``main`` before the fix, step 3's
``reset_deferred_registry()`` wiped the promotion and re-registered
every MCP tool as deferred — this assertion fired with
``REGRESSION (#2884)``.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# Step 1: first call — both MCP tools start deferred
get_available_tools()
reg1 = get_deferred_registry()
assert reg1 is not None
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
# Step 2: simulate tool_search promoting one of them
reg1.promote({"fake_mcp_search"})
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
# Step 3: second call — registry must NOT silently undo the promotion
get_available_tools()
reg2 = get_deferred_registry()
assert reg2 is not None
deferred_after = {e.name for e in reg2.entries}
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
# ---------------------------------------------------------------------------
# Section B — graph-execution reproduction
# ---------------------------------------------------------------------------
class _ToolSearchPromotingModel(FakeToolCallingModel):
"""Two-turn model that:
Turn 1 → emit a tool_call for ``tool_search`` (the real one)
Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool)
Records the tools it received on each turn so the test can inspect what
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
"""
bound_tools_per_turn: list[list[str]] = []
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
# Record the tool names the model would see in this turn
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
self.bound_tools_per_turn.append(names)
return self
def _build_promoting_model() -> _ToolSearchPromotingModel:
return _ToolSearchPromotingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
}
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
"""End-to-end: drive a real create_agent graph through two turns.
Without the fix, the second-turn bind_tools call should NOT contain
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
registry and strips it). With the fix, the model sees the schema and can
invoke it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
tools = get_available_tools()
# Sanity: the assembled tool list includes the deferred tools (they're in
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
# they reach the model)
tool_names = {getattr(t, "name", "") for t in tools}
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
model = _build_promoting_model()
model.bound_tools_per_turn = [] # reset class-level recorder
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
turn1 = set(model.bound_tools_per_turn[0])
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
# This is the load-bearing assertion for issue #2884.
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
turn2 = set(model.bound_tools_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
# ---------------------------------------------------------------------------
# Section C — the actual issue #2884 trigger: a re-entrant
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
# wipe the parent's promotion.
# ---------------------------------------------------------------------------
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
(the same pattern that happens when ``task_tool`` builds a subagent's
toolset mid-run) must not wipe the parent agent's tool_search promotions.
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
``get_available_tools`` again — exactly what ``task_tool`` does when it
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
promoted tool. Without the fix, the re-entry wipes the registry and
the filter re-hides it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# The trigger tool simulates what task_tool does internally: rebuild the
# toolset by calling get_available_tools while the registry is live.
@as_tool
def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
get_available_tools(subagent_enabled=False)
return f"spawned subagent for: {prompt}"
tools = get_available_tools() + [fake_subagent_trigger]
bound_per_turn: list[list[str]] = []
class _Model(FakeToolCallingModel):
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
return self
model = _Model(
responses=[
# Turn 1: do both in one batch — promote AND trigger the
# subagent-style rebuild. LangGraph executes them in order in the
# same agent step.
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
},
{
"name": "fake_subagent_trigger",
"args": {"prompt": "go"},
"id": "call_trigger_1",
"type": "tool_call",
},
],
),
# Turn 2: try to invoke the promoted tool. The model gets this
# turn only if turn 1's bind_tools recorded what the filter sent.
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-subagent-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1 sanity: deferred tool not visible yet
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
# re-entrant get_available_tools call that happened in turn 1's tool batch.
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
turn2 = set(bound_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
@@ -3,7 +3,6 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id(): def test_conversation_context_has_user_id():
@@ -18,7 +17,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id(): def test_queue_add_stores_user_id():
q = MemoryUpdateQueue() q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): with patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice") q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1 assert len(q._queue) == 1
assert q._queue[0].user_id == "alice" assert q._queue[0].user_id == "alice"
@@ -27,7 +26,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater(): def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue() q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): with patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice") q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock() mock_updater = MagicMock()
-48
View File
@@ -3,10 +3,7 @@
Uses a temp SQLite DB to test ORM-backed CRUD operations. Uses a temp SQLite DB to test ORM-backed CRUD operations.
""" """
import re
import pytest import pytest
from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository from deerflow.persistence.run import RunRepository
@@ -281,48 +278,3 @@ class TestRunRepository:
assert row4["model_name"] is None assert row4["model_name"] is None
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
captured = []
class FakeResult:
def all(self):
return []
class FakeSession:
async def execute(self, stmt):
captured.append(stmt)
return FakeResult()
class FakeSessionContext:
async def __aenter__(self):
return FakeSession()
async def __aexit__(self, exc_type, exc, tb):
return None
repo = RunRepository(lambda: FakeSessionContext())
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg == {
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_runs": 0,
"by_model": {},
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
}
assert len(captured) == 1
stmt = captured[0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
assert select_match is not None
assert group_by_match is not None
assert select_match.group(1) == group_by_match.group(1)
-153
View File
@@ -59,15 +59,12 @@ def _make_result(
ai_messages: list[dict] | None = None, ai_messages: list[dict] | None = None,
result: str | None = None, result: str | None = None,
error: str | None = None, error: str | None = None,
token_usage_records: list[dict] | None = None,
) -> SimpleNamespace: ) -> SimpleNamespace:
return SimpleNamespace( return SimpleNamespace(
status=status, status=status,
ai_messages=ai_messages or [], ai_messages=ai_messages or [],
result=result, result=result,
error=error, error=error,
token_usage_records=token_usage_records or [],
usage_reported=False,
) )
@@ -1135,153 +1132,3 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
assert len(report_calls) == 1 assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"] assert cleanup_calls == ["tc-cancel-report"]
@pytest.mark.parametrize(
"status, expected_type",
[
(FakeSubagentStatus.COMPLETED, "task_completed"),
(FakeSubagentStatus.FAILED, "task_failed"),
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
],
)
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
"""Terminal task events include a usage summary from token_usage_records."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
records = [
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
]
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-usage",
)
terminal_events = [e for e in events if e["type"] == expected_type]
assert len(terminal_events) == 1
assert terminal_events[0]["usage"] == {
"input_tokens": 300,
"output_tokens": 130,
"total_tokens": 430,
}
def test_terminal_event_usage_none_when_no_records(monkeypatch):
"""Terminal event has usage=None when token_usage_records is empty."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-no-records",
)
completed = [e for e in events if e["type"] == "task_completed"]
assert len(completed) == 1
assert completed[0]["usage"] is None
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=FileNotFoundError("missing config")),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
runtime = _make_runtime(app_config=app_config)
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
task_tool_module._subagent_usage_cache.clear()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-disabled-cache",
)
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
runtime = _make_runtime(app_config=app_config)
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
with pytest.raises(RuntimeError, match="poll failed"):
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-error",
)
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
+1 -48
View File
@@ -1,10 +1,9 @@
"""Tests for TokenUsageMiddleware attribution annotations.""" """Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging import logging
from unittest.mock import MagicMock from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.token_usage_middleware import ( from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY, TOKEN_USAGE_ATTRIBUTION_KEY,
@@ -233,49 +232,3 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove", "tool_call_id": "write_todos:remove",
} }
] ]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}
+8 -4
View File
@@ -65,7 +65,8 @@ def _make_minimal_config(tools):
@patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): @patch("deerflow.tools.tools.reset_deferred_registry")
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
"""Config-loaded async-only tools can still be invoked by sync clients.""" """Config-loaded async-only tools can still be invoked by sync clients."""
async def async_tool_impl(x: int) -> str: async def async_tool_impl(x: int) -> str:
@@ -97,7 +98,8 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_no_duplicates_returned(mock_bash, mock_cfg): @patch("deerflow.tools.tools.reset_deferred_registry")
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
"""get_available_tools() never returns two tools with the same name.""" """get_available_tools() never returns two tools with the same name."""
mock_cfg.return_value = _make_minimal_config([]) mock_cfg.return_value = _make_minimal_config([])
@@ -111,7 +113,8 @@ def test_no_duplicates_returned(mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_first_occurrence_wins(mock_bash, mock_cfg): @patch("deerflow.tools.tools.reset_deferred_registry")
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
"""When duplicates exist, the first occurrence is kept.""" """When duplicates exist, the first occurrence is kept."""
mock_cfg.return_value = _make_minimal_config([]) mock_cfg.return_value = _make_minimal_config([])
@@ -129,7 +132,8 @@ def test_first_occurrence_wins(mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog): @patch("deerflow.tools.tools.reset_deferred_registry")
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
"""A warning is logged for every skipped duplicate.""" """A warning is logged for every skipped duplicate."""
import logging import logging
+3 -3
View File
@@ -2005,7 +2005,7 @@ wheels = [
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.8.0" version = "0.7.36"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "httpx" }, { name = "httpx" },
@@ -2018,9 +2018,9 @@ dependencies = [
{ name = "xxhash" }, { name = "xxhash" },
{ name = "zstandard" }, { name = "zstandard" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" } sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" }, { url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" },
] ]
[package.optional-dependencies] [package.optional-dependencies]
+4 -3
View File
@@ -10,6 +10,7 @@ import { FlickeringGrid } from "@/components/ui/flickering-grid";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { useAuth } from "@/core/auth/AuthProvider"; import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types"; import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
/** /**
* Validate next parameter * Validate next parameter
@@ -71,7 +72,7 @@ export default function LoginPage() {
useEffect(() => { useEffect(() => {
let cancelled = false; let cancelled = false;
void fetch("/api/v1/auth/setup-status") void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
.then((r) => r.json()) .then((r) => r.json())
.then((data: { needs_setup?: boolean }) => { .then((data: { needs_setup?: boolean }) => {
if (!cancelled && data.needs_setup) { if (!cancelled && data.needs_setup) {
@@ -94,8 +95,8 @@ export default function LoginPage() {
try { try {
const endpoint = isLogin const endpoint = isLogin
? "/api/v1/auth/login/local" ? `${getBackendBaseURL()}/api/v1/auth/login/local`
: "/api/v1/auth/register"; : `${getBackendBaseURL()}/api/v1/auth/register`;
const body = isLogin const body = isLogin
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}` ? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
: JSON.stringify({ email, password }); : JSON.stringify({ email, password });
+18 -14
View File
@@ -10,6 +10,7 @@ import { Input } from "@/components/ui/input";
import { getCsrfHeaders } from "@/core/api/fetcher"; import { getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider"; import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types"; import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
type SetupMode = "loading" | "init_admin" | "change_password"; type SetupMode = "loading" | "init_admin" | "change_password";
@@ -36,7 +37,7 @@ export default function SetupPage() {
setMode("change_password"); setMode("change_password");
} else if (!isAuthenticated) { } else if (!isAuthenticated) {
// Check if the system has no users yet // Check if the system has no users yet
void fetch("/api/v1/auth/setup-status") void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
.then((r) => r.json()) .then((r) => r.json())
.then((data: { needs_setup?: boolean }) => { .then((data: { needs_setup?: boolean }) => {
if (cancelled) return; if (cancelled) return;
@@ -72,7 +73,7 @@ export default function SetupPage() {
setLoading(true); setLoading(true);
try { try {
const res = await fetch("/api/v1/auth/initialize", { const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/initialize`, {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
credentials: "include", credentials: "include",
@@ -113,19 +114,22 @@ export default function SetupPage() {
setLoading(true); setLoading(true);
try { try {
const res = await fetch("/api/v1/auth/change-password", { const res = await fetch(
method: "POST", `${getBackendBaseURL()}/api/v1/auth/change-password`,
headers: { {
"Content-Type": "application/json", method: "POST",
...getCsrfHeaders(), headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
credentials: "include",
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
}, },
credentials: "include", );
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
});
if (!res.ok) { if (!res.ok) {
const data = await res.json(); const data = await res.json();
+2 -1
View File
@@ -4,6 +4,7 @@ import { redirect } from "next/navigation";
import { AuthProvider } from "@/core/auth/AuthProvider"; import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server"; import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types"; import { assertNever } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
import { WorkspaceContent } from "./workspace-content"; import { WorkspaceContent } from "./workspace-content";
@@ -44,7 +45,7 @@ export default async function WorkspaceLayout({
Retry Retry
</Link> </Link>
<Link <Link
href="/api/v1/auth/logout" href={`${getBackendBaseURL()}/api/v1/auth/logout`}
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm" className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
> >
Logout &amp; Reset Logout &amp; Reset
@@ -12,11 +12,13 @@ function TokenUsageSummary({
inputTokens, inputTokens,
outputTokens, outputTokens,
totalTokens, totalTokens,
unavailable = false,
}: { }: {
className?: string; className?: string;
inputTokens?: number; inputTokens?: number;
outputTokens?: number; outputTokens?: number;
totalTokens?: number; totalTokens?: number;
unavailable?: boolean;
}) { }) {
const { t } = useI18n(); const { t } = useI18n();
@@ -31,15 +33,21 @@ function TokenUsageSummary({
<CoinsIcon className="size-3" /> <CoinsIcon className="size-3" />
{t.tokenUsage.label} {t.tokenUsage.label}
</span> </span>
<span> {!unavailable ? (
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} <>
</span> <span>
<span> {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} </span>
</span> <span>
<span className="font-medium"> {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} </span>
</span> <span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</>
) : (
<span>{t.tokenUsage.unavailableShort}</span>
)}
</div> </div>
); );
} }
@@ -47,7 +55,7 @@ function TokenUsageSummary({
export function MessageTokenUsageList({ export function MessageTokenUsageList({
className, className,
enabled = false, enabled = false,
isLoading: _isLoading = false, isLoading = false,
messages, messages,
}: { }: {
className?: string; className?: string;
@@ -55,7 +63,7 @@ export function MessageTokenUsageList({
isLoading?: boolean; isLoading?: boolean;
messages: Message[]; messages: Message[];
}) { }) {
if (!enabled) { if (!enabled || isLoading) {
return null; return null;
} }
@@ -67,16 +75,13 @@ export function MessageTokenUsageList({
const usage = accumulateUsage(aiMessages); const usage = accumulateUsage(aiMessages);
if (!usage) {
return null;
}
return ( return (
<TokenUsageSummary <TokenUsageSummary
className={className} className={className}
inputTokens={usage.inputTokens} inputTokens={usage?.inputTokens}
outputTokens={usage.outputTokens} outputTokens={usage?.outputTokens}
totalTokens={usage.totalTokens} totalTokens={usage?.totalTokens}
unavailable={!usage}
/> />
); );
} }
@@ -8,6 +8,7 @@ import { Input } from "@/components/ui/input";
import { fetch, getCsrfHeaders } from "@/core/api/fetcher"; import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider"; import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types"; import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { SettingsSection } from "./settings-section"; import { SettingsSection } from "./settings-section";
@@ -38,17 +39,20 @@ export function AccountSettingsPage() {
setLoading(true); setLoading(true);
try { try {
const res = await fetch("/api/v1/auth/change-password", { const res = await fetch(
method: "POST", `${getBackendBaseURL()}/api/v1/auth/change-password`,
headers: { {
"Content-Type": "application/json", method: "POST",
...getCsrfHeaders(), headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
}),
}, },
body: JSON.stringify({ );
current_password: currentPassword,
new_password: newPassword,
}),
});
if (!res.ok) { if (!res.ok) {
const data = await res.json(); const data = await res.json();
+4 -2
View File
@@ -10,6 +10,8 @@ import React, {
type ReactNode, type ReactNode,
} from "react"; } from "react";
import { getBackendBaseURL } from "@/core/config";
import { type User, buildLoginUrl } from "./types"; import { type User, buildLoginUrl } from "./types";
// Re-export for consumers // Re-export for consumers
@@ -56,7 +58,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
const refreshUser = useCallback(async () => { const refreshUser = useCallback(async () => {
try { try {
setIsLoading(true); setIsLoading(true);
const res = await fetch("/api/v1/auth/me", { const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/me`, {
credentials: "include", credentials: "include",
}); });
@@ -88,7 +90,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
setUser(null); setUser(null);
try { try {
await fetch("/api/v1/auth/logout", { await fetch(`${getBackendBaseURL()}/api/v1/auth/logout`, {
method: "POST", method: "POST",
credentials: "include", credentials: "include",
}); });
+2 -2
View File
@@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
return hasUsage ? cumulative : null; return hasUsage ? cumulative : null;
} }
export function hasNonZeroUsage( function hasNonZeroUsage(
usage: TokenUsage | null | undefined, usage: TokenUsage | null | undefined,
): usage is TokenUsage { ): usage is TokenUsage {
return ( return (
@@ -75,7 +75,7 @@ export function hasNonZeroUsage(
); );
} }
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
return { return {
inputTokens: base.inputTokens + delta.inputTokens, inputTokens: base.inputTokens + delta.inputTokens,
outputTokens: base.outputTokens + delta.outputTokens, outputTokens: base.outputTokens + delta.outputTokens,
@@ -1,49 +0,0 @@
import type { Message } from "@langchain/langgraph-sdk";
/**
* Deduplicate incoming messages against an existing history.
* A message is considered a duplicate if its `id` or `tool_call_id`
* (for tool messages) already appears in the existing list.
*/
export function deduplicateHistoryMessages(
existing: Message[],
incoming: Message[],
): Message[] {
const existingIds = new Set(
existing
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
.filter(Boolean),
);
return incoming.filter((m) => {
if (m.id && existingIds.has(m.id)) return false;
if (
"tool_call_id" in m &&
m.tool_call_id &&
existingIds.has(m.tool_call_id)
) {
return false;
}
return true;
});
}
/**
* Compute the new history-loading index when the runs list grows.
*
* - `currentIndex < 0` means all previously-known runs have been loaded;
* reset to the last run so the user can scroll up to load new runs.
* - `currentIndex >= 0` means some runs haven't been loaded yet;
* shift the index by the number of newly-added runs.
* - If no new runs were added, return `currentIndex` unchanged.
*/
export function adjustHistoryIndex(
currentIndex: number,
prevRunsLength: number,
newRunsLength: number,
): number {
const added = newRunsLength - prevRunsLength;
if (added <= 0) return currentIndex;
if (currentIndex < 0) return newRunsLength - 1;
return currentIndex + added;
}
+9 -56
View File
@@ -18,10 +18,6 @@ import type { UploadedFileInfo } from "../uploads";
import { promptInputFilePartToFile, uploadFiles } from "../uploads"; import { promptInputFilePartToFile, uploadFiles } from "../uploads";
import { fetchThreadTokenUsage } from "./api"; import { fetchThreadTokenUsage } from "./api";
import {
adjustHistoryIndex,
deduplicateHistoryMessages,
} from "./history-utils";
import { threadTokenUsageQueryKey } from "./token-usage"; import { threadTokenUsageQueryKey } from "./token-usage";
import type { import type {
AgentThread, AgentThread,
@@ -300,11 +296,7 @@ export function useThreadStream({
onError(error) { onError(error) {
setOptimisticMessages([]); setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error)); toast.error(getStreamErrorMessage(error));
pendingUsageBaselineMessageIdsRef.current = new Set( pendingUsageBaselineMessageIdsRef.current = new Set();
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
if (threadIdRef.current && !isMock) { if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({ void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current), queryKey: threadTokenUsageQueryKey(threadIdRef.current),
@@ -313,16 +305,9 @@ export function useThreadStream({
}, },
onFinish(state) { onFinish(state) {
listeners.current.onFinish?.(state.values); listeners.current.onFinish?.(state.values);
pendingUsageBaselineMessageIdsRef.current = new Set( pendingUsageBaselineMessageIdsRef.current = new Set();
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
if (threadIdRef.current && !isMock) { if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: ["thread", threadIdRef.current],
});
void queryClient.invalidateQueries({ void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current), queryKey: threadTokenUsageQueryKey(threadIdRef.current),
}); });
@@ -354,11 +339,7 @@ export function useThreadStream({
useEffect(() => { useEffect(() => {
startedRef.current = false; startedRef.current = false;
sendInFlightRef.current = false; sendInFlightRef.current = false;
pendingUsageBaselineMessageIdsRef.current = new Set( pendingUsageBaselineMessageIdsRef.current = new Set();
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
prevHumanMsgCountRef.current = prevHumanMsgCountRef.current =
latestMessageCountsRef.current.humanMessageCount; latestMessageCountsRef.current.humanMessageCount;
}, [threadId]); }, [threadId]);
@@ -636,7 +617,6 @@ export function useThreadHistory(threadId: string) {
const loadingRef = useRef(false); const loadingRef = useRef(false);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Message[]>([]); const [messages, setMessages] = useState<Message[]>([]);
const initialLoadDoneRef = useRef(false);
loadingRef.current = loading; loadingRef.current = loading;
const loadMessages = useCallback(async () => { const loadMessages = useCallback(async () => {
@@ -664,10 +644,7 @@ export function useThreadHistory(threadId: string) {
const _messages = result.data const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:")) .filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content); .map((m) => m.content);
setMessages((prev) => { setMessages((prev) => [..._messages, ...prev]);
const deduped = deduplicateHistoryMessages(prev, _messages);
return [...deduped, ...prev];
});
indexRef.current -= 1; indexRef.current -= 1;
} catch (err) { } catch (err) {
console.error(err); console.error(err);
@@ -675,39 +652,15 @@ export function useThreadHistory(threadId: string) {
setLoading(false); setLoading(false);
} }
}, []); }, []);
// Reset state when threadId changes
useEffect(() => { useEffect(() => {
threadIdRef.current = threadId; threadIdRef.current = threadId;
runsRef.current = [];
indexRef.current = -1;
initialLoadDoneRef.current = false;
setMessages([]);
}, [threadId]);
// Load/update history when runs data changes
useEffect(() => {
if (runs.data && runs.data.length > 0) { if (runs.data && runs.data.length > 0) {
const prevLength = runsRef.current.length; runsRef.current = runs.data ?? [];
runsRef.current = runs.data; indexRef.current = runs.data.length - 1;
if (!initialLoadDoneRef.current) {
// Initial load: start from the most recent run
initialLoadDoneRef.current = true;
indexRef.current = runs.data.length - 1;
loadMessages().catch(() => {
toast.error("Failed to load thread history.");
});
} else if (runs.data.length > prevLength) {
// New runs added (e.g., after query invalidation): adjust indexRef
// so the user can load older history by scrolling up
indexRef.current = adjustHistoryIndex(
indexRef.current,
prevLength,
runs.data.length,
);
}
} }
loadMessages().catch(() => {
toast.error("Failed to load thread history.");
});
}, [threadId, runs.data, loadMessages]); }, [threadId, runs.data, loadMessages]);
const appendMessages = useCallback((_messages: Message[]) => { const appendMessages = useCallback((_messages: Message[]) => {
@@ -1,136 +0,0 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import {
adjustHistoryIndex,
deduplicateHistoryMessages,
} from "@/core/threads/history-utils";
// ---------------------------------------------------------------------------
// deduplicateHistoryMessages
// ---------------------------------------------------------------------------
test("returns all incoming messages when existing history is empty", () => {
const existing: Message[] = [];
const incoming: Message[] = [
{ type: "human", id: "m1", content: "hello" },
{ type: "ai", id: "m2", content: "hi" },
];
const result = deduplicateHistoryMessages(existing, incoming);
expect(result).toHaveLength(2);
expect(result.map((m) => m.id)).toEqual(["m1", "m2"]);
});
test("filters out messages whose id already exists in history", () => {
const existing: Message[] = [
{ type: "human", id: "m1", content: "hello" },
{ type: "ai", id: "m2", content: "hi" },
];
const incoming: Message[] = [
{ type: "human", id: "m1", content: "hello" }, // duplicate
{ type: "ai", id: "m3", content: "new" },
];
const result = deduplicateHistoryMessages(existing, incoming);
expect(result).toHaveLength(1);
expect(result[0]!.id).toBe("m3");
});
test("filters out tool messages by tool_call_id", () => {
const existing: Message[] = [
{
type: "tool",
id: "t1",
tool_call_id: "tc-1",
content: "tool result",
name: "search",
} as unknown as Message,
];
const incoming: Message[] = [
{
type: "tool",
id: "t1-dup",
tool_call_id: "tc-1",
content: "tool result",
name: "search",
} as unknown as Message,
{
type: "tool",
id: "t2",
tool_call_id: "tc-2",
content: "other result",
name: "search",
} as unknown as Message,
];
const result = deduplicateHistoryMessages(existing, incoming);
expect(result).toHaveLength(1);
expect(result[0]!.id).toBe("t2");
});
test("keeps messages with no id or tool_call_id", () => {
const existing: Message[] = [
{ type: "human", id: "m1", content: "existing" },
];
const incoming: Message[] = [
// Message without id — should be kept (not considered a duplicate)
{ type: "ai", content: "no id" } as Message,
];
const result = deduplicateHistoryMessages(existing, incoming);
expect(result).toHaveLength(1);
});
test("deduplicates against tool_call_id from existing messages", () => {
// Existing message has tool_call_id stored in the id set
const existing: Message[] = [
{
type: "tool",
id: "t0",
tool_call_id: "tc-x",
content: "result",
name: "tool",
} as unknown as Message,
];
// Incoming AI message references the same id — should be filtered
const incoming: Message[] = [{ type: "ai", id: "tc-x", content: "response" }];
const result = deduplicateHistoryMessages(existing, incoming);
expect(result).toHaveLength(0);
});
// ---------------------------------------------------------------------------
// adjustHistoryIndex
// ---------------------------------------------------------------------------
test("returns unchanged index when no new runs were added", () => {
expect(adjustHistoryIndex(2, 5, 5)).toBe(2);
expect(adjustHistoryIndex(-1, 3, 3)).toBe(-1);
expect(adjustHistoryIndex(0, 1, 0)).toBe(0); // shouldn't happen, but safe
});
test("resets to last run when all previous runs were loaded", () => {
// 3 runs existed, all loaded (index = -1), now 5 runs
const result = adjustHistoryIndex(-1, 3, 5);
expect(result).toBe(4); // last index of new runs list
});
test("shifts index by number of added runs when some are unloaded", () => {
// 3 runs, currently at index 1 (run at index 2 loaded), now 6 runs
const result = adjustHistoryIndex(1, 3, 6);
// 3 new runs added, shift: 1 + (6 - 3) = 4
expect(result).toBe(4);
});
test("handles single new run when all previous were loaded", () => {
// 4 runs, all loaded (index = -1), now 5 runs
const result = adjustHistoryIndex(-1, 4, 5);
expect(result).toBe(4);
});
test("handles transition from empty runs to populated", () => {
// 0 runs → 3 runs, all loaded (index = -1)
const result = adjustHistoryIndex(-1, 0, 3);
expect(result).toBe(2);
});