Compare commits

..

19 Commits

Author SHA1 Message Date
Willem Jiang 7752e74e2b update the code with review comments 2026-05-16 17:07:53 +08:00
Willem Jiang ba99a23814 Merge branch 'main' into fix-2804 2026-05-16 09:27:40 +08:00
Willem Jiang 6d611c2bf6 fix(auth): persist auto-generated JWT secret to survive restarts (#2933)
* fix(auth): persist auto-generated JWT secret to survive restarts

  When AUTH_JWT_SECRET is not set, the auto-generated secret is now
  written to .deer-flow/.jwt_secret (mode 0600) and reused on subsequent
  starts. This prevents session invalidation on every restart while still
  allowing explicit AUTH_JWT_SECRET in .env to take precedence.

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix the lint errors of backend

---------

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-16 09:24:40 +08:00
Nan Gao 6d3cffb4f0 fix(frontend): deduplicate restored thread messages (#2958)
* fix(frontend): fix duplicate messages when reopening agent sessions (#2957)

* make format

* fix(frontend): retry pending thread history loads
2026-05-16 08:48:19 +08:00
Yi Tang 48e038f752 feat(channels): enhance Discord with mention-only mode, thread routing, and typing indicators (#2842)
* feat(channels): enhance Discord with mention-only mode, thread routing, and typing indicators

Add mention_only config to only respond when bot is mentioned, with
allowed_channels override. Add thread_mode for Hermes-style auto-thread
creation. Add periodic typing indicators while bot is processing.

* fix(discord): include allowed_channels in mention_only skip condition (line 274)

* docs: fix Discord config example to match boolean thread_mode implementation

* style: format with ruff

* fix(discord): apply Copilot review fixes and resolve lint errors

- Remove unused Optional import
- Fix thread_ts type hints to str | None
- Fix has_mention logic for None values
- Implement thread_mode fallback to channel replies on thread creation failure
- Fix thread_mode docstring alignment
- Fix allowed_channels comment formatting in config.example.yaml

* fix(discord): reset context for orphaned threads in mention_only mode

When a message arrives in a thread not tracked by _active_threads,
clear thread_id and typing_target so the message falls through to
the standard channel handling pipeline, which creates a fresh thread
instead of incorrectly routing to the stale thread.

* fix(discord): create new thread on @ when channel has existing tracked thread

When mention_only is enabled and a user @-s the bot in a channel
that already has a tracked thread, create a new thread instead of
incorrectly routing to the old one.

* fix(discord): allow no-@ thread replies while skipping no-@ channel messages

The skip block for no-@ messages was too aggressive — it blocked
continuation replies within tracked threads AND incorrectly routed
no-@ channel messages to the existing thread.

Now:
- Thread message, no @ → routed to existing tracked thread
- Channel message, no @ → skipped
- Channel message, with @ → creates new thread

* feat(discord): add checkmark reaction to acknowledge received messages

* Move discord.py to optional dependency and auto-detect from config.yaml

- Add discord extra to [project.optional-dependencies] in pyproject.toml
- Update detect_uv_extras.py to map channels.discord.enabled: true -> --extra discord
- Set UV_EXTRAS=discord in docker-compose-dev.yaml gateway env

* fix(discord): persist thread-channel mappings to store for recovery after restart

Discord's _active_threads dict was purely in-memory, so all channel-to-thread
mappings were lost on server restart. This fix bridges ChannelStore into
DiscordChannel:

- Save thread mappings to store.json after every thread creation
- Restore active threads from store on DiscordChannel startup
- Pass channel_store to all channels via service.py config injection

Store keys follow the pattern: discord:<channel_id>:<thread_id>

* fix(discord): address Copilot review — fix types, typing targets, cross-thread safety, and config comments

* fix(tests): add multitask_strategy param to mock for clarification follow-up test

* fix(tests): explicitly set model_name=None for title middleware test isolation

* fix(discord): use trigger_typing() instead of typing() for typing indicators

discord.py 2.x TextChannel.typing() and Thread.typing() are async context
managers, not one-shot coroutines. Use trigger_typing() for periodic
typing indicator pings.

* fix(discord): cancel typing tasks on channel shutdown

Prevents 'Task was destroyed but it is pending' warnings when the
Discord client stops while typing indicator loops are still running.

* fix(scripts): detect nested YAML config for discord extra

section_value() only matched top-level YAML sections. Added
nested_section_value() that handles two-level nesting (e.g.,
channels.discord.enabled), so auto-detection of the discord
extra works when config uses the standard nested format.

* fix(docker): remove hard-coded UV_EXTRAS=discord from dev compose

Relies on auto-detection via detect_uv_extras.py instead of forcing
discord.py install even when channels.discord.enabled is false.
Matches production docker-compose.yaml behavior (UV_EXTRAS:-).

* refactor(nginx): move proxy_buffering/proxy_cache to server level

DRY cleanup — these directives were repeated in 14 location blocks.
Set at server level once, reducing duplication and risk of drift.

* fix(discord): use dedicated JSON file for thread persistence

Replace ChannelStore usage for Discord thread-ID persistence with a
dedicated discord_threads.json file. ChannelStore is designed to map
IM conversations to DeerFlow thread IDs — using it to persist Discord
thread IDs was semantically wrong and confusing.

Changes:
- _save_thread() now reads/writes a simple {channel_id: thread_id} JSON dict
- _load_active_threads() reads directly from the JSON file
- File path derived from ChannelStore directory (when available) or
  defaults to ~/.deer-flow/channels/discord_threads.json
- Removed unused ChannelStore import

* fix(discord): address WillemJiang's code review comments on PR #2842

1. Remove semantically incorrect message_in_thread variable. At this code
   point (after the Thread case is handled above), we're guaranteed to be in
   a channel, not a thread. Always apply mention_only check here.

2. Add _active_thread_ids reverse-lookup set for O(1) thread ID membership
   checks instead of O(n) scan of _active_threads.values(). Keep the set
   in sync with _active_threads in _load_active_threads() and _save_thread().

3. Add _thread_store_lock (threading.Lock) to protect _active_threads and
   the JSON file from concurrent access between the Discord loop thread
   (_run_client) and the main thread (_load_active_threads, _save_thread).
2026-05-15 22:30:05 +08:00
Admire 7c42ab3e16 fix(frontend): wait for async chat submit before clearing (#2940)
* fix(frontend): wait for async chat submit before clearing

* test(frontend): cover pending attachment uploads

* fix(frontend): preserve sync submit semantics
2026-05-15 22:27:10 +08:00
Hinotobi 7a2670eaea fix(gateway): cap skill artifact preview size (#2963) 2026-05-15 22:15:58 +08:00
Nan Gao 0c37509b38 fix(middleware): Prevent todo completion reminder IMMessage leak (#2907)
* fix(middleware): Prevent todo completion reminder IMMessage leak (#2892)

* make format

* fix(middleware): Clear stale todo reminder counts (#2892)

* add size guard for _completion_reminder_counts and add a integration test
2026-05-15 22:12:37 +08:00
LawranceLiao 181d836541 fix(middleware): normalize tool result adjacency before model calls (#2939)
* normalizing tool-call transcripts before invocation

* test(middleware): cover tool result regrouping edge cases
2026-05-15 22:09:04 +08:00
Nan Gao 45060a9ffc fix(runtime): avoid postgres aggregate row lock (#2962) 2026-05-15 10:32:09 +08:00
LawranceLiao 722c690f4f fix(memory): isolate queued memory updates by agent (#2941)
* fix(memory): isolate queued memory updates by agent

* fix(memory): include user in queue identity

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-15 10:26:35 +08:00
dependabot[bot] ba864112a3 chore(deps): bump langsmith from 0.7.36 to 0.8.0 in /backend (#2943)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.7.36 to 0.8.0.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.7.36...v0.8.0)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.8.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-14 11:02:58 +08:00
AochenShen99 6e8e6a969b test: add blocking IO detector (#2924)
* test: add blocking IO detector

* test: add blocking IO probe option

* test: harden blocking IO probe lifecycle

* test: move blocking io detector to support
2026-05-13 23:56:06 +08:00
YuJitang eab7ae3d62 feat: stream subagent token usage to header via terminal task events (#2882)
* feat: real-time subagent token usage display in header and per-turn

Backend:
- Persist subagent token usage to AIMessage.usage_metadata via
  TokenUsageMiddleware, so accumulateUsage() naturally includes
  subagent tokens without frontend state management
- Cache subagent usage by tool_call_id in task_tool, write back
  to the dispatching AIMessage on next model response
- Emit subagent token usage on all terminal task events
  (task_completed, task_failed, task_cancelled, task_timed_out)
- Report subagent usage to parent RunJournal for API totals
- Search backward from ToolMessage to find dispatching AIMessage
  for correct multi-tool-call attribution

Frontend:
- Remove subagentUsage state, custom event handling, and prop
  threading — subagent tokens are now embedded in message metadata
- Simplify selectHeaderTokenUsage (no subagentUsage parameter)
- Per-turn inline badges show turn-specific usage via message
  accumulation
- Remove isLoading guard from MessageTokenUsageList for dynamic
  updates during streaming

* fix: prevent header token double counting from baseline reset race

onFinish, onError, and thread-switch useEffect all reset
pendingUsageBaselineMessageIdsRef to an empty Set. If
thread.isLoading is still true on the next render, all messages
pass the getMessagesAfterBaseline filter and their tokens are
added to backendUsage (which already includes them), causing
the header to display up to 2× the actual token count.

Capture current message IDs instead of using an empty Set so
that getMessagesAfterBaseline correctly returns no pending
messages even if thread.isLoading lags behind the stream end.

* fix: write back subagent tokens for all concurrent task tool calls

TokenUsageMiddleware only processed messages[-2], so when a
single model response dispatched multiple task tool calls only
the last ToolMessage had its cached subagent usage written back
to the dispatch AIMessage.usage_metadata. Earlier tasks' usage
stayed in _subagent_usage_cache indefinitely (leak) and never
appeared in the per-turn inline token display.

Walk backward through all consecutive ToolMessages before the
new AIMessage, and accumulate updates targeting the same
dispatch message into one state update so overlapping writes
don't clobber each other.

* fix: clean up subagent usage cache entry on task cancellation

When a task_tool invocation is cancelled via CancelledError, any
cached subagent usage entry leaked because the TokenUsageMiddleware
writeback path never fires after cancellation. Pop the cache entry
before re-raising to prevent unbounded growth of the module-level
_subagent_usage_cache dict.

* fix: address token usage review feedback

* fix: handle missing config for subagent usage cache

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-13 23:52:19 +08:00
Xinmin Zeng f1a0ab699a fix(tools): preserve tool_search promotions across re-entrant get_available_tools (#2885)
* fix(tools): preserve tool_search promotions across re-entrant get_available_tools

Closes #2884.

``get_available_tools`` used to unconditionally call
``reset_deferred_registry()`` and rebuild a fresh ``DeferredToolRegistry``
on every invocation. That works for the first call of a request (the
ContextVar starts at its default of ``None``), but any RE-ENTRANT call
during the same async context — e.g. ``task_tool`` building a subagent's
toolset, or a custom middleware that rebuilds tools mid-run — wiped any
``tool_search`` promotions the parent agent had already made. The
``DeferredToolFilterMiddleware`` would then re-hide those tools from the
next model call, leaving the agent able to see a tool's name (via the
prior ``tool_search`` result that's still in conversation history) but
unable to invoke it.

Fix: when the ContextVar already holds a registry, reuse it instead of
rebuilding. Fresh requests still get a fresh registry because each new
graph run starts in a new asyncio task with the ContextVar at ``None``.

## Verification

- Unit-level reproduction (``test_get_available_tools_resets_registry_wiping_promotion``):
  promote a tool in the registry, call ``get_available_tools`` again, assert
  the promotion is preserved. Fails on main, passes on this branch.

- Graph-execution reproduction (two tests): drive a real
  ``langchain.agents.create_agent`` graph with the real
  ``DeferredToolFilterMiddleware`` through two model turns, including one
  that issues a re-entrant ``get_available_tools`` call to simulate the
  task_tool subagent path.

- Real-LLM end-to-end (``test_deferred_tool_promotion_real_llm.py``,
  opt-in via ``ONEAPI_E2E=1``): drives the same flow against a real
  OpenAI-compatible model (verified on GPT-5.4-mini through the one-api
  gateway), watches the model call the promoted ``fake_calculator``
  through the deferred-filter middleware, and asserts the right arithmetic
  result. Passes against the fixed branch.

- Companion update to ``test_tool_deduplication.py``: dropped the
  ``@patch("deerflow.tools.tools.reset_deferred_registry")`` decorators
  because the symbol is no longer imported there.

- Test fixtures in the new files patch ``deerflow.tools.tools.get_app_config``
  with a minimal ``model_construct``-ed ``AppConfig`` instead of calling
  the real loader, so they never trigger ``_apply_singleton_configs`` and
  never leak ``_memory_config``/``_title_config``/… mutations into the
  rest of the suite.

Full backend suite: 3208 passed / 14 skipped / 0 failed. ruff check + format clean.

* fix(tools): address Copilot review on #2885

- tools.py: rewrite the reuse-path comment to spell out (a) why we don't
  reconcile the registry against the current ``mcp_tools`` snapshot — the
  MCP cache doesn't refresh mid-graph-run, the lead agent's ``ToolNode``
  is already bound to the previous tool set anyway, and ``promote()``
  drops the entry so a naive re-sync misclassifies promotions as new
  tools — and (b) why the log uses ``max(0, …)`` to avoid negative
  counts when the cache shrinks between snapshots.
- Replace direct ``ts_mod._registry_var.set(None)`` in test fixtures with
  the public ``reset_deferred_registry()`` helper so tests don't couple
  to module internals.
- Correct the docstring path in ``test_deferred_tool_registry_promotion.py``
  to match the actual monkeypatch target (``deerflow.mcp.cache.get_cached_mcp_tools``).
- Rename
  ``test_get_available_tools_resets_registry_wiping_promotion`` to
  ``test_get_available_tools_preserves_promotions_across_reentrant_calls``
  so the test name describes the contract being asserted, not the bug it
  originally reproduced.

Full backend suite: 3208 passed / 14 skipped. Real-LLM e2e: 1 passed.
2026-05-13 23:45:47 +08:00
Eilen Shin 2a1ac06bf4 fix(persistence): reuse token usage model grouping expression (#2910) 2026-05-13 15:49:34 +08:00
Willem Jiang 2b2742c034 Merge branch 'main' into fix-2804 2026-05-12 15:53:28 +08:00
copilot-swe-agent[bot] 6ffe267d20 fix: use AIMessage content for summarization response parsing
Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/0b34cccd-f69a-4f68-bfa6-9a41256b63b5

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>
2026-05-11 14:48:45 +00:00
Willem Jiang c995c3a394 fix: hide summarization LLM output from frontend during context compression
Fixes #2804

  The useStream hook tracks messages-tuple mode which captures ALL LLM
  events, including the internal summarization LLM call. This caused the
  English summary text to briefly appear as an AI message in the UI before
  the REMOVE_ALL_MESSAGES state update replaced it.

  Fix by overriding _create_summary/_acreate_summary to pass callbacks=[]
  when invoking the summary model, preventing LangGraph from forwarding
  the internal LLM events to the frontend stream. Also add
  hide_from_ui=True to the summary HumanMessage's additional_kwargs as a
  belt-and-suspenders safety net alongside the existing name=summary
  check.
2026-05-11 22:19:47 +08:00
63 changed files with 3698 additions and 292 deletions
+1 -1
View File
@@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
+1 -1
View File
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
+291 -11
View File
@@ -3,8 +3,10 @@
from __future__ import annotations
import asyncio
import json
import logging
import threading
from pathlib import Path
from typing import Any
from app.channels.base import Channel
@@ -21,6 +23,12 @@ class DiscordChannel(Channel):
Configuration keys (in ``config.yaml`` under ``channels.discord``):
- ``bot_token``: Discord Bot token.
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
(even when mention_only is true). Use for channels where you want the bot to respond
without mentions. Empty = mention_only applies everywhere.
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
Default: same as ``mention_only``.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
@@ -32,6 +40,29 @@ class DiscordChannel(Channel):
self._allowed_guilds.add(int(guild_id))
except (TypeError, ValueError):
continue
self._mention_only: bool = bool(config.get("mention_only", False))
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
self._allowed_channels: set[str] = set()
for channel_id in config.get("allowed_channels", []):
self._allowed_channels.add(str(channel_id))
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
# conversations to DeerFlow thread IDs — a different concern.
self._active_threads: dict[str, str] = {}
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
self._active_thread_ids: set[str] = set()
# Lock protecting _active_threads and the JSON file from concurrent access.
# _run_client (Discord loop thread) and the main thread both read/write.
self._thread_store_lock = threading.Lock()
store = config.get("channel_store")
if store is not None:
self._thread_store_path = store._path.parent / "discord_threads.json"
else:
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
# Typing indicator management
self._typing_tasks: dict[str, asyncio.Task] = {}
self._client = None
self._thread: threading.Thread | None = None
@@ -75,12 +106,56 @@ class DiscordChannel(Channel):
self._thread = threading.Thread(target=self._run_client, daemon=True)
self._thread.start()
self._load_active_threads()
logger.info("Discord channel started")
def _load_active_threads(self) -> None:
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
with self._thread_store_lock:
try:
if not self._thread_store_path.exists():
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
return
data = json.loads(self._thread_store_path.read_text())
self._active_threads.clear()
self._active_thread_ids.clear()
for channel_id, thread_id in data.items():
self._active_threads[channel_id] = thread_id
self._active_thread_ids.add(thread_id)
if self._active_threads:
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
except Exception:
logger.exception("[Discord] failed to load thread mappings")
def _save_thread(self, channel_id: str, thread_id: str) -> None:
"""Persist a Discord thread mapping to the dedicated JSON file."""
with self._thread_store_lock:
try:
data: dict[str, str] = {}
if self._thread_store_path.exists():
data = json.loads(self._thread_store_path.read_text())
old_id = data.get(channel_id)
data[channel_id] = thread_id
# Update reverse-lookup set
if old_id:
self._active_thread_ids.discard(old_id)
self._active_thread_ids.add(thread_id)
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
self._thread_store_path.write_text(json.dumps(data, indent=2))
except Exception:
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
# Cancel all active typing indicator tasks
for target_id, task in list(self._typing_tasks.items()):
if not task.done():
task.cancel()
logger.debug("[Discord] cancelled typing task for target %s", target_id)
self._typing_tasks.clear()
if self._client and self._discord_loop and self._discord_loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
try:
@@ -100,6 +175,10 @@ class DiscordChannel(Channel):
logger.info("Discord channel stopped")
async def send(self, msg: OutboundMessage) -> None:
# Stop typing indicator once we're sending the response
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
await asyncio.wrap_future(stop_future)
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
@@ -111,6 +190,9 @@ class DiscordChannel(Channel):
await asyncio.wrap_future(send_future)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
await asyncio.wrap_future(stop_future)
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
@@ -130,6 +212,41 @@ class DiscordChannel(Channel):
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
return False
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
"""Starts a loop to send periodic typing indicators."""
target_id = thread_ts or chat_id
if target_id in self._typing_tasks:
return # Already typing for this target
async def _typing_loop():
try:
while True:
try:
await channel.trigger_typing()
except Exception:
pass
await asyncio.sleep(10)
except asyncio.CancelledError:
pass
task = asyncio.create_task(_typing_loop())
self._typing_tasks[target_id] = task
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
"""Stops the typing loop for a specific target."""
target_id = thread_ts or chat_id
task = self._typing_tasks.pop(target_id, None)
if task and not task.done():
task.cancel()
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
async def _add_reaction(self, message) -> None:
"""Add a checkmark reaction to acknowledge the message was received."""
try:
await message.add_reaction("")
except Exception:
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
async def _on_message(self, message) -> None:
if not self._running or not self._client:
return
@@ -152,15 +269,143 @@ class DiscordChannel(Channel):
if self._discord_module is None:
return
if isinstance(message.channel, self._discord_module.Thread):
chat_id = str(message.channel.parent_id or message.channel.id)
thread_id = str(message.channel.id)
# Determine whether the bot is mentioned in this message
user = self._client.user if self._client else None
if user:
bot_mention = user.mention # <@ID>
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
standard_mention = f"<@{user.id}>"
else:
thread = await self._create_thread(message)
if thread is None:
bot_mention = None
alt_mention = None
standard_mention = ""
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
# Strip mention from text for processing
if has_mention:
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
# --- Determine thread/channel routing and typing target ---
thread_id = None
chat_id = None
typing_target = None # The Discord object to type into
if isinstance(message.channel, self._discord_module.Thread):
# --- Message already inside a thread ---
thread_obj = message.channel
thread_id = str(thread_obj.id)
chat_id = str(thread_obj.parent_id or thread_obj.id)
typing_target = thread_obj
# If this is a known active thread, process normally
if thread_id in self._active_thread_ids:
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
text=text,
msg_type=msg_type,
thread_ts=thread_id,
metadata={
"guild_id": str(guild.id) if guild else None,
"channel_id": str(message.channel.id),
"message_id": str(message.id),
},
)
inbound.topic_id = thread_id
self._publish(inbound)
# Start typing indicator in the thread
if typing_target:
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
asyncio.create_task(self._add_reaction(message))
return
chat_id = str(message.channel.id)
thread_id = str(thread.id)
# Thread not tracked (orphaned) — create new thread and handle below
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
thread_id = None
typing_target = None
# At this point we're guaranteed to be in a channel, not a thread
# (the Thread case is handled above). Apply mention_only for all
# non-thread messages — no special case needed.
channel_id = str(message.channel.id)
# Check if there's an active thread for this channel
if channel_id in self._active_threads:
# respect mention_only: if enabled, only process messages that mention the bot
# (unless the channel is in allowed_channels)
# Messages within a thread are always allowed through (continuation).
# At this code point we know the message is in a channel, not a thread
# (Thread case handled above), so always apply the check.
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
return
# mention_only + fresh @ → create new thread instead of routing to existing one
if self._mention_only and has_mention:
thread_obj = await self._create_thread(message)
if thread_obj is not None:
target_thread_id = str(thread_obj.id)
self._active_threads[channel_id] = target_thread_id
self._save_thread(channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = thread_obj
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
else:
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel
else:
# Existing session → route to the existing thread
target_thread_id = self._active_threads[channel_id]
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = await self._get_channel_or_thread(target_thread_id)
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
# Not mentioned and not in an allowed channel → skip
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
return
elif self._mention_only and has_mention:
# First mention in this channel → create thread
thread_obj = await self._create_thread(message)
if thread_obj is not None:
target_thread_id = str(thread_obj.id)
self._active_threads[channel_id] = target_thread_id
self._save_thread(channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = thread_obj # Type into the new thread
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
else:
# Fallback: thread creation failed (disabled/permissions), reply in channel
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel # Type into the channel
elif self._thread_mode:
# thread_mode but mention_only is False → create thread anyway for conversation grouping
thread_obj = await self._create_thread(message)
if thread_obj is None:
# Thread creation failed (disabled/permissions), fall back to channel replies
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel # Type into the channel
else:
target_thread_id = str(thread_obj.id)
self._active_threads[channel_id] = target_thread_id
self._save_thread(channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = thread_obj # Type into the new thread
else:
# No threading — reply directly in channel
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel # Type into the channel
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
@@ -177,6 +422,15 @@ class DiscordChannel(Channel):
)
inbound.topic_id = thread_id
# Start typing indicator in the correct target (thread or channel)
if typing_target:
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
self._publish(inbound)
asyncio.create_task(self._add_reaction(message))
def _publish(self, inbound) -> None:
"""Publish an inbound message to the main event loop."""
if self._main_loop and self._main_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
@@ -198,14 +452,40 @@ class DiscordChannel(Channel):
async def _create_thread(self, message):
try:
if self._discord_module is None:
return None
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
channel_type = message.channel.type
if channel_type not in (
self._discord_module.ChannelType.text,
self._discord_module.ChannelType.news,
):
logger.info(
"[Discord] channel type %s (%s) does not support threads",
channel_type.value,
channel_type.name,
)
return None
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
return await message.create_thread(name=thread_name)
except self._discord_module.errors.HTTPException as exc:
if exc.code == 50024:
logger.info(
"[Discord] cannot create thread in channel %s (error code 50024): %s",
message.channel.id,
channel_type.name if (channel_type := message.channel.type) else "unknown",
)
else:
logger.exception(
"[Discord] failed to create thread for message=%s (HTTPException %s)",
message.id,
exc.code,
)
return None
except Exception:
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
try:
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
except Exception:
pass
return None
async def _resolve_target(self, msg: OutboundMessage):
+16 -7
View File
@@ -787,13 +787,22 @@ class ChannelManager:
return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
)
try:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
multitask_strategy="reject",
)
except Exception as exc:
if _is_thread_busy_error(exc):
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
await self._send_error(msg, THREAD_BUSY_MESSAGE)
return
else:
raise
response_text = _extract_response_text(result)
artifacts = _extract_artifacts(result)
+2
View File
@@ -167,6 +167,8 @@ class ChannelService:
return False
try:
config = dict(config)
config["channel_store"] = self.store
channel = channel_cls(bus=self.bus, config=config)
self._channels[name] = channel
await channel.start()
+31 -3
View File
@@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
_SECRET_FILE = ".jwt_secret"
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.
@@ -30,6 +32,32 @@ class AuthConfig(BaseModel):
_auth_config: AuthConfig | None = None
def _load_or_create_secret() -> str:
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
from deerflow.config.paths import get_paths
paths = get_paths()
secret_file = paths.base_dir / _SECRET_FILE
try:
if secret_file.exists():
secret = secret_file.read_text(encoding="utf-8").strip()
if secret:
return secret
except OSError as exc:
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
secret = secrets.token_urlsafe(32)
try:
secret_file.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(secret)
except OSError as exc:
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
return secret
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
@@ -39,11 +67,11 @@ def get_auth_config() -> AuthConfig:
load_dotenv()
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
jwt_secret = _load_or_create_secret()
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
"persisted to .jwt_secret. Sessions will survive restarts. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
+24 -5
View File
@@ -20,6 +20,9 @@ ACTIVE_CONTENT_MIME_TYPES = {
"image/svg+xml",
}
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
def _build_content_disposition(disposition_type: str, filename: str) -> str:
"""Build an RFC 5987 encoded Content-Disposition header value."""
@@ -44,6 +47,22 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
return False
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
"""Read a .skill archive member while enforcing an uncompressed size cap."""
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
chunks: list[bytes] = []
total_read = 0
with zip_ref.open(info, "r") as src:
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
total_read += len(chunk)
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
chunks.append(chunk)
return b"".join(chunks)
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
"""Extract a file from a .skill ZIP archive.
@@ -60,16 +79,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# List all files in the archive
namelist = zip_ref.namelist()
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
# Try direct path first
if internal_path in namelist:
return zip_ref.read(internal_path)
if internal_path in infos_by_name:
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
for name in namelist:
for name, info in infos_by_name.items():
if name.endswith("/" + internal_path) or name == internal_path:
return zip_ref.read(name)
return _read_skill_archive_member(zip_ref, info)
# Not found
return None
+2 -2
View File
@@ -99,7 +99,7 @@ rm -f backend/.deer-flow/data/deerflow.db
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效 |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持 |
### 生产环境建议
@@ -137,4 +137,4 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 |
@@ -40,6 +40,15 @@ class MemoryUpdateQueue:
self._timer: threading.Timer | None = None
self._processing = False
@staticmethod
def _queue_key(
thread_id: str,
user_id: str | None,
agent_name: str | None,
) -> tuple[str, str | None, str | None]:
"""Return the debounce identity for a memory update target."""
return (thread_id, user_id, agent_name)
def add(
self,
thread_id: str,
@@ -115,8 +124,9 @@ class MemoryUpdateQueue:
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
queue_key = self._queue_key(thread_id, user_id, agent_name)
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
@@ -130,7 +140,7 @@ class MemoryUpdateQueue:
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
self._queue.append(context)
def _reset_timer(self) -> None:
@@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import resolve_runtime_user_id
def memory_flush_hook(event: SummarizationEvent) -> None:
@@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
user_id = resolve_runtime_user_id(event.runtime)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
return "[Tool call was interrupted and did not return a result.]"
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
"""Return messages with tool results grouped after their tool-call AIMessage.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged.
"""
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
tool_messages_by_id: dict[str, ToolMessage] = {}
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
# Check if any patching is needed
needs_patch = False
tool_call_ids: set[str] = set()
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
break
if needs_patch:
break
if tc_id:
tool_call_ids.add(tc_id)
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
patched_ids: set[str] = set()
consumed_tool_msg_ids: set[str] = set()
patch_count = 0
for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
continue
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
if not tc_id or tc_id in consumed_tool_msg_ids:
continue
existing_tool_msg = tool_messages_by_id.get(tc_id)
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else:
patched.append(
ToolMessage(
content=self._synthetic_tool_message_content(tc),
@@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error",
)
)
patched_ids.add(tc_id)
consumed_tool_msg_ids.add(tc_id)
patch_count += 1
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
if patched == messages:
return None
if patch_count:
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched
@override
@@ -10,6 +10,7 @@ from typing import Any, Protocol, override, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.messages.utils import get_buffer_string
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
@@ -175,12 +176,84 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
]
}
@override
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary without emitting streaming events to the client.
Suppresses callbacks to prevent the internal summarization LLM call from
producing visible AI message chunks in the frontend's ``messages-tuple``
stream (issue #2804).
"""
if not messages_to_summarize:
return "No previous conversation history."
trimmed = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed:
return "Previous conversation was too long to summarize."
formatted = get_buffer_string(trimmed)
try:
response = self.model.with_config(callbacks=[]).invoke(
self.summary_prompt.format(messages=formatted).rstrip(),
config={
"metadata": {"lc_source": "summarization"},
"callbacks": [],
},
)
return self._extract_summary_text(response)
except Exception as e:
return f"Error generating summary: {e!s}"
@override
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary without emitting streaming events to the client.
Suppresses callbacks to prevent the internal summarization LLM call from
producing visible AI message chunks in the frontend's ``messages-tuple``
stream (issue #2804).
"""
if not messages_to_summarize:
return "No previous conversation history."
trimmed = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed:
return "Previous conversation was too long to summarize."
formatted = get_buffer_string(trimmed)
try:
response = await self.model.with_config(callbacks=[]).ainvoke(
self.summary_prompt.format(messages=formatted).rstrip(),
config={
"metadata": {"lc_source": "summarization"},
"callbacks": [],
},
)
return self._extract_summary_text(response)
except Exception as e:
return f"Error generating summary: {e!s}"
def _extract_summary_text(self, response: Any) -> str:
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
# Fall back to .content for non-LangChain responses.
summary_text = getattr(response, "text", None)
if summary_text is None:
summary_text = getattr(response, "content", "")
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
@override
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
"""Override the base implementation to let the human message with the special name 'summary'.
And this message will be ignored to display in the frontend, but still can be used as context for the model.
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
return [
HumanMessage(
content=f"Here is a summary of the conversation to date:\n\n{summary}",
name="summary",
additional_kwargs={"hide_from_ui": True},
)
]
def _preserve_dynamic_context_reminders(
self,
@@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware injects a reminder
and jumps back to the model node to force continued engagement.
(no tool calls) but todos are not yet complete, the middleware queues a reminder
for the next model request and jumps back to the model node to force continued
engagement. The completion reminder is injected via ``wrap_model_call`` instead
of being persisted into graph state as a normal user-visible message.
"""
from __future__ import annotations
import threading
from collections.abc import Awaitable, Callable
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import hook_config
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
@@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str:
return "\n".join(lines)
def _format_completion_reminder(todos: list[Todo]) -> str:
"""Format a completion reminder for incomplete todo items."""
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
return (
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
)
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
"""Return True when an AIMessage is not a clean final answer.
Todo completion reminders should only fire when the model has produced a
plain final response. Provider/tool parsing details have moved across
LangChain versions and integrations, so keep all tool-intent/error signals
behind this helper instead of checking one concrete field at the call site.
"""
if message.tool_calls:
return True
if getattr(message, "invalid_tool_calls", None):
return True
# Backward/provider compatibility: some integrations preserve raw or legacy
# tool-call intent in additional_kwargs even when structured tool_calls is
# empty. If this helper changes, update the matching sentinel test
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
# if that test fails after a LangChain upgrade, review this helper so new
# tool-call/error fields are not silently treated as clean final answers.
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
return True
response_metadata = getattr(message, "response_metadata", {}) or {}
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
@@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware):
formatted = _format_todos(todos)
reminder = HumanMessage(
name="todo_reminder",
additional_kwargs={"hide_from_ui": True},
content=(
"<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, "
@@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware):
# Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
_MAX_COMPLETION_REMINDER_KEYS = 4096
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._lock = threading.Lock()
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
self._completion_reminder_next_order = 0
@staticmethod
def _get_thread_id(runtime: Runtime) -> str:
context = getattr(runtime, "context", None)
thread_id = context.get("thread_id") if context else None
return str(thread_id) if thread_id else "default"
@staticmethod
def _get_run_id(runtime: Runtime) -> str:
context = getattr(runtime, "context", None)
run_id = context.get("run_id") if context else None
return str(run_id) if run_id else "default"
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
return self._get_thread_id(runtime), self._get_run_id(runtime)
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
self._completion_reminder_next_order += 1
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
keys = set(self._pending_completion_reminders)
keys.update(self._completion_reminder_counts)
keys.update(self._completion_reminder_touch_order)
return keys
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
self._pending_completion_reminders.pop(key, None)
self._completion_reminder_counts.pop(key, None)
self._completion_reminder_touch_order.pop(key, None)
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
keys = self._completion_reminder_keys_locked()
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
if overflow <= 0:
return
candidates = [key for key in keys if key != protected_key]
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
for key in candidates[:overflow]:
self._drop_completion_reminder_key_locked(key)
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
key = self._pending_key(runtime)
with self._lock:
self._pending_completion_reminders.setdefault(key, []).append(reminder)
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
self._touch_completion_reminder_key_locked(key)
self._prune_completion_reminder_state_locked(protected_key=key)
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
key = self._pending_key(runtime)
with self._lock:
return self._completion_reminder_counts.get(key, 0)
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
key = self._pending_key(runtime)
with self._lock:
reminders = self._pending_completion_reminders.pop(key, [])
if reminders or key in self._completion_reminder_counts:
self._touch_completion_reminder_key_locked(key)
return reminders
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
thread_id, current_run_id = self._pending_key(runtime)
with self._lock:
for key in self._completion_reminder_keys_locked():
if key[0] == thread_id and key[1] != current_run_id:
self._drop_completion_reminder_key_locked(key)
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
key = self._pending_key(runtime)
with self._lock:
self._drop_completion_reminder_key_locked(key)
@override
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@override
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@hook_config(can_jump_to=["model"])
@override
@@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware):
if base_result is not None:
return base_result
# 2. Only intervene when the agent wants to exit (no tool calls).
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
# intent or tool-call parse errors should be handled by the tool path
# instead of being masked by todo reminders.
messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or last_ai.tool_calls:
if not last_ai or _has_tool_call_intent_or_error(last_ai):
return None
# 3. Allow exit when all todos are completed or there are no todos.
@@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware):
return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
return None
# 5. Inject a reminder and force the agent back to the model.
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
reminder = HumanMessage(
name="todo_completion_reminder",
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
# 5. Queue a reminder for the next model request and jump back. We must
# not persist this control prompt as a normal HumanMessage, otherwise it
# can leak into user-visible message streams and saved transcripts.
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
return {"jump_to": "model"}
@override
@hook_config(can_jump_to=["model"])
@@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None:
"""Async version of after_model."""
return self.after_model(state, runtime)
@staticmethod
def _format_pending_completion_reminders(reminders: list[str]) -> str:
return "\n\n".join(dict.fromkeys(reminders))
def _augment_request(self, request: ModelRequest) -> ModelRequest:
reminders = self._drain_completion_reminders(request.runtime)
if not reminders:
return request
new_messages = [
*request.messages,
HumanMessage(
content=self._format_pending_completion_reminders(reminders),
name="todo_completion_reminder",
additional_kwargs={"hide_from_ui": True},
),
]
return request.override(messages=new_messages)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(self._augment_request(request))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._augment_request(request))
@override
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@override
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@@ -9,7 +9,7 @@ from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
return "thinking"
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
"""Return True if the AIMessage contains a tool_call with the given id."""
for tc in message.tool_calls or []:
if isinstance(tc, dict):
if tc.get("id") == tool_call_id:
return True
elif hasattr(tc, "id") and tc.id == tool_call_id:
return True
return False
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
if not messages:
return None
# Annotate subagent token usage onto the AIMessage that dispatched it.
# When a task tool completes, its usage is cached by tool_call_id. Detect
# the ToolMessage → search backward for the corresponding AIMessage → merge.
# Walk backward through consecutive ToolMessages before the new AIMessage
# so that multiple concurrent task tool calls all get their subagent tokens
# written back to the same dispatch message (merging into one update).
state_updates: dict[int, AIMessage] = {}
if len(messages) >= 2:
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
idx = len(messages) - 2
while idx >= 0:
tool_msg = messages[idx]
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
break
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
if subagent_usage:
# Search backward from the ToolMessage to find the AIMessage
# that dispatched it. A single model response can dispatch
# multiple task tool calls, so we can't assume a fixed offset.
dispatch_idx = idx - 1
while dispatch_idx >= 0:
candidate = messages[dispatch_idx]
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
# Accumulate into an existing update for the same
# AIMessage (multiple task calls in one response),
# or merge fresh from the original message.
existing_update = state_updates.get(dispatch_idx)
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
merged = {
**prev,
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
}
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
break
dispatch_idx -= 1
idx -= 1
last = messages[-1]
if not isinstance(last, AIMessage):
if state_updates:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return None
usage = getattr(last, "usage_metadata", None)
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
state_updates[len(messages) - 1] = updated_msg
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
@@ -223,10 +223,11 @@ class RunRepository(RunStore):
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
stmt = (
select(
func.coalesce(RunRow.model_name, "unknown").label("model"),
model_name.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
@@ -236,7 +237,7 @@ class RunRepository(RunStore):
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
.group_by(model_name)
)
async with self._sf() as session:
@@ -11,7 +11,7 @@ import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy import delete, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
@@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore):
user = get_current_user()
return str(user.id) if user is not None else None
@staticmethod
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
"""Return the current max seq while serializing writers per thread.
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
results are not lockable rows. As a release-safe workaround, take a
transaction-level advisory lock keyed by thread_id before reading the
aggregate. Other dialects keep the existing row-locking statement.
"""
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
bind = session.get_bind()
dialect_name = bind.dialect.name if bind is not None else ""
if dialect_name == "postgresql":
await session.execute(
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
{"thread_id": thread_id},
)
return await session.scalar(stmt)
return await session.scalar(stmt.with_for_update())
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
@@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore):
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
@@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore):
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = max_seq or 0
rows = []
for e in events:
@@ -26,6 +26,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
# write it back to the triggering AIMessage's usage_metadata.
_subagent_usage_cache: dict[str, dict[str, int]] = {}
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
if app_config is None:
try:
app_config = get_app_config()
except FileNotFoundError:
return False
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
if enabled and usage:
_subagent_usage_cache[tool_call_id] = usage
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
return _subagent_usage_cache.pop(tool_call_id, None)
def _is_subagent_terminal(result: Any) -> bool:
"""Return whether a background subagent result is safe to clean up."""
@@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
return None
def _summarize_usage(records: list[dict] | None) -> dict | None:
"""Summarize token usage records into a compact dict for SSE events."""
if not records:
return None
return {
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
}
def _report_subagent_usage(runtime: Any, result: Any) -> None:
"""Report subagent token usage to the parent RunJournal, if available.
@@ -177,6 +210,7 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
runtime_app_config = _get_runtime_app_config(runtime)
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration
@@ -312,27 +346,32 @@ async def task_tool(
last_message_count = current_message_count
# Check if task completed, failed, or timed out
usage = _summarize_usage(getattr(result, "token_usage_records", None))
if result.status == SubagentStatus.COMPLETED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id)
return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}"
@@ -351,7 +390,9 @@ async def task_tool(
timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id})
usage = _summarize_usage(getattr(result, "token_usage_records", None))
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively.
@@ -374,4 +415,8 @@ async def task_tool(
cleanup_background_task(task_id)
else:
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
_subagent_usage_cache.pop(tool_call_id, None)
raise
except Exception:
_subagent_usage_cache.pop(tool_call_id, None)
raise
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.builtins.tool_search import reset_deferred_registry
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
@@ -116,8 +116,6 @@ def get_available_tools(
# made through the Gateway API (which runs in a separate process) are immediately
# reflected when loading MCP tools.
mcp_tools = []
# Reset deferred registry upfront to prevent stale state from previous calls
reset_deferred_registry()
if include_mcp:
try:
from deerflow.config.extensions_config import ExtensionsConfig
@@ -135,12 +133,51 @@ def get_available_tools(
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
# Reuse the existing registry if one is already set for
# this async context. ``get_available_tools`` is
# re-entered whenever a subagent is spawned
# (``task_tool`` calls it to build the child agent's
# toolset), and previously we used to unconditionally
# rebuild the registry — wiping out the parent agent's
# tool_search promotions. The
# ``DeferredToolFilterMiddleware`` then re-hid those
# tools from subsequent model calls, leaving the agent
# able to see a tool's name but unable to invoke it
# (issue #2884). ``contextvars`` already gives us the
# lifetime semantics we want: a fresh request / graph
# run starts in a new asyncio task with the
# ContextVar at its default of ``None``, so reuse is
# only triggered for re-entrant calls inside one run.
#
# Intentionally NOT reconciling against the current
# ``mcp_tools`` snapshot. The MCP cache only refreshes
# on ``extensions_config.json`` mtime changes, which
# in practice happens between graph runs — not inside
# one. And even if a refresh did happen mid-run, the
# already-built lead agent's ``ToolNode`` still holds
# the *previous* tool set (LangGraph binds tools at
# graph construction time), so a brand-new MCP tool
# couldn't actually be invoked anyway. The
# ``DeferredToolRegistry`` doesn't retain the names
# of previously-promoted tools (``promote()`` drops
# the entry entirely), so re-syncing the registry
# against a fresh ``mcp_tools`` list would
# mis-classify those promotions as new tools and
# re-register them as deferred — exactly the bug
# this fix exists to prevent.
existing_registry = get_deferred_registry()
if existing_registry is None:
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
else:
mcp_tool_names = {t.name for t in mcp_tools}
still_deferred = len(existing_registry)
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
builtin_tools.append(tool_search_tool)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e:
+1
View File
@@ -25,6 +25,7 @@ dependencies = [
[project.optional-dependencies]
postgres = ["deerflow-harness[postgres]"]
discord = ["discord.py>=2.7.0"]
[dependency-groups]
dev = [
+93
View File
@@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
issues when unit-testing lightweight config/registry code in isolation.
"""
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
@@ -11,11 +13,16 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
# Make 'app' and 'deerflow' importable from any working directory
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
# Break the circular import chain that exists in production code:
# deerflow.subagents.__init__
# -> .executor (SubagentExecutor, SubagentResult)
@@ -56,6 +63,92 @@ def provisioner_module():
return module
@pytest.fixture()
def blocking_io_detector():
"""Fail a focused test if blocking calls run on the event loop thread."""
with detect_blocking_io(fail_on_exit=True) as detector:
yield detector
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("blocking-io")
group.addoption(
"--detect-blocking-io",
action="store_true",
default=False,
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
)
group.addoption(
"--detect-blocking-io-fail",
action="store_true",
default=False,
help="Set a failing exit status when --detect-blocking-io records violations.",
)
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
def pytest_sessionstart(session: pytest.Session) -> None:
if _blocking_io_probe_enabled(session.config):
_blocking_io_probe.clear()
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item: pytest.Item):
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
yield
return
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
detector.__enter__()
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item: pytest.Item):
yield
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
if detector is None:
return
try:
detector.__exit__(None, None, None)
_blocking_io_probe.record(item.nodeid, detector.violations)
finally:
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
def pytest_sessionfinish(session: pytest.Session) -> None:
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
session.exitstatus = pytest.ExitCode.TESTS_FAILED
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
if not _blocking_io_probe_enabled(terminalreporter.config):
return
header, *details = _blocking_io_probe.format_summary().splitlines()
terminalreporter.write_sep("=", header)
for line in details:
terminalreporter.write_line(line)
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io-fail"))
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
# ---------------------------------------------------------------------------
# Auto-set user context for every test unless marked no_auto_user
# ---------------------------------------------------------------------------
+1
View File
@@ -0,0 +1 @@
"""Shared test support helpers."""
@@ -0,0 +1 @@
"""Runtime and static detectors used by tests."""
@@ -0,0 +1,287 @@
"""Test helper for detecting blocking calls on an asyncio event loop.
The detector is intentionally test-only. It monkeypatches a small set of
well-known blocking entry points and their already-loaded module-level aliases,
then records calls only when they happen on a thread that is currently running
an asyncio event loop. Aliases captured in closures or default arguments remain
out of scope.
"""
from __future__ import annotations
import asyncio
import importlib
import sys
import traceback
from collections import Counter
from collections.abc import Callable, Iterable, Iterator
from contextlib import AbstractContextManager
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from types import TracebackType
from typing import Any
BlockingCallable = Callable[..., Any]
@dataclass(frozen=True)
class BlockingCallSpec:
"""Describes one blocking callable to wrap during a detector run."""
name: str
target: str
record_on_iteration: bool = False
@dataclass(frozen=True)
class BlockingCall:
"""One blocking call observed on an asyncio event loop thread."""
name: str
target: str
stack: tuple[traceback.FrameSummary, ...]
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
BlockingCallSpec("time.sleep", "time:sleep"),
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
)
def _is_event_loop_thread() -> bool:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return False
return loop.is_running()
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
module_name, attr_path = target.split(":", maxsplit=1)
owner: object = importlib.import_module(module_name)
parts = attr_path.split(".")
for part in parts[:-1]:
owner = getattr(owner, part)
attr_name = parts[-1]
original = getattr(owner, attr_name)
return owner, attr_name, original
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
return tuple(frame for frame in stack if frame.filename != __file__)
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
"""Record blocking calls made from async runtime code.
By default the detector reports violations but does not fail on context
exit. Tests can set ``fail_on_exit=True`` or call
``assert_no_blocking_calls()`` explicitly.
"""
def __init__(
self,
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> None:
self._specs = tuple(specs)
self._fail_on_exit = fail_on_exit
self._patch_loaded_aliases_enabled = patch_loaded_aliases
self._stack_limit = stack_limit
self._patches: list[tuple[object, str, BlockingCallable]] = []
self._patch_keys: set[tuple[int, str]] = set()
self.violations: list[BlockingCall] = []
self._active = False
def __enter__(self) -> BlockingIODetector:
try:
self._active = True
alias_replacements: dict[int, BlockingCallable] = {}
for spec in self._specs:
owner, attr_name, original = _resolve_target(spec.target)
wrapper = self._wrap(spec, original)
self._patch_attribute(owner, attr_name, original, wrapper)
alias_replacements[id(original)] = wrapper
if self._patch_loaded_aliases_enabled:
self._patch_loaded_module_aliases(alias_replacements)
except Exception:
self._restore()
self._active = False
raise
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback_value: TracebackType | None,
) -> bool | None:
self._restore()
self._active = False
if exc_type is None and self._fail_on_exit:
self.assert_no_blocking_calls()
return None
def _restore(self) -> None:
for owner, attr_name, original in reversed(self._patches):
setattr(owner, attr_name, original)
self._patches.clear()
self._patch_keys.clear()
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
key = (id(owner), attr_name)
if key in self._patch_keys:
return
setattr(owner, attr_name, replacement)
self._patches.append((owner, attr_name, original))
self._patch_keys.add(key)
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
for module in tuple(sys.modules.values()):
namespace = getattr(module, "__dict__", None)
if not isinstance(namespace, dict):
continue
for attr_name, value in tuple(namespace.items()):
replacement = replacements_by_id.get(id(value))
if replacement is not None:
self._patch_attribute(module, attr_name, value, replacement)
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
@wraps(original)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if spec.record_on_iteration:
result = original(*args, **kwargs)
return self._wrap_iteration(spec, result)
self._record_if_blocking(spec)
return original(*args, **kwargs)
return wrapper
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
iterator = iter(iterable)
reported = False
while True:
if not reported:
reported = self._record_if_blocking(spec)
try:
yield next(iterator)
except StopIteration:
return
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
if self._active and _is_event_loop_thread():
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
self.violations.append(BlockingCall(spec.name, spec.target, stack))
return True
return False
def assert_no_blocking_calls(self) -> None:
if self.violations:
raise AssertionError(format_blocking_calls(self.violations))
class BlockingIOProbe:
"""Collect detector output across tests and format a compact summary."""
def __init__(self, project_root: Path) -> None:
self._project_root = project_root.resolve()
self._observed: list[tuple[str, BlockingCall]] = []
@property
def violation_count(self) -> int:
return len(self._observed)
@property
def test_count(self) -> int:
return len({nodeid for nodeid, _violation in self._observed})
def clear(self) -> None:
self._observed.clear()
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
for violation in violations:
self._observed.append((nodeid, violation))
def format_summary(self, *, limit: int = 30) -> str:
if not self._observed:
return "blocking io probe: no violations"
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
for _nodeid, violation in self._observed:
frame = self._local_call_site(violation.stack)
if frame is None:
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
continue
call_sites[
(
violation.name,
self._relative(frame.filename),
frame.lineno,
frame.name,
(frame.line or "").strip(),
)
] += 1
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
return "\n".join(lines)
def _relative(self, filename: str) -> str:
try:
return str(Path(filename).resolve().relative_to(self._project_root))
except ValueError:
return filename
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
if local_frames:
return local_frames[-1]
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
return test_frames[-1] if test_frames else None
def detect_blocking_io(
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> BlockingIODetector:
"""Create a detector context manager for a focused test scope."""
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
"""Format detector output with enough stack context to locate call sites."""
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
for index, violation in enumerate(violations, start=1):
lines.append(f"{index}. {violation.name} ({violation.target})")
lines.extend(_format_stack(violation.stack))
return "\n".join(lines)
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
for frame in stack:
location = f"{frame.filename}:{frame.lineno}"
lines = [f" at {frame.name} ({location})"]
if frame.line:
lines.append(f" {frame.line.strip()}")
yield from lines
+15
View File
@@ -4,6 +4,7 @@ from pathlib import Path
import pytest
from _router_auth_helpers import call_unwrapped, make_authed_test_app
from fastapi import HTTPException
from fastapi.testclient import TestClient
from starlette.requests import Request
from starlette.responses import FileResponse
@@ -102,3 +103,17 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
assert response.status_code == 200
assert response.text == "hello"
assert response.headers.get("content-disposition", "").startswith("attachment;")
def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None:
skill_path = tmp_path / "sample.skill"
payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1)
with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref:
zip_ref.writestr("SKILL.md", payload)
assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES
with pytest.raises(HTTPException) as exc_info:
artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md")
assert exc_info.value.status_code == 413
+47 -11
View File
@@ -5,28 +5,26 @@ from unittest.mock import patch
import pytest
from app.gateway.auth.config import AuthConfig
import app.gateway.auth.config as cfg
def test_auth_config_defaults():
config = AuthConfig(jwt_secret="test-secret-key-123")
config = cfg.AuthConfig(jwt_secret="test-secret-key-123")
assert config.token_expiry_days == 7
def test_auth_config_token_expiry_range():
AuthConfig(jwt_secret="s", token_expiry_days=1)
AuthConfig(jwt_secret="s", token_expiry_days=30)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=1)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=30)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=0)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=0)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=31)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=31)
def test_auth_config_from_env():
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
with patch.dict(os.environ, env, clear=False):
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
@@ -36,19 +34,57 @@ def test_auth_config_from_env():
cfg._auth_config = old
def test_auth_config_missing_secret_generates_ephemeral(caplog):
def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog):
import logging
import app.gateway.auth.config as cfg
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
secret_file = tmp_path / ".jwt_secret"
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
assert secret_file.exists()
assert secret_file.read_text().strip() == config.jwt_secret
finally:
cfg._auth_config = old
def test_auth_config_reuses_persisted_secret(tmp_path):
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
persisted = "persisted-secret-from-file-min-32-chars!!"
(tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8")
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
config = cfg.get_auth_config()
assert config.jwt_secret == persisted
finally:
cfg._auth_config = old
def test_auth_config_empty_secret_file_generates_new(tmp_path):
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
(tmp_path / ".jwt_secret").write_text("", encoding="utf-8")
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
config = cfg.get_auth_config()
assert config.jwt_secret
assert len(config.jwt_secret) > 20
assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret
finally:
cfg._auth_config = old
+190
View File
@@ -0,0 +1,190 @@
from __future__ import annotations
import asyncio
import os
import time
from os import walk as imported_walk
from pathlib import Path
from time import sleep as imported_sleep
import httpx
import pytest
import requests
from support.detectors.blocking_io import (
BlockingCallSpec,
BlockingIOProbe,
detect_blocking_io,
)
pytestmark = pytest.mark.asyncio
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
async def test_records_time_sleep_on_event_loop() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
original_alias = imported_sleep
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
imported_sleep(0)
assert imported_sleep is original_alias
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_can_disable_loaded_alias_patching() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
imported_sleep(0)
assert detector.violations == []
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
await asyncio.to_thread(time.sleep, 0)
assert detector.violations == []
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
await asyncio.to_thread(time.sleep, 0)
assert blocking_io_detector.violations == []
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
def call_sleep() -> list[str]:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
return [violation.name for violation in detector.violations]
assert await asyncio.to_thread(call_sleep) == []
async def test_fail_on_exit_includes_call_site() -> None:
with pytest.raises(AssertionError) as exc_info:
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
time.sleep(0)
message = str(exc_info.value)
assert "time.sleep" in message
assert "test_fail_on_exit_includes_call_site" in message
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
return f"{method}:{url}"
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
with detect_blocking_io(REQUESTS_ONLY) as detector:
assert requests.get("https://example.invalid") == "get:https://example.invalid"
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
return httpx.Response(200, request=httpx.Request(method, url))
monkeypatch.setattr(httpx.Client, "request", fake_request)
with detect_blocking_io(HTTPX_ONLY) as detector:
with httpx.Client() as client:
response = client.get("https://example.invalid")
assert response.status_code == 200
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(os.walk(tmp_path))
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
original_alias = imported_walk
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(imported_walk(tmp_path))
assert imported_walk is original_alias
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert list(walker)
assert detector.violations == []
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert await asyncio.to_thread(lambda: list(walker))
assert detector.violations == []
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
assert path.read_text(encoding="utf-8") == "content"
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
summary = probe.format_summary()
assert "blocking io probe: 1 violations across 1 tests" in summary
assert "pathlib.Path.read_text" in summary
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
assert probe.format_summary() == "blocking io probe: no violations"
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
assert probe.violation_count == 1
probe.clear()
assert probe.violation_count == 0
assert probe.format_summary() == "blocking io probe: no violations"
@@ -0,0 +1,22 @@
from __future__ import annotations
import time
import pytest
ORIGINAL_SLEEP = time.sleep
def replacement_sleep(seconds: float) -> None:
return None
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(time, "sleep", replacement_sleep)
assert time.sleep is replacement_sleep
@pytest.mark.no_blocking_io_probe
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
assert time.sleep is ORIGINAL_SLEEP
assert getattr(time.sleep, "__wrapped__", None) is None
+1 -1
View File
@@ -761,7 +761,7 @@ class TestChannelManager:
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None):
del assistant_id, context # unused in this test, kept for signature parity
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
@@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching:
assert patched[1].name == "bash"
assert patched[1].status == "error"
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
middleware = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = middleware._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
HumanMessage(content="interruption"),
_tool_msg("call_2", "read"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert isinstance(patched[2], ToolMessage)
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
assert isinstance(patched[3], HumanMessage)
def test_valid_adjacent_tool_results_are_unchanged(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
_tool_msg("call_1", "bash"),
HumanMessage(content="next"),
]
assert mw._build_patched_messages(msgs) is None
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_ai_with_tool_calls([_tc("read", "call_2")]),
_tool_msg("call_1", "bash"),
_tool_msg("call_2", "read"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
assert isinstance(patched[3], AIMessage)
assert isinstance(patched[4], ToolMessage)
assert patched[4].tool_call_id == "call_2"
def test_orphan_tool_message_is_preserved_during_grouping(self):
mw = DanglingToolCallMiddleware()
orphan = _tool_msg("orphan_call", "orphan")
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
orphan,
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert orphan in patched
assert patched.count(orphan) == 1
def test_invalid_tool_call_is_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
@@ -0,0 +1,222 @@
"""Real-LLM end-to-end verification for issue #2884.
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
and the production ``get_available_tools`` pipeline. The only thing we mock is
the MCP tool source we hand-roll two ``@tool``s and inject them through
``deerflow.mcp.cache.get_cached_mcp_tools``.
The flow exercised:
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
that re-enters ``get_available_tools`` on the same task this is the
code path issue #2884 reports). It must call ``tool_search`` to
discover the deferred ``fake_calculator`` tool.
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
``fake_subagent_trigger`` re-enters ``get_available_tools``.
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
so it can actually call it. Without this PR's fix, the re-entry wipes
the promotion and the model can no longer invoke the tool.
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
test run. Run with::
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
PYTHONPATH=. uv run pytest \
tests/test_deferred_tool_promotion_real_llm.py -v -s
"""
from __future__ import annotations
import os
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool as as_tool
# ---------------------------------------------------------------------------
# Skip control: only run when explicitly opted in.
# ---------------------------------------------------------------------------
pytestmark = pytest.mark.skipif(
os.getenv("ONEAPI_E2E") != "1",
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
)
# ---------------------------------------------------------------------------
# Fake "MCP" tools the agent should discover via tool_search.
# Keep them obviously synthetic so the model can pattern-match the search.
# ---------------------------------------------------------------------------
_calls: list[str] = []
@as_tool
def fake_calculator(expression: str) -> str:
"""Evaluate a tiny arithmetic expression like '2 + 2'.
Reserved for the user only call this if the user asks for arithmetic.
"""
_calls.append(f"fake_calculator:{expression}")
try:
# Trivially safe-eval just for the e2e check
allowed = set("0123456789+-*/() .")
if not set(expression) <= allowed:
return "expression contains disallowed characters"
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
except Exception as e:
return f"error: {e}"
@as_tool
def fake_translator(text: str, target_lang: str) -> str:
"""Translate text into the given language code. Decorative — not used."""
_calls.append(f"fake_translator:{text}:{target_lang}")
return f"[{target_lang}] {text}"
# ---------------------------------------------------------------------------
# Pipeline wiring (same shape as the in-process tests).
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_registry_between_tests():
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Build a minimal mock AppConfig and patch the symbol — never call the
real loader, which would trigger ``_apply_singleton_configs`` and
permanently mutate cross-test singletons (memory, title, )."""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Real-LLM e2e test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
"""End-to-end against a real OpenAI-compatible LLM.
The model must:
Turn 1 see ``tool_search`` (deferred tools aren't bound yet) and
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
``fake_subagent_trigger(...)``.
Turn 2 call ``fake_calculator`` and finish.
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
layer recorded in ``_calls`` which proves the model received the
promoted schema after the re-entrant ``get_available_tools`` call.
"""
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
_force_tool_search_enabled(monkeypatch)
_calls.clear()
@as_tool
async def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
Use this whenever the user asks you to delegate work pass a short
description as ``prompt``.
"""
# ``task_tool`` does this internally. Whether the registry-reset that
# used to happen here actually leaks back to the parent task depends
# on asyncio's implicit context-copying semantics (gather creates
# child tasks with copied contexts, so reset_deferred_registry is
# task-local) — but the fix in this PR is what GUARANTEES the
# promotion sticks regardless of which integration path triggers a
# re-entrant ``get_available_tools`` call.
get_available_tools(subagent_enabled=False)
_calls.append(f"fake_subagent_trigger:{prompt}")
return "subagent completed"
tools = get_available_tools() + [fake_subagent_trigger]
model = ChatOpenAI(
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
temperature=0,
max_retries=1,
)
system_prompt = (
"You are a meticulous assistant. Available deferred tools include a "
"calculator and a translator — their schemas are hidden until you "
"search for them via tool_search.\n\n"
"Procedure for the user's request:\n"
" 1. Call tool_search with query 'select:fake_calculator' AND "
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
"to delegate the side work. Put both tool_calls in your first response.\n"
" 2. After both tool messages come back, call fake_calculator with "
"the user's expression.\n"
" 3. Reply with just the numeric result."
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt=system_prompt,
)
result = await graph.ainvoke(
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
config={"recursion_limit": 12},
)
print("\n=== tool calls recorded ===")
for c in _calls:
print(f" {c}")
print("\n=== final message ===")
final_text = result["messages"][-1].content if result["messages"] else "(none)"
print(f" {final_text!r}")
# The smoking-gun assertion: fake_calculator was actually invoked at the
# tool layer. This is only possible if the promoted schema reached the
# model in turn 2, despite the subagent-style re-entry in turn 1.
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
# And the math should actually be done correctly (sanity that the LLM
# really used the result, not just hallucinated the answer).
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
@@ -0,0 +1,390 @@
"""Reproduce + regression-guard issue #2884.
Hypothesis from the issue:
``tools.tools.get_available_tools`` unconditionally calls
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
every time it is invoked. If anything calls ``get_available_tools`` again
during the same async context (after the agent has promoted tools via
``tool_search``), the promotion is wiped and the next model call hides the
tool's schema again.
These tests pin two things:
A. **At the unit boundary** verify the failure mode directly. Promote a
tool in the registry, then call ``get_available_tools`` again and observe
that the ContextVar registry is reset and the promotion is lost.
B. **At the graph-execution boundary** drive a real ``create_agent`` graph
with the real ``DeferredToolFilterMiddleware`` through two model turns.
The first turn calls ``tool_search`` which promotes a tool. The second
turn must see that tool's schema in ``request.tools``. If
``get_available_tools`` were to run again between the two turns and reset
the registry, the second turn's filter would strip the tool.
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
unmodified; mock only the LLM and the MCP tool source. Patch
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
``get_available_tools`` resolves via lazy import) to return our fixture
tools so we don't need a real MCP server.
"""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
# ---------------------------------------------------------------------------
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
# ---------------------------------------------------------------------------
@as_tool
def fake_mcp_search(query: str) -> str:
"""Pretend to search a knowledge base for the given query."""
return f"results for {query}"
@as_tool
def fake_mcp_fetch(url: str) -> str:
"""Pretend to fetch a page at the given URL."""
return f"content of {url}"
@pytest.fixture(autouse=True)
def _supply_env(monkeypatch: pytest.MonkeyPatch):
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
@pytest.fixture(autouse=True)
def _reset_deferred_registry_between_tests():
"""Each test must start with a clean ContextVar.
The registry lives in a module-level ContextVar with no per-task isolation
in a synchronous test runner, so one test's promotion can leak into the
next and silently break filter assertions.
"""
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
"""Make get_available_tools believe an MCP server is registered.
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
that both ``AppConfig.from_file`` (which calls
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
see a valid instance. Then point the MCP tool cache at our fixture tools.
"""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Force config.tool_search.enabled=True without touching the yaml.
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
which permanently mutates module-level singletons (``_memory_config``,
``_title_config``, ) to match the developer's ``config.yaml`` — even
after pytest restores our patch. That leaks across tests later in the
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
require ``_memory_config.enabled = True``, which is the dataclass default
but FALSE in the actual yaml).
Build a minimal mock AppConfig instead and never call the real loader.
"""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Section A — direct unit-level reproduction
# ---------------------------------------------------------------------------
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
Step 1: call get_available_tools() registers MCP tools as deferred.
Step 2: simulate the agent calling tool_search by promoting one tool.
Step 3: call get_available_tools() again (the same code path
``task_tool`` exercises mid-run).
Assertion: after step 3, the promoted tool is STILL promoted (not
re-deferred). On ``main`` before the fix, step 3's
``reset_deferred_registry()`` wiped the promotion and re-registered
every MCP tool as deferred this assertion fired with
``REGRESSION (#2884)``.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# Step 1: first call — both MCP tools start deferred
get_available_tools()
reg1 = get_deferred_registry()
assert reg1 is not None
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
# Step 2: simulate tool_search promoting one of them
reg1.promote({"fake_mcp_search"})
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
# Step 3: second call — registry must NOT silently undo the promotion
get_available_tools()
reg2 = get_deferred_registry()
assert reg2 is not None
deferred_after = {e.name for e in reg2.entries}
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
# ---------------------------------------------------------------------------
# Section B — graph-execution reproduction
# ---------------------------------------------------------------------------
class _ToolSearchPromotingModel(FakeToolCallingModel):
"""Two-turn model that:
Turn 1 emit a tool_call for ``tool_search`` (the real one)
Turn 2 emit a tool_call for ``fake_mcp_search`` (the promoted tool)
Records the tools it received on each turn so the test can inspect what
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
"""
bound_tools_per_turn: list[list[str]] = []
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
# Record the tool names the model would see in this turn
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
self.bound_tools_per_turn.append(names)
return self
def _build_promoting_model() -> _ToolSearchPromotingModel:
return _ToolSearchPromotingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
}
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
"""End-to-end: drive a real create_agent graph through two turns.
Without the fix, the second-turn bind_tools call should NOT contain
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
registry and strips it). With the fix, the model sees the schema and can
invoke it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
tools = get_available_tools()
# Sanity: the assembled tool list includes the deferred tools (they're in
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
# they reach the model)
tool_names = {getattr(t, "name", "") for t in tools}
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
model = _build_promoting_model()
model.bound_tools_per_turn = [] # reset class-level recorder
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
turn1 = set(model.bound_tools_per_turn[0])
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
# This is the load-bearing assertion for issue #2884.
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
turn2 = set(model.bound_tools_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
# ---------------------------------------------------------------------------
# Section C — the actual issue #2884 trigger: a re-entrant
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
# wipe the parent's promotion.
# ---------------------------------------------------------------------------
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
(the same pattern that happens when ``task_tool`` builds a subagent's
toolset mid-run) must not wipe the parent agent's tool_search promotions.
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
``get_available_tools`` again exactly what ``task_tool`` does when it
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
promoted tool. Without the fix, the re-entry wipes the registry and
the filter re-hides it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# The trigger tool simulates what task_tool does internally: rebuild the
# toolset by calling get_available_tools while the registry is live.
@as_tool
def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
get_available_tools(subagent_enabled=False)
return f"spawned subagent for: {prompt}"
tools = get_available_tools() + [fake_subagent_trigger]
bound_per_turn: list[list[str]] = []
class _Model(FakeToolCallingModel):
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
return self
model = _Model(
responses=[
# Turn 1: do both in one batch — promote AND trigger the
# subagent-style rebuild. LangGraph executes them in order in the
# same agent step.
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
},
{
"name": "fake_subagent_trigger",
"args": {"prompt": "go"},
"id": "call_trigger_1",
"type": "tool_call",
},
],
),
# Turn 2: try to invoke the promoted tool. The model gets this
# turn only if turn 1's bind_tools recorded what the filter sent.
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-subagent-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1 sanity: deferred tool not visible yet
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
# re-entrant get_available_tools call that happened in turn 1's tool batch.
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
turn2 = set(bound_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
+83 -1
View File
@@ -1,6 +1,6 @@
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
assert elapsed < 0.1
assert finished.is_set() is False
assert finished.wait(1.0) is True
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
assert queue.pending_count == 2
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(
thread_id="thread-1",
messages=["first"],
agent_name="agent-a",
correction_detected=True,
)
queue.add(
thread_id="thread-1",
messages=["second"],
agent_name="agent-a",
correction_detected=False,
)
assert queue.pending_count == 1
assert queue._queue[0].agent_name == "agent-a"
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with (
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
patch("deerflow.agents.memory.queue.time.sleep"),
):
queue.flush()
assert mock_updater.update_memory.call_count == 2
mock_updater.update_memory.assert_has_calls(
[
call(
messages=["agent-a"],
thread_id="thread-1",
agent_name="agent-a",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
call(
messages=["agent-b"],
thread_id="thread-1",
agent_name="agent-b",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
]
)
@@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id():
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
assert q._queue[0].user_id == "alice"
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock()
@@ -37,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
mock_updater.update_memory.assert_called_once()
call_kwargs = mock_updater.update_memory.call_args.kwargs
assert call_kwargs["user_id"] == "alice"
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
assert q.pending_count == 1
assert q._queue[0].messages == ["second"]
assert q._queue[0].user_id == "alice"
assert q._queue[0].agent_name == "researcher"
def test_add_nowait_keeps_different_users_separate():
q = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
patch.object(q, "_schedule_timer"),
):
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
-1
View File
@@ -454,7 +454,6 @@ class TestAStream:
@pytest.mark.asyncio
async def test_with_tools_emits_tool_call_chunk(self):
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
+33
View File
@@ -268,6 +268,39 @@ class TestEdgeCases:
class TestDbRunEventStore:
"""Tests for DbRunEventStore with temp SQLite."""
@pytest.mark.anyio
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
from sqlalchemy.dialects import postgresql
from deerflow.runtime.events.store.db import DbRunEventStore
class FakeSession:
def __init__(self):
self.dialect = postgresql.dialect()
self.execute_calls = []
self.scalar_stmt = None
def get_bind(self):
return self
async def execute(self, stmt, params=None):
self.execute_calls.append((stmt, params))
async def scalar(self, stmt):
self.scalar_stmt = stmt
return 41
session = FakeSession()
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
assert max_seq == 41
assert session.execute_calls
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
assert "FOR UPDATE" not in compiled
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
+48
View File
@@ -3,7 +3,10 @@
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import re
import pytest
from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
@@ -278,3 +281,48 @@ class TestRunRepository:
assert row4["model_name"] is None
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
captured = []
class FakeResult:
def all(self):
return []
class FakeSession:
async def execute(self, stmt):
captured.append(stmt)
return FakeResult()
class FakeSessionContext:
async def __aenter__(self):
return FakeSession()
async def __aexit__(self, exc_type, exc, tb):
return None
repo = RunRepository(lambda: FakeSessionContext())
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg == {
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_runs": 0,
"by_model": {},
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
}
assert len(captured) == 1
stmt = captured[0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
assert select_match is not None
assert group_by_match is not None
assert select_match.group(1) == group_by_match.group(1)
+105 -2
View File
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
)
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
def _runtime(
thread_id: str | None = "thread-1",
agent_name: str | None = None,
user_id: str | None = None,
) -> SimpleNamespace:
context = {}
if thread_id is not None:
context["thread_id"] = thread_id
if agent_name is not None:
context["agent_name"] = agent_name
if user_id is not None:
context["user_id"] = user_id
return SimpleNamespace(context=context)
@@ -50,7 +56,8 @@ def _middleware(
preserve_recent_skill_tokens_per_skill: int = 0,
) -> DeerFlowSummarizationMiddleware:
model = MagicMock()
model.invoke.return_value = SimpleNamespace(text="compressed summary")
model.invoke.return_value = AIMessage(content="compressed summary")
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
return DeerFlowSummarizationMiddleware(
model=model,
trigger=trigger,
@@ -634,3 +641,99 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
# ---------------------------------------------------------------------------
# Issue #2804: summary text must not leak to the frontend via streaming
# ---------------------------------------------------------------------------
def test_build_new_messages_sets_hide_from_ui() -> None:
"""The summary HumanMessage must carry hide_from_ui so the frontend filters it."""
middleware = _middleware()
messages = middleware._build_new_messages("test summary")
assert len(messages) == 1
msg = messages[0]
assert msg.name == "summary"
assert msg.additional_kwargs.get("hide_from_ui") is True
assert "test summary" in msg.content
def test_create_summary_suppresses_callbacks() -> None:
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
in the invoke config to suppress inherited LangGraph stream callbacks."""
middleware = _middleware()
middleware._create_summary(_messages())
middleware.model.with_config.assert_called_once_with(callbacks=[])
bound = middleware.model.with_config.return_value
bound.invoke.assert_called_once()
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
assert call_config is not None
assert call_config.get("callbacks") == []
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
@pytest.mark.anyio
async def test_acreate_summary_suppresses_callbacks() -> None:
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
middleware = _middleware()
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
await middleware._acreate_summary(_messages())
middleware.model.with_config.assert_called_once_with(callbacks=[])
bound = middleware.model.with_config.return_value
bound.ainvoke.assert_called_once()
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
assert call_config is not None
assert call_config.get("callbacks") == []
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
def test_before_model_summary_message_has_hide_from_ui() -> None:
"""End-to-end: the emitted state update contains a summary message with hide_from_ui."""
middleware = _middleware()
result = middleware.before_model({"messages": _messages()}, _runtime())
emitted = result["messages"]
summary_msg = emitted[1]
assert summary_msg.name == "summary"
assert summary_msg.additional_kwargs.get("hide_from_ui") is True
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="main",
agent_name="researcher",
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
"""AIMessage.content can be a list of content blocks; _extract_summary_text
must normalize to plain text via the .text property instead of producing
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
middleware = _middleware()
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
assert middleware._extract_summary_text(response) == "A summary of the chat."
# Plain string content still works
response_str = AIMessage(content="Plain summary")
assert middleware._extract_summary_text(response_str) == "Plain summary"
+153
View File
@@ -59,12 +59,15 @@ def _make_result(
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
token_usage_records: list[dict] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
token_usage_records=token_usage_records or [],
usage_reported=False,
)
@@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]
@pytest.mark.parametrize(
"status, expected_type",
[
(FakeSubagentStatus.COMPLETED, "task_completed"),
(FakeSubagentStatus.FAILED, "task_failed"),
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
],
)
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
"""Terminal task events include a usage summary from token_usage_records."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
records = [
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
]
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-usage",
)
terminal_events = [e for e in events if e["type"] == expected_type]
assert len(terminal_events) == 1
assert terminal_events[0]["usage"] == {
"input_tokens": 300,
"output_tokens": 130,
"total_tokens": 430,
}
def test_terminal_event_usage_none_when_no_records(monkeypatch):
"""Terminal event has usage=None when token_usage_records is empty."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-no-records",
)
completed = [e for e in events if e["type"] == "task_completed"]
assert len(completed) == 1
assert completed[0]["usage"] is None
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=FileNotFoundError("missing config")),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
runtime = _make_runtime(app_config=app_config)
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
task_tool_module._subagent_usage_cache.clear()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-disabled-cache",
)
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
runtime = _make_runtime(app_config=app_config)
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
with pytest.raises(RuntimeError, match="poll failed"):
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-error",
)
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
@@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic:
assert middleware._should_generate_title(state) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
_set_test_title_config(max_chars=12, model_name=None)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
+364 -20
View File
@@ -1,14 +1,19 @@
"""Tests for TodoMiddleware context-loss detection."""
import asyncio
from unittest.mock import MagicMock
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import PrivateAttr
from deerflow.agents.middlewares.todo_middleware import (
TodoMiddleware,
_completion_reminder_count,
_format_todos,
_has_tool_call_intent_or_error,
_reminder_in_messages,
_todos_in_messages,
)
@@ -22,9 +27,35 @@ def _reminder_msg():
return HumanMessage(name="todo_reminder", content="reminder")
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_runtime():
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
return runtime
def _make_runtime_for(thread_id: str, run_id: str):
runtime = _make_runtime()
runtime.context = {"thread_id": thread_id, "run_id": run_id}
return runtime
@@ -161,10 +192,62 @@ def _completion_reminder_msg():
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
def _todo_completion_reminders(messages):
reminders = []
for message in messages:
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
reminders.append(message)
return reminders
def _ai_no_tool_calls():
return AIMessage(content="I'm done!")
def _ai_with_invalid_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[
{
"type": "invalid_tool_call",
"id": "write_file:36",
"name": "write_file",
"args": "{invalid",
"error": "Failed to parse tool arguments",
}
],
)
def _ai_with_raw_provider_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[],
additional_kwargs={
"tool_calls": [
{
"id": "raw-tool-call",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
}
]
},
)
def _ai_with_legacy_function_call():
return AIMessage(
content="",
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
)
def _ai_with_tool_finish_reason():
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
def _incomplete_todos():
return [
{"status": "completed", "content": "Step 1"},
@@ -194,6 +277,36 @@ class TestCompletionReminderCount:
assert _completion_reminder_count(msgs) == 1
class TestToolCallIntentOrError:
def test_false_for_plain_final_answer(self):
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
def test_true_for_structured_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
def test_true_for_invalid_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
def test_true_for_raw_provider_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
def test_true_for_legacy_function_call(self):
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
def test_true_for_tool_finish_reason(self):
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
# Sentinel for LangChain compatibility: if future AIMessage versions add
# new top-level tool/function-call fields, this test should fail. When
# it does, update `_has_tool_call_intent_or_error()` so the completion
# reminder guard explicitly decides whether each new field means "not a
# clean final answer"; the helper has a matching comment pointing back
# to this sentinel.
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
class TestAfterModel:
def test_returns_none_when_agent_still_using_tools(self):
mw = TodoMiddleware()
@@ -235,68 +348,299 @@ class TestAfterModel:
}
assert mw.after_model(state, _make_runtime()) is None
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, _make_runtime())
result = mw.after_model(state, runtime)
assert result is not None
assert result["jump_to"] == "model"
assert len(result["messages"]) == 1
reminder = result["messages"][0]
assert "messages" not in result
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_called_once()
reminder = request.override.call_args.kwargs["messages"][-1]
assert isinstance(reminder, HumanMessage)
assert reminder.name == "todo_completion_reminder"
assert reminder.additional_kwargs["hide_from_ui"] is True
assert "Step 2" in reminder.content
assert "Step 3" in reminder.content
handler.assert_called_once_with("patched-request")
def test_reminder_lists_only_incomplete_items(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, _make_runtime())
content = result["messages"][0].content
result = mw.after_model(state, runtime)
assert result is not None
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
mw.wrap_model_call(request, MagicMock(return_value="response"))
content = request.override.call_args.kwargs["messages"][-1].content
assert "Step 1" not in content # completed — should not appear
assert "Step 2" in content
assert "Step 3" in content
def test_allows_exit_after_max_reminders(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_completion_reminder_msg(),
_completion_reminder_msg(),
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is None
def test_still_sends_reminder_before_cap(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
result = mw.after_model(state, runtime)
assert result is not None
assert result["jump_to"] == "model"
def test_does_not_trigger_for_invalid_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_invalid_tool_calls()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_still_sends_reminder_before_cap(self):
def test_does_not_trigger_for_raw_provider_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [
_completion_reminder_msg(), # 1 reminder so far
_ai_no_tool_calls(),
],
"messages": [_ai_with_raw_provider_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, _make_runtime())
assert result is not None
assert result["jump_to"] == "model"
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_legacy_function_call(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_legacy_function_call()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_tool_finish_reason(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_tool_finish_reason()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
class TestAafterModel:
def test_delegates_to_sync(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
result = asyncio.run(mw.aafter_model(state, runtime))
assert result is not None
assert result["jump_to"] == "model"
assert result["messages"][0].name == "todo_completion_reminder"
assert "messages" not in result
class TestWrapModelCall:
def test_no_pending_reminder_passthrough(self):
mw = TodoMiddleware()
request = MagicMock()
request.runtime = _make_runtime()
request.messages = [HumanMessage(content="hi")]
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
def test_pending_reminder_is_injected_once(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
request.override.reset_mock()
handler.reset_mock()
handler.return_value = "second-response"
assert mw.wrap_model_call(request, handler) == "second-response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
class TestTodoMiddlewareAgentGraphIntegration:
def test_completion_reminder_is_transient_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
},
}
],
),
AIMessage(content="premature final 1"),
AIMessage(content="premature final 2"),
AIMessage(content="premature final 3"),
],
)
graph = create_agent(model=model, tools=[], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "finish all todos")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
)
assert len(model.seen_messages) == 4
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
assert reminders_by_call[0] == []
assert reminders_by_call[1] == []
assert len(reminders_by_call[2]) == 1
assert len(reminders_by_call[3]) == 1
assert "Step 1" not in reminders_by_call[2][0].content
assert "Step 2" in reminders_by_call[2][0].content
persisted_reminders = _todo_completion_reminders(result["messages"])
assert persisted_reminders == []
assert result["messages"][-1].content == "premature final 3"
assert result["todos"] == [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
assert mw._pending_completion_reminders == {}
assert mw._completion_reminder_counts == {}
class TestRunScopedReminderCleanup:
def test_before_agent_clears_stale_count_without_pending_reminder(self):
mw = TodoMiddleware()
stale_runtime = _make_runtime()
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
current_runtime = _make_runtime()
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
other_thread_runtime = _make_runtime()
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, other_thread_runtime) is not None
# Simulate a model call that drained the pending message, followed by an
# abnormal run end where after_agent did not clear the reminder count.
assert mw._drain_completion_reminders(stale_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
mw.before_agent({}, current_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 2
first_runtime = _make_runtime_for("thread-a", "run-a")
second_runtime = _make_runtime_for("thread-b", "run-b")
third_runtime = _make_runtime_for("thread-c", "run-c")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, first_runtime) is not None
# Simulate the normal model request path: pending reminder is consumed,
# but the run count remains until after_agent() or stale cleanup.
assert mw._drain_completion_reminders(first_runtime)
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
assert mw.after_model(state, second_runtime) is not None
assert mw.after_model(state, third_runtime) is not None
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
def test_size_guard_prunes_pending_and_count_state_together(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 1
stale_runtime = _make_runtime_for("thread-a", "run-a")
current_runtime = _make_runtime_for("thread-b", "run-b")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, current_runtime) is not None
assert mw._drain_completion_reminders(stale_runtime) == []
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
class TestAwrapModelCall:
def test_async_pending_reminder_is_injected(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = AsyncMock(return_value="response")
result = asyncio.run(mw.awrap_model_call(request, handler))
assert result == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
handler.assert_awaited_once_with("patched-request")
+48 -1
View File
@@ -1,9 +1,10 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove",
}
]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}
+4 -8
View File
@@ -65,8 +65,7 @@ def _make_minimal_config(tools):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Config-loaded async-only tools can still be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
@@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash,
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
def test_no_duplicates_returned(mock_bash, mock_cfg):
"""get_available_tools() never returns two tools with the same name."""
mock_cfg.return_value = _make_minimal_config([])
@@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
def test_first_occurrence_wins(mock_bash, mock_cfg):
"""When duplicates exist, the first occurrence is kept."""
mock_cfg.return_value = _make_minimal_config([])
@@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
"""A warning is logged for every skipped duplicate."""
import logging
+22 -5
View File
@@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.12"
resolution-markers = [
"python_full_version >= '3.14' and sys_platform == 'win32'",
@@ -763,6 +763,9 @@ dependencies = [
]
[package.optional-dependencies]
discord = [
{ name = "discord-py" },
]
postgres = [
{ name = "deerflow-harness", extra = ["postgres"] },
]
@@ -781,6 +784,7 @@ requires-dist = [
{ name = "deerflow-harness", editable = "packages/harness" },
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
{ name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" },
{ name = "email-validator", specifier = ">=2.0.0" },
{ name = "fastapi", specifier = ">=0.115.0" },
{ name = "httpx", specifier = ">=0.28.0" },
@@ -795,7 +799,7 @@ requires-dist = [
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
]
provides-extras = ["postgres"]
provides-extras = ["postgres", "discord"]
[package.metadata.requires-dev]
dev = [
@@ -923,6 +927,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" },
]
[[package]]
name = "discord-py"
version = "2.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "audioop-lts", marker = "python_full_version >= '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" },
]
[[package]]
name = "distro"
version = "1.9.0"
@@ -2005,7 +2022,7 @@ wheels = [
[[package]]
name = "langsmith"
version = "0.7.36"
version = "0.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
@@ -2018,9 +2035,9 @@ dependencies = [
{ name = "xxhash" },
{ name = "zstandard" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" }
sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" },
{ url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" },
]
[package.optional-dependencies]
+8
View File
@@ -1029,6 +1029,14 @@ run_events:
# client_secret: $DINGTALK_CLIENT_SECRET
# allowed_users: [] # empty = allow all
# card_template_id: "" # Optional: AI Card template ID for streaming updates
#
# discord:
# enabled: false
# bot_token: $DISCORD_BOT_TOKEN
# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID
# mention_only: false # If true, only respond when the bot is mentioned
# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention)
# thread_mode: false # If true, group a channel conversation into a thread
# ============================================================================
# Guardrails Configuration
+21 -3
View File
@@ -28,6 +28,10 @@ http {
set $gateway_upstream gateway:8001;
set $frontend_upstream frontend:3000;
# Default proxy settings for all locations (streaming/SSE support)
proxy_buffering off;
proxy_cache off;
# Keep the unified nginx endpoint same-origin by default. When split
# frontend/backend or port-forwarded deployments need browser CORS,
# configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and
@@ -49,8 +53,6 @@ http {
proxy_set_header Connection '';
# SSE/Streaming support
proxy_buffering off;
proxy_cache off;
proxy_set_header X-Accel-Buffering no;
# Timeouts for long-running requests
@@ -70,6 +72,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Memory endpoint
@@ -80,6 +83,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: MCP configuration endpoint
@@ -90,6 +94,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Skills configuration endpoint
@@ -100,6 +105,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Agents endpoint
@@ -110,6 +116,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Uploads endpoint
@@ -124,6 +131,8 @@ http {
# Large file upload support
client_max_body_size 100M;
proxy_request_buffering off;
# Disable response buffering to avoid permission errors
}
# Custom API: Other endpoints under /api/threads
@@ -134,6 +143,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# API Documentation: Swagger UI
@@ -144,6 +154,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# API Documentation: ReDoc
@@ -154,6 +165,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# API Documentation: OpenAPI Schema
@@ -164,6 +176,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Health check endpoint (gateway)
@@ -174,6 +187,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# ── Provisioner API (sandbox management) ────────────────────────
@@ -187,6 +201,7 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*).
@@ -198,6 +213,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Disable buffering to avoid permission errors when nginx
# runs as a non-root user (e.g. local development).
}
# All other requests go to frontend
@@ -220,4 +238,4 @@ http {
proxy_read_timeout 600s;
}
}
}
}
+41
View File
@@ -70,6 +70,11 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Disable buffering to avoid permission errors when nginx
# runs as a non-root user (e.g. local development).
proxy_buffering off;
proxy_cache off;
}
# Custom API: Memory endpoint
@@ -80,6 +85,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: MCP configuration endpoint
@@ -90,6 +98,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: Skills configuration endpoint
@@ -100,6 +111,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: Agents endpoint
@@ -110,6 +124,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: Uploads endpoint
@@ -124,6 +141,10 @@ http {
# Large file upload support
client_max_body_size 100M;
proxy_request_buffering off;
# Disable response buffering to avoid permission errors
proxy_buffering off;
proxy_cache off;
}
# Custom API: Other endpoints under /api/threads
@@ -134,6 +155,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# API Documentation: Swagger UI
@@ -144,6 +168,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# API Documentation: ReDoc
@@ -154,6 +181,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# API Documentation: OpenAPI Schema
@@ -164,6 +194,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Health check endpoint (gateway)
@@ -174,6 +207,9 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Catch-all for any /api/* prefix not matched by a more specific block above.
@@ -193,6 +229,11 @@ http {
# Auth endpoints set HttpOnly cookies — make sure nginx doesn't
# strip the Set-Cookie header from upstream responses.
proxy_pass_header Set-Cookie;
# Disable buffering to avoid permission errors when nginx
# runs as a non-root user (e.g. local development).
proxy_buffering off;
proxy_cache off;
}
# All other requests go to frontend
+3 -4
View File
@@ -10,7 +10,6 @@ import { FlickeringGrid } from "@/components/ui/flickering-grid";
import { Input } from "@/components/ui/input";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
/**
* Validate next parameter
@@ -72,7 +71,7 @@ export default function LoginPage() {
useEffect(() => {
let cancelled = false;
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
void fetch("/api/v1/auth/setup-status")
.then((r) => r.json())
.then((data: { needs_setup?: boolean }) => {
if (!cancelled && data.needs_setup) {
@@ -95,8 +94,8 @@ export default function LoginPage() {
try {
const endpoint = isLogin
? `${getBackendBaseURL()}/api/v1/auth/login/local`
: `${getBackendBaseURL()}/api/v1/auth/register`;
? "/api/v1/auth/login/local"
: "/api/v1/auth/register";
const body = isLogin
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
: JSON.stringify({ email, password });
+14 -18
View File
@@ -10,7 +10,6 @@ import { Input } from "@/components/ui/input";
import { getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
type SetupMode = "loading" | "init_admin" | "change_password";
@@ -37,7 +36,7 @@ export default function SetupPage() {
setMode("change_password");
} else if (!isAuthenticated) {
// Check if the system has no users yet
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
void fetch("/api/v1/auth/setup-status")
.then((r) => r.json())
.then((data: { needs_setup?: boolean }) => {
if (cancelled) return;
@@ -73,7 +72,7 @@ export default function SetupPage() {
setLoading(true);
try {
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/initialize`, {
const res = await fetch("/api/v1/auth/initialize", {
method: "POST",
headers: { "Content-Type": "application/json" },
credentials: "include",
@@ -114,22 +113,19 @@ export default function SetupPage() {
setLoading(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/v1/auth/change-password`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
credentials: "include",
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
const res = await fetch("/api/v1/auth/change-password", {
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
);
credentials: "include",
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
});
if (!res.ok) {
const data = await res.json();
@@ -66,6 +66,7 @@ export default function AgentChatPage() {
thread,
pendingUsageMessages,
sendMessage,
isUploading,
isHistoryLoading,
hasMoreHistory,
loadMoreHistory,
@@ -106,7 +107,11 @@ export default function AgentChatPage() {
const handleSubmit = useCallback(
(message: PromptInputMessage) => {
void sendMessage(threadId, message, { agent_name });
const sendPromise = sendMessage(threadId, message, { agent_name });
if (message.files.length > 0) {
return sendPromise;
}
void sendPromise;
},
[sendMessage, threadId, agent_name],
);
@@ -243,7 +248,10 @@ export default function AgentChatPage() {
<AgentWelcome agent={agent} agentName={agent_name} />
)
}
disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"}
disabled={
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
isUploading
}
onContextChange={(context) => setSettings("context", context)}
onSubmit={handleSubmit}
onStop={handleStop}
@@ -109,7 +109,11 @@ export default function ChatPage() {
const handleSubmit = useCallback(
(message: PromptInputMessage) => {
void sendMessage(threadId, message);
const sendPromise = sendMessage(threadId, message);
if (message.files.length > 0) {
return sendPromise;
}
void sendPromise;
},
[sendMessage, threadId],
);
+1 -2
View File
@@ -4,7 +4,6 @@ import { redirect } from "next/navigation";
import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
import { WorkspaceContent } from "./workspace-content";
@@ -45,7 +44,7 @@ export default async function WorkspaceLayout({
Retry
</Link>
<Link
href={`${getBackendBaseURL()}/api/v1/auth/logout`}
href="/api/v1/auth/logout"
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
>
Logout &amp; Reset
@@ -499,6 +499,10 @@ export const PromptInput = ({
// Keep a ref to files for cleanup on unmount (avoids stale closure)
const filesRef = useRef(files);
filesRef.current = files;
const providerTextRef = useRef("");
if (usingProvider) {
providerTextRef.current = controller.textInput.value;
}
const openFileDialogLocal = useCallback(() => {
inputRef.current?.click();
@@ -768,6 +772,24 @@ export const PromptInput = ({
}
// Convert blob URLs to data URLs asynchronously
const submittedFileIds = files.map((file) => file.id);
const clearSubmittedState = () => {
const currentFileIds = new Set(filesRef.current.map((file) => file.id));
const submittedFileIdsStillPresent = submittedFileIds.filter((id) =>
currentFileIds.has(id),
);
if (submittedFileIdsStillPresent.length === filesRef.current.length) {
clear();
} else {
for (const id of submittedFileIdsStillPresent) {
remove(id);
}
}
if (usingProvider && providerTextRef.current === text) {
controller.textInput.clear();
}
};
Promise.all(
files.map(async ({ id, ...item }) => {
if (item.file instanceof File) {
@@ -793,20 +815,14 @@ export const PromptInput = ({
if (result instanceof Promise) {
result
.then(() => {
clear();
if (usingProvider) {
controller.textInput.clear();
}
clearSubmittedState();
})
.catch(() => {
// Don't clear on error - user may want to retry
});
} else {
// Sync function completed without throwing, clear attachments
clear();
if (usingProvider) {
controller.textInput.clear();
}
clearSubmittedState();
}
} catch {
// Don't clear on error - user may want to retry
@@ -110,6 +110,7 @@ export function InputBox({
threadId,
initialValue,
onContextChange,
onFollowupsVisibilityChange,
onSubmit,
onStop,
...props
@@ -142,7 +143,8 @@ export function InputBox({
reasoning_effort?: "minimal" | "low" | "medium" | "high";
},
) => void;
onSubmit?: (message: PromptInputMessage) => void;
onFollowupsVisibilityChange?: (visible: boolean) => void;
onSubmit?: (message: PromptInputMessage) => void | Promise<void>;
onStop?: () => void;
}) {
const { t } = useI18n();
@@ -251,12 +253,12 @@ export function InputBox({
);
const handleSubmit = useCallback(
async (message: PromptInputMessage) => {
(message: PromptInputMessage) => {
if (status === "streaming") {
onStop?.();
return;
}
if (!message.text) {
if (!message.text.trim() && message.files.length === 0) {
return;
}
setFollowups([]);
@@ -274,11 +276,14 @@ export function InputBox({
selectedModel?.supports_thinking ?? false,
),
});
setTimeout(() => onSubmit?.(message), 0);
return;
return new Promise<void>((resolve, reject) => {
setTimeout(() => {
Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject);
}, 0);
});
}
onSubmit?.(message);
return onSubmit?.(message);
},
[
context,
@@ -348,6 +353,14 @@ export function InputBox({
!followupsHidden &&
(followupsLoading || followups.length > 0);
useEffect(() => {
onFollowupsVisibilityChange?.(showFollowups);
}, [onFollowupsVisibilityChange, showFollowups]);
useEffect(() => {
return () => onFollowupsVisibilityChange?.(false);
}, [onFollowupsVisibilityChange]);
useEffect(() => {
messagesRef.current = thread.messages;
}, [thread.messages]);
@@ -12,13 +12,11 @@ function TokenUsageSummary({
inputTokens,
outputTokens,
totalTokens,
unavailable = false,
}: {
className?: string;
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
unavailable?: boolean;
}) {
const { t } = useI18n();
@@ -33,21 +31,15 @@ function TokenUsageSummary({
<CoinsIcon className="size-3" />
{t.tokenUsage.label}
</span>
{!unavailable ? (
<>
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</>
) : (
<span>{t.tokenUsage.unavailableShort}</span>
)}
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</div>
);
}
@@ -55,7 +47,7 @@ function TokenUsageSummary({
export function MessageTokenUsageList({
className,
enabled = false,
isLoading = false,
isLoading: _isLoading = false,
messages,
}: {
className?: string;
@@ -63,7 +55,7 @@ export function MessageTokenUsageList({
isLoading?: boolean;
messages: Message[];
}) {
if (!enabled || isLoading) {
if (!enabled) {
return null;
}
@@ -75,13 +67,16 @@ export function MessageTokenUsageList({
const usage = accumulateUsage(aiMessages);
if (!usage) {
return null;
}
return (
<TokenUsageSummary
className={className}
inputTokens={usage?.inputTokens}
outputTokens={usage?.outputTokens}
totalTokens={usage?.totalTokens}
unavailable={!usage}
inputTokens={usage.inputTokens}
outputTokens={usage.outputTokens}
totalTokens={usage.totalTokens}
/>
);
}
@@ -8,7 +8,6 @@ import { Input } from "@/components/ui/input";
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks";
import { SettingsSection } from "./settings-section";
@@ -39,20 +38,17 @@ export function AccountSettingsPage() {
setLoading(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/v1/auth/change-password`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
}),
const res = await fetch("/api/v1/auth/change-password", {
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
);
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
}),
});
if (!res.ok) {
const data = await res.json();
+2 -4
View File
@@ -10,8 +10,6 @@ import React, {
type ReactNode,
} from "react";
import { getBackendBaseURL } from "@/core/config";
import { type User, buildLoginUrl } from "./types";
// Re-export for consumers
@@ -58,7 +56,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
const refreshUser = useCallback(async () => {
try {
setIsLoading(true);
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/me`, {
const res = await fetch("/api/v1/auth/me", {
credentials: "include",
});
@@ -90,7 +88,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
setUser(null);
try {
await fetch(`${getBackendBaseURL()}/api/v1/auth/logout`, {
await fetch("/api/v1/auth/logout", {
method: "POST",
credentials: "include",
});
+2 -2
View File
@@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
return hasUsage ? cumulative : null;
}
function hasNonZeroUsage(
export function hasNonZeroUsage(
usage: TokenUsage | null | undefined,
): usage is TokenUsage {
return (
@@ -75,7 +75,7 @@ function hasNonZeroUsage(
);
}
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
return {
inputTokens: base.inputTokens + delta.inputTokens,
outputTokens: base.outputTokens + delta.outputTokens,
+9 -6
View File
@@ -26,6 +26,13 @@ export type MessageGroup =
| AssistantClarificationGroup
| AssistantSubagentGroup;
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
"summary",
"loop_warning",
"todo_reminder",
"todo_completion_reminder",
]);
export function getMessageGroups(messages: Message[]): MessageGroup[] {
if (messages.length === 0) {
return [];
@@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
continue;
}
if (message.name === "todo_reminder") {
continue;
}
if (message.type === "human") {
groups.push({ id: message.id, type: "human", messages: [message] });
continue;
@@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
export function isHiddenFromUIMessage(message: Message) {
return (
message.additional_kwargs?.hide_from_ui === true ||
message.name === "summary" ||
message.name === "loop_warning"
(typeof message.name === "string" &&
HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
);
}
+149 -44
View File
@@ -45,15 +45,60 @@ type SendMessageOptions = {
additionalKwargs?: Record<string, unknown>;
};
function mergeMessages(
function isNonEmptyString(value: string | undefined): value is string {
return typeof value === "string" && value.length > 0;
}
function messageIdentity(message: Message): string | undefined {
if (
"tool_call_id" in message &&
typeof message.tool_call_id === "string" &&
message.tool_call_id.length > 0
) {
return `tool:${message.tool_call_id}`;
}
if (typeof message.id === "string" && message.id.length > 0) {
return `message:${message.id}`;
}
return undefined;
}
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
const lastIndexByIdentity = new Map<string, number>();
messages.forEach((message, index) => {
const identity = messageIdentity(message);
if (identity) {
lastIndexByIdentity.set(identity, index);
}
});
return messages.filter((message, index) => {
const identity = messageIdentity(message);
return !identity || lastIndexByIdentity.get(identity) === index;
});
}
function findLatestUnloadedRunIndex(
runs: Run[],
loadedRunIds: ReadonlySet<string>,
): number {
for (let i = runs.length - 1; i >= 0; i--) {
const run = runs[i];
if (run && !loadedRunIds.has(run.run_id)) {
return i;
}
}
return -1;
}
export function mergeMessages(
historyMessages: Message[],
threadMessages: Message[],
optimisticMessages: Message[],
): Message[] {
const threadMessageIds = new Set(
threadMessages
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
.filter(Boolean),
threadMessages.map(messageIdentity).filter(isNonEmptyString),
);
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
@@ -65,28 +110,19 @@ function mergeMessages(
if (!msg) {
continue;
}
if (
(msg?.id && threadMessageIds.has(msg.id)) ||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
) {
const identity = messageIdentity(msg);
if (identity && threadMessageIds.has(identity)) {
cutoff = i;
} else {
break;
}
}
return [
return dedupeMessagesByIdentity([
...historyMessages.slice(0, cutoff),
...threadMessages,
...optimisticMessages,
];
}
function messageIdentity(message: Message): string | undefined {
if ("tool_call_id" in message) {
return message.tool_call_id;
}
return message.id;
]);
}
function getMessagesAfterBaseline(
@@ -296,7 +332,11 @@ export function useThreadStream({
onError(error) {
setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error));
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
@@ -305,7 +345,11 @@ export function useThreadStream({
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
@@ -339,7 +383,11 @@ export function useThreadStream({
useEffect(() => {
startedRef.current = false;
sendInFlightRef.current = false;
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
prevHumanMsgCountRef.current =
latestMessageCountsRef.current.humanMessageCount;
}, [threadId]);
@@ -615,48 +663,105 @@ export function useThreadHistory(threadId: string) {
const runsRef = useRef(runs.data ?? []);
const indexRef = useRef(-1);
const loadingRef = useRef(false);
const pendingLoadRef = useRef(false);
const loadingRunIdRef = useRef<string | null>(null);
const loadedRunIdsRef = useRef<Set<string>>(new Set());
const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Message[]>([]);
loadingRef.current = loading;
const loadMessages = useCallback(async () => {
if (loadingRef.current) {
const pendingRunIndex = findLatestUnloadedRunIndex(
runsRef.current,
loadedRunIdsRef.current,
);
const pendingRun = runsRef.current[pendingRunIndex];
if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) {
pendingLoadRef.current = true;
}
return;
}
if (runsRef.current.length === 0) {
return;
}
const run = runsRef.current[indexRef.current];
if (!run || loadingRef.current) {
return;
}
loadingRef.current = true;
setLoading(true);
try {
setLoading(true);
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
do {
pendingLoadRef.current = false;
const nextRunIndex = findLatestUnloadedRunIndex(
runsRef.current,
loadedRunIdsRef.current,
);
indexRef.current = nextRunIndex;
const run = runsRef.current[nextRunIndex];
if (!run) {
indexRef.current = -1;
return;
}
const requestThreadId = threadIdRef.current;
loadingRunIdRef.current = run.run_id;
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
credentials: "include",
},
credentials: "include",
},
).then((res) => {
return res.json();
});
const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content);
setMessages((prev) => [..._messages, ...prev]);
indexRef.current -= 1;
).then((res) => {
return res.json();
});
const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content);
if (threadIdRef.current !== requestThreadId) {
return;
}
setMessages((prev) =>
dedupeMessagesByIdentity([..._messages, ...prev]),
);
loadedRunIdsRef.current.add(run.run_id);
indexRef.current = findLatestUnloadedRunIndex(
runsRef.current,
loadedRunIdsRef.current,
);
} while (pendingLoadRef.current);
} catch (err) {
console.error(err);
} finally {
loadingRef.current = false;
loadingRunIdRef.current = null;
setLoading(false);
}
}, []);
useEffect(() => {
const threadChanged = threadIdRef.current !== threadId;
threadIdRef.current = threadId;
if (threadChanged) {
runsRef.current = [];
indexRef.current = -1;
pendingLoadRef.current = false;
loadingRunIdRef.current = null;
loadedRunIdsRef.current = new Set();
loadingRef.current = false;
setLoading(false);
setMessages([]);
}
if (runs.data && runs.data.length > 0) {
runsRef.current = runs.data ?? [];
indexRef.current = runs.data.length - 1;
indexRef.current = findLatestUnloadedRunIndex(
runs.data,
loadedRunIdsRef.current,
);
}
loadMessages().catch(() => {
toast.error("Failed to load thread history.");
@@ -665,7 +770,7 @@ export function useThreadHistory(threadId: string) {
const appendMessages = useCallback((_messages: Message[]) => {
setMessages((prev) => {
return [...prev, ..._messages];
return dedupeMessagesByIdentity([...prev, ..._messages]);
});
}, []);
const hasMore = indexRef.current >= 0 || !runs.data;
+62
View File
@@ -48,4 +48,66 @@ test.describe("Chat workspace", () => {
timeout: 10_000,
});
});
test("keeps attachments visible while upload submit is pending", async ({
page,
}) => {
let releaseUpload!: () => void;
const uploadCanFinish = new Promise<void>((resolve) => {
releaseUpload = resolve;
});
let uploadStarted!: () => void;
const uploadStartedPromise = new Promise<void>((resolve) => {
uploadStarted = resolve;
});
await page.route("**/api/threads/*/uploads", async (route) => {
uploadStarted();
await uploadCanFinish;
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
success: true,
message: "Uploaded",
files: [
{
filename: "report.docx",
size: 12,
path: "report.docx",
virtual_path: "/mnt/user-data/uploads/report.docx",
artifact_url: "/api/threads/test/uploads/report.docx",
extension: ".docx",
},
],
}),
});
});
await page.goto("/workspace/chats/new");
const textarea = page.getByPlaceholder(/how can i assist you/i);
await expect(textarea).toBeVisible({ timeout: 15_000 });
const promptForm = page.locator("form").filter({ has: textarea });
await page.getByLabel("Upload files").setInputFiles({
name: "report.docx",
mimeType:
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
buffer: Buffer.from("fake docx"),
});
await expect(promptForm.getByText("report.docx")).toBeVisible();
await textarea.fill("Summarize this document");
await textarea.press("Enter");
await uploadStartedPromise;
await expect(promptForm.getByText("report.docx")).toBeVisible();
releaseUpload();
await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({
timeout: 10_000,
});
await expect(promptForm.getByText("report.docx")).toBeHidden();
});
});
@@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => {
),
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
});
test("hides internal todo reminder messages from message groups", () => {
const messages = [
{
id: "human-1",
type: "human",
content: "Audit the middleware",
},
{
id: "todo-reminder-1",
type: "human",
name: "todo_completion_reminder",
content: "<system_reminder>finish todos</system_reminder>",
},
{
id: "todo-reminder-2",
type: "human",
name: "todo_reminder",
content: "<system_reminder>remember todos</system_reminder>",
},
{
id: "ai-1",
type: "ai",
content: "Done",
},
] as Message[];
const groups = getMessageGroups(messages);
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
expect(
groups.flatMap((group) => group.messages).map((message) => message.id),
).toEqual(["human-1", "ai-1"]);
});
@@ -0,0 +1,64 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import { mergeMessages } from "@/core/threads/hooks";
test("mergeMessages removes duplicate messages already present in history", () => {
const human = {
id: "human-1",
type: "human",
content: "Design an agent",
} as Message;
const ai = {
id: "ai-1",
type: "ai",
content: "Let's design it.",
} as Message;
expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]);
});
test("mergeMessages lets live thread messages replace overlapping history", () => {
const oldHuman = {
id: "human-1",
type: "human",
content: "old",
} as Message;
const liveHuman = {
id: "human-1",
type: "human",
content: "live",
} as Message;
const oldAi = {
id: "ai-1",
type: "ai",
content: "old",
} as Message;
const liveAi = {
id: "ai-1",
type: "ai",
content: "live",
} as Message;
expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([
liveHuman,
liveAi,
]);
});
test("mergeMessages deduplicates tool messages by tool_call_id", () => {
const oldTool = {
id: "tool-message-old",
type: "tool",
tool_call_id: "call-1",
content: "old",
} as Message;
const liveTool = {
id: "tool-message-live",
type: "tool",
tool_call_id: "call-1",
content: "live",
} as Message;
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
});
+81
View File
@@ -72,6 +72,7 @@ def find_config_file() -> Path | None:
_SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$")
_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$")
_KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$")
@@ -141,6 +142,84 @@ def section_value(lines: list[str], section: str, key: str) -> str | None:
return None
def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None:
"""Return the value of a nested YAML key like ``channels.discord.enabled``.
Handles two levels of nesting:
channels:
discord:
enabled: true
"""
parts = section_path.split(".")
if len(parts) != 2:
return None
parent_section, child_section = parts
inside_parent = False
inside_child = False
parent_indent: int | None = None
child_indent: int | None = None
for raw in lines:
line = _strip_comment(raw)
if not line.strip():
continue
stripped = line.lstrip()
indent = len(line) - len(stripped)
# Top-level section match
sect_match = _SECTION_RE.match(line)
if sect_match:
if indent == 0:
inside_parent = sect_match.group(1) == parent_section
inside_child = False
parent_indent = None
child_indent = None
continue
if not inside_parent:
continue
# Track parent indent from first child
if parent_indent is None and indent > 0:
parent_indent = indent
# If indent goes back to 0, we left the parent section
if indent == 0:
inside_parent = False
inside_child = False
continue
# Check if we're at the parent's child level (subsection)
if parent_indent is not None and indent == parent_indent:
# This could be a subsection or a direct key of parent
sub_match = _INDENTED_SECTION_RE.match(line)
if sub_match and sub_match.group(1) == child_section:
inside_child = True
child_indent = None
continue
else:
inside_child = False
continue
if not inside_child:
continue
# We're inside the subsection — track child indent
if child_indent is None and indent > (parent_indent or 0):
child_indent = indent
if child_indent is not None and indent != child_indent:
continue
key_match = _KEY_RE.match(line)
if key_match and key_match.group(1) == key:
return _unquote(key_match.group(2).strip())
return None
def detect_from_config(path: Path) -> list[str]:
try:
text = path.read_text(encoding="utf-8", errors="replace")
@@ -152,6 +231,8 @@ def detect_from_config(path: Path) -> list[str]:
extras.add("postgres")
if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres":
extras.add("postgres")
if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true":
extras.add("discord")
return sorted(extras)