mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 15:11:09 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4dc328e460 |
@@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
|
|||||||
|
|
||||||
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
||||||
|
|
||||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
|
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
|
||||||
|
|
||||||
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
|
|||||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
||||||
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||||
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
|
||||||
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||||
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||||
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||||
|
|||||||
+11
-291
@@ -3,10 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
@@ -23,12 +21,6 @@ class DiscordChannel(Channel):
|
|||||||
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
||||||
- ``bot_token``: Discord Bot token.
|
- ``bot_token``: Discord Bot token.
|
||||||
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
||||||
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
|
|
||||||
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
|
|
||||||
(even when mention_only is true). Use for channels where you want the bot to respond
|
|
||||||
without mentions. Empty = mention_only applies everywhere.
|
|
||||||
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
|
|
||||||
Default: same as ``mention_only``.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||||
@@ -40,29 +32,6 @@ class DiscordChannel(Channel):
|
|||||||
self._allowed_guilds.add(int(guild_id))
|
self._allowed_guilds.add(int(guild_id))
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
continue
|
continue
|
||||||
self._mention_only: bool = bool(config.get("mention_only", False))
|
|
||||||
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
|
|
||||||
self._allowed_channels: set[str] = set()
|
|
||||||
for channel_id in config.get("allowed_channels", []):
|
|
||||||
self._allowed_channels.add(str(channel_id))
|
|
||||||
|
|
||||||
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
|
|
||||||
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
|
|
||||||
# conversations to DeerFlow thread IDs — a different concern.
|
|
||||||
self._active_threads: dict[str, str] = {}
|
|
||||||
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
|
|
||||||
self._active_thread_ids: set[str] = set()
|
|
||||||
# Lock protecting _active_threads and the JSON file from concurrent access.
|
|
||||||
# _run_client (Discord loop thread) and the main thread both read/write.
|
|
||||||
self._thread_store_lock = threading.Lock()
|
|
||||||
store = config.get("channel_store")
|
|
||||||
if store is not None:
|
|
||||||
self._thread_store_path = store._path.parent / "discord_threads.json"
|
|
||||||
else:
|
|
||||||
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
|
|
||||||
|
|
||||||
# Typing indicator management
|
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
self._client = None
|
self._client = None
|
||||||
self._thread: threading.Thread | None = None
|
self._thread: threading.Thread | None = None
|
||||||
@@ -106,56 +75,12 @@ class DiscordChannel(Channel):
|
|||||||
|
|
||||||
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
self._load_active_threads()
|
|
||||||
logger.info("Discord channel started")
|
logger.info("Discord channel started")
|
||||||
|
|
||||||
def _load_active_threads(self) -> None:
|
|
||||||
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
|
|
||||||
with self._thread_store_lock:
|
|
||||||
try:
|
|
||||||
if not self._thread_store_path.exists():
|
|
||||||
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
|
|
||||||
return
|
|
||||||
data = json.loads(self._thread_store_path.read_text())
|
|
||||||
self._active_threads.clear()
|
|
||||||
self._active_thread_ids.clear()
|
|
||||||
for channel_id, thread_id in data.items():
|
|
||||||
self._active_threads[channel_id] = thread_id
|
|
||||||
self._active_thread_ids.add(thread_id)
|
|
||||||
if self._active_threads:
|
|
||||||
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to load thread mappings")
|
|
||||||
|
|
||||||
def _save_thread(self, channel_id: str, thread_id: str) -> None:
|
|
||||||
"""Persist a Discord thread mapping to the dedicated JSON file."""
|
|
||||||
with self._thread_store_lock:
|
|
||||||
try:
|
|
||||||
data: dict[str, str] = {}
|
|
||||||
if self._thread_store_path.exists():
|
|
||||||
data = json.loads(self._thread_store_path.read_text())
|
|
||||||
old_id = data.get(channel_id)
|
|
||||||
data[channel_id] = thread_id
|
|
||||||
# Update reverse-lookup set
|
|
||||||
if old_id:
|
|
||||||
self._active_thread_ids.discard(old_id)
|
|
||||||
self._active_thread_ids.add(thread_id)
|
|
||||||
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._thread_store_path.write_text(json.dumps(data, indent=2))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||||
|
|
||||||
# Cancel all active typing indicator tasks
|
|
||||||
for target_id, task in list(self._typing_tasks.items()):
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.debug("[Discord] cancelled typing task for target %s", target_id)
|
|
||||||
self._typing_tasks.clear()
|
|
||||||
|
|
||||||
if self._client and self._discord_loop and self._discord_loop.is_running():
|
if self._client and self._discord_loop and self._discord_loop.is_running():
|
||||||
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
||||||
try:
|
try:
|
||||||
@@ -175,10 +100,6 @@ class DiscordChannel(Channel):
|
|||||||
logger.info("Discord channel stopped")
|
logger.info("Discord channel stopped")
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
# Stop typing indicator once we're sending the response
|
|
||||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(stop_future)
|
|
||||||
|
|
||||||
target = await self._resolve_target(msg)
|
target = await self._resolve_target(msg)
|
||||||
if target is None:
|
if target is None:
|
||||||
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||||
@@ -190,9 +111,6 @@ class DiscordChannel(Channel):
|
|||||||
await asyncio.wrap_future(send_future)
|
await asyncio.wrap_future(send_future)
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(stop_future)
|
|
||||||
|
|
||||||
target = await self._resolve_target(msg)
|
target = await self._resolve_target(msg)
|
||||||
if target is None:
|
if target is None:
|
||||||
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||||
@@ -212,41 +130,6 @@ class DiscordChannel(Channel):
|
|||||||
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
|
|
||||||
"""Starts a loop to send periodic typing indicators."""
|
|
||||||
target_id = thread_ts or chat_id
|
|
||||||
if target_id in self._typing_tasks:
|
|
||||||
return # Already typing for this target
|
|
||||||
|
|
||||||
async def _typing_loop():
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await channel.trigger_typing()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
task = asyncio.create_task(_typing_loop())
|
|
||||||
self._typing_tasks[target_id] = task
|
|
||||||
|
|
||||||
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
|
|
||||||
"""Stops the typing loop for a specific target."""
|
|
||||||
target_id = thread_ts or chat_id
|
|
||||||
task = self._typing_tasks.pop(target_id, None)
|
|
||||||
if task and not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
|
|
||||||
|
|
||||||
async def _add_reaction(self, message) -> None:
|
|
||||||
"""Add a checkmark reaction to acknowledge the message was received."""
|
|
||||||
try:
|
|
||||||
await message.add_reaction("✅")
|
|
||||||
except Exception:
|
|
||||||
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
|
|
||||||
|
|
||||||
async def _on_message(self, message) -> None:
|
async def _on_message(self, message) -> None:
|
||||||
if not self._running or not self._client:
|
if not self._running or not self._client:
|
||||||
return
|
return
|
||||||
@@ -269,143 +152,15 @@ class DiscordChannel(Channel):
|
|||||||
if self._discord_module is None:
|
if self._discord_module is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Determine whether the bot is mentioned in this message
|
|
||||||
user = self._client.user if self._client else None
|
|
||||||
if user:
|
|
||||||
bot_mention = user.mention # <@ID>
|
|
||||||
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
|
|
||||||
standard_mention = f"<@{user.id}>"
|
|
||||||
else:
|
|
||||||
bot_mention = None
|
|
||||||
alt_mention = None
|
|
||||||
standard_mention = ""
|
|
||||||
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
|
|
||||||
|
|
||||||
# Strip mention from text for processing
|
|
||||||
if has_mention:
|
|
||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
|
||||||
|
|
||||||
# --- Determine thread/channel routing and typing target ---
|
|
||||||
thread_id = None
|
|
||||||
chat_id = None
|
|
||||||
typing_target = None # The Discord object to type into
|
|
||||||
|
|
||||||
if isinstance(message.channel, self._discord_module.Thread):
|
if isinstance(message.channel, self._discord_module.Thread):
|
||||||
# --- Message already inside a thread ---
|
chat_id = str(message.channel.parent_id or message.channel.id)
|
||||||
thread_obj = message.channel
|
thread_id = str(message.channel.id)
|
||||||
thread_id = str(thread_obj.id)
|
|
||||||
chat_id = str(thread_obj.parent_id or thread_obj.id)
|
|
||||||
typing_target = thread_obj
|
|
||||||
|
|
||||||
# If this is a known active thread, process normally
|
|
||||||
if thread_id in self._active_thread_ids:
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=str(message.author.id),
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=thread_id,
|
|
||||||
metadata={
|
|
||||||
"guild_id": str(guild.id) if guild else None,
|
|
||||||
"channel_id": str(message.channel.id),
|
|
||||||
"message_id": str(message.id),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = thread_id
|
|
||||||
self._publish(inbound)
|
|
||||||
# Start typing indicator in the thread
|
|
||||||
if typing_target:
|
|
||||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
||||||
asyncio.create_task(self._add_reaction(message))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Thread not tracked (orphaned) — create new thread and handle below
|
|
||||||
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
|
|
||||||
thread_id = None
|
|
||||||
typing_target = None
|
|
||||||
|
|
||||||
# At this point we're guaranteed to be in a channel, not a thread
|
|
||||||
# (the Thread case is handled above). Apply mention_only for all
|
|
||||||
# non-thread messages — no special case needed.
|
|
||||||
channel_id = str(message.channel.id)
|
|
||||||
|
|
||||||
# Check if there's an active thread for this channel
|
|
||||||
if channel_id in self._active_threads:
|
|
||||||
# respect mention_only: if enabled, only process messages that mention the bot
|
|
||||||
# (unless the channel is in allowed_channels)
|
|
||||||
# Messages within a thread are always allowed through (continuation).
|
|
||||||
# At this code point we know the message is in a channel, not a thread
|
|
||||||
# (Thread case handled above), so always apply the check.
|
|
||||||
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
||||||
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
|
|
||||||
return
|
|
||||||
# mention_only + fresh @ → create new thread instead of routing to existing one
|
|
||||||
if self._mention_only and has_mention:
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is not None:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj
|
|
||||||
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
|
|
||||||
else:
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel
|
|
||||||
else:
|
|
||||||
# Existing session → route to the existing thread
|
|
||||||
target_thread_id = self._active_threads[channel_id]
|
|
||||||
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = await self._get_channel_or_thread(target_thread_id)
|
|
||||||
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
||||||
# Not mentioned and not in an allowed channel → skip
|
|
||||||
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
|
|
||||||
return
|
|
||||||
elif self._mention_only and has_mention:
|
|
||||||
# First mention in this channel → create thread
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is not None:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj # Type into the new thread
|
|
||||||
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
|
|
||||||
else:
|
|
||||||
# Fallback: thread creation failed (disabled/permissions), reply in channel
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
elif self._thread_mode:
|
|
||||||
# thread_mode but mention_only is False → create thread anyway for conversation grouping
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is None:
|
|
||||||
# Thread creation failed (disabled/permissions), fall back to channel replies
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
else:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj # Type into the new thread
|
|
||||||
else:
|
else:
|
||||||
# No threading — reply directly in channel
|
thread = await self._create_thread(message)
|
||||||
thread_id = channel_id
|
if thread is None:
|
||||||
chat_id = channel_id
|
return
|
||||||
typing_target = message.channel # Type into the channel
|
chat_id = str(message.channel.id)
|
||||||
|
thread_id = str(thread.id)
|
||||||
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
@@ -422,15 +177,6 @@ class DiscordChannel(Channel):
|
|||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
|
|
||||||
# Start typing indicator in the correct target (thread or channel)
|
|
||||||
if typing_target:
|
|
||||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
||||||
|
|
||||||
self._publish(inbound)
|
|
||||||
asyncio.create_task(self._add_reaction(message))
|
|
||||||
|
|
||||||
def _publish(self, inbound) -> None:
|
|
||||||
"""Publish an inbound message to the main event loop."""
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||||
@@ -452,40 +198,14 @@ class DiscordChannel(Channel):
|
|||||||
|
|
||||||
async def _create_thread(self, message):
|
async def _create_thread(self, message):
|
||||||
try:
|
try:
|
||||||
if self._discord_module is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
|
|
||||||
channel_type = message.channel.type
|
|
||||||
if channel_type not in (
|
|
||||||
self._discord_module.ChannelType.text,
|
|
||||||
self._discord_module.ChannelType.news,
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"[Discord] channel type %s (%s) does not support threads",
|
|
||||||
channel_type.value,
|
|
||||||
channel_type.name,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
||||||
return await message.create_thread(name=thread_name)
|
return await message.create_thread(name=thread_name)
|
||||||
except self._discord_module.errors.HTTPException as exc:
|
|
||||||
if exc.code == 50024:
|
|
||||||
logger.info(
|
|
||||||
"[Discord] cannot create thread in channel %s (error code 50024): %s",
|
|
||||||
message.channel.id,
|
|
||||||
channel_type.name if (channel_type := message.channel.type) else "unknown",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.exception(
|
|
||||||
"[Discord] failed to create thread for message=%s (HTTPException %s)",
|
|
||||||
message.id,
|
|
||||||
exc.code,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
||||||
|
try:
|
||||||
|
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _resolve_target(self, msg: OutboundMessage):
|
async def _resolve_target(self, msg: OutboundMessage):
|
||||||
|
|||||||
@@ -787,22 +787,13 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
try:
|
result = await client.runs.wait(
|
||||||
result = await client.runs.wait(
|
thread_id,
|
||||||
thread_id,
|
assistant_id,
|
||||||
assistant_id,
|
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
config=run_config,
|
||||||
config=run_config,
|
context=run_context,
|
||||||
context=run_context,
|
)
|
||||||
multitask_strategy="reject",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
if _is_thread_busy_error(exc):
|
|
||||||
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
|
|
||||||
await self._send_error(msg, THREAD_BUSY_MESSAGE)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
|
|||||||
@@ -167,8 +167,6 @@ class ChannelService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = dict(config)
|
|
||||||
config["channel_store"] = self.store
|
|
||||||
channel = channel_cls(bus=self.bus, config=config)
|
channel = channel_cls(bus=self.bus, config=config)
|
||||||
self._channels[name] = channel
|
self._channels[name] = channel
|
||||||
await channel.start()
|
await channel.start()
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_SECRET_FILE = ".jwt_secret"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(BaseModel):
|
class AuthConfig(BaseModel):
|
||||||
"""JWT and auth-related configuration. Parsed once at startup.
|
"""JWT and auth-related configuration. Parsed once at startup.
|
||||||
@@ -32,32 +30,6 @@ class AuthConfig(BaseModel):
|
|||||||
_auth_config: AuthConfig | None = None
|
_auth_config: AuthConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
def _load_or_create_secret() -> str:
|
|
||||||
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
paths = get_paths()
|
|
||||||
secret_file = paths.base_dir / _SECRET_FILE
|
|
||||||
|
|
||||||
try:
|
|
||||||
if secret_file.exists():
|
|
||||||
secret = secret_file.read_text(encoding="utf-8").strip()
|
|
||||||
if secret:
|
|
||||||
return secret
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
|
|
||||||
|
|
||||||
secret = secrets.token_urlsafe(32)
|
|
||||||
try:
|
|
||||||
secret_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
||||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
||||||
fh.write(secret)
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
|
|
||||||
return secret
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_config() -> AuthConfig:
|
def get_auth_config() -> AuthConfig:
|
||||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||||
global _auth_config
|
global _auth_config
|
||||||
@@ -67,11 +39,11 @@ def get_auth_config() -> AuthConfig:
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||||
if not jwt_secret:
|
if not jwt_secret:
|
||||||
jwt_secret = _load_or_create_secret()
|
jwt_secret = secrets.token_urlsafe(32)
|
||||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
|
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||||
"persisted to .jwt_secret. Sessions will survive restarts. "
|
"Sessions will be invalidated on restart. "
|
||||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,9 +20,6 @@ ACTIVE_CONTENT_MIME_TYPES = {
|
|||||||
"image/svg+xml",
|
"image/svg+xml",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
|
|
||||||
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
||||||
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
||||||
@@ -47,22 +44,6 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
|
|
||||||
"""Read a .skill archive member while enforcing an uncompressed size cap."""
|
|
||||||
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
|
||||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
|
||||||
|
|
||||||
chunks: list[bytes] = []
|
|
||||||
total_read = 0
|
|
||||||
with zip_ref.open(info, "r") as src:
|
|
||||||
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
|
|
||||||
total_read += len(chunk)
|
|
||||||
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
|
||||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
|
||||||
chunks.append(chunk)
|
|
||||||
return b"".join(chunks)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
||||||
"""Extract a file from a .skill ZIP archive.
|
"""Extract a file from a .skill ZIP archive.
|
||||||
|
|
||||||
@@ -79,16 +60,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
|||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
# List all files in the archive
|
# List all files in the archive
|
||||||
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
|
namelist = zip_ref.namelist()
|
||||||
|
|
||||||
# Try direct path first
|
# Try direct path first
|
||||||
if internal_path in infos_by_name:
|
if internal_path in namelist:
|
||||||
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
|
return zip_ref.read(internal_path)
|
||||||
|
|
||||||
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
||||||
for name, info in infos_by_name.items():
|
for name in namelist:
|
||||||
if name.endswith("/" + internal_path) or name == internal_path:
|
if name.endswith("/" + internal_path) or name == internal_path:
|
||||||
return _read_skill_archive_member(zip_ref, info)
|
return zip_ref.read(name)
|
||||||
|
|
||||||
# Not found
|
# Not found
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ rm -f backend/.deer-flow/data/deerflow.db
|
|||||||
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
|
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
|
||||||
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
|
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
|
||||||
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
|
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
|
||||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持) |
|
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||||
|
|
||||||
### 生产环境建议
|
### 生产环境建议
|
||||||
|
|
||||||
@@ -137,4 +137,4 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
|
|||||||
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
|
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
|
||||||
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
|
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
|
||||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||||
| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 |
|
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||||
|
|||||||
@@ -40,15 +40,6 @@ class MemoryUpdateQueue:
|
|||||||
self._timer: threading.Timer | None = None
|
self._timer: threading.Timer | None = None
|
||||||
self._processing = False
|
self._processing = False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _queue_key(
|
|
||||||
thread_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
agent_name: str | None,
|
|
||||||
) -> tuple[str, str | None, str | None]:
|
|
||||||
"""Return the debounce identity for a memory update target."""
|
|
||||||
return (thread_id, user_id, agent_name)
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -124,9 +115,8 @@ class MemoryUpdateQueue:
|
|||||||
correction_detected: bool,
|
correction_detected: bool,
|
||||||
reinforcement_detected: bool,
|
reinforcement_detected: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
queue_key = self._queue_key(thread_id, user_id, agent_name)
|
|
||||||
existing_context = next(
|
existing_context = next(
|
||||||
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
|
(context for context in self._queue if context.thread_id == thread_id),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||||
@@ -140,7 +130,7 @@ class MemoryUpdateQueue:
|
|||||||
reinforcement_detected=merged_reinforcement_detected,
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
|
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||||
self._queue.append(context)
|
self._queue.append(context)
|
||||||
|
|
||||||
def _reset_timer(self) -> None:
|
def _reset_timer(self) -> None:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
|
|||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
|
||||||
|
|
||||||
|
|
||||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||||
@@ -22,13 +21,11 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
|
|||||||
|
|
||||||
correction_detected = detect_correction(filtered_messages)
|
correction_detected = detect_correction(filtered_messages)
|
||||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||||
user_id = resolve_runtime_user_id(event.runtime)
|
|
||||||
queue = get_memory_queue()
|
queue = get_memory_queue()
|
||||||
queue.add_nowait(
|
queue.add_nowait(
|
||||||
thread_id=event.thread_id,
|
thread_id=event.thread_id,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
agent_name=event.agent_name,
|
agent_name=event.agent_name,
|
||||||
user_id=user_id,
|
|
||||||
correction_detected=correction_detected,
|
correction_detected=correction_detected,
|
||||||
reinforcement_detected=reinforcement_detected,
|
reinforcement_detected=reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|||||||
+22
-27
@@ -104,46 +104,45 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
return "[Tool call was interrupted and did not return a result.]"
|
return "[Tool call was interrupted and did not return a result.]"
|
||||||
|
|
||||||
def _build_patched_messages(self, messages: list) -> list | None:
|
def _build_patched_messages(self, messages: list) -> list | None:
|
||||||
"""Return messages with tool results grouped after their tool-call AIMessage.
|
"""Return a new message list with patches inserted at the correct positions.
|
||||||
|
|
||||||
This normalizes model-bound causal order before provider serialization while
|
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||||
preserving already-valid transcripts unchanged.
|
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||||
|
Returns None if no patches are needed.
|
||||||
"""
|
"""
|
||||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
# Collect IDs of all existing ToolMessages
|
||||||
|
existing_tool_msg_ids: set[str] = set()
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||||
|
|
||||||
tool_call_ids: set[str] = set()
|
# Check if any patching is needed
|
||||||
|
needs_patch = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
for tc in self._message_tool_calls(msg):
|
for tc in self._message_tool_calls(msg):
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if tc_id:
|
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||||
tool_call_ids.add(tc_id)
|
needs_patch = True
|
||||||
|
break
|
||||||
|
if needs_patch:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not needs_patch:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build new list with patches inserted right after each dangling AIMessage
|
||||||
patched: list = []
|
patched: list = []
|
||||||
consumed_tool_msg_ids: set[str] = set()
|
patched_ids: set[str] = set()
|
||||||
patch_count = 0
|
patch_count = 0
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
patched.append(msg)
|
patched.append(msg)
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for tc in self._message_tool_calls(msg):
|
for tc in self._message_tool_calls(msg):
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||||
continue
|
|
||||||
|
|
||||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
|
||||||
if existing_tool_msg is not None:
|
|
||||||
patched.append(existing_tool_msg)
|
|
||||||
consumed_tool_msg_ids.add(tc_id)
|
|
||||||
else:
|
|
||||||
patched.append(
|
patched.append(
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=self._synthetic_tool_message_content(tc),
|
content=self._synthetic_tool_message_content(tc),
|
||||||
@@ -152,14 +151,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
consumed_tool_msg_ids.add(tc_id)
|
patched_ids.add(tc_id)
|
||||||
patch_count += 1
|
patch_count += 1
|
||||||
|
|
||||||
if patched == messages:
|
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||||
return None
|
|
||||||
|
|
||||||
if patch_count:
|
|
||||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
|
||||||
return patched
|
return patched
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from typing import Any, Protocol, override, runtime_checkable
|
|||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import SummarizationMiddleware
|
from langchain.agents.middleware import SummarizationMiddleware
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||||
from langchain_core.messages.utils import get_buffer_string
|
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
@@ -176,84 +175,12 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@override
|
|
||||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
||||||
"""Generate summary without emitting streaming events to the client.
|
|
||||||
|
|
||||||
Suppresses callbacks to prevent the internal summarization LLM call from
|
|
||||||
producing visible AI message chunks in the frontend's ``messages-tuple``
|
|
||||||
stream (issue #2804).
|
|
||||||
"""
|
|
||||||
if not messages_to_summarize:
|
|
||||||
return "No previous conversation history."
|
|
||||||
|
|
||||||
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
|
||||||
if not trimmed:
|
|
||||||
return "Previous conversation was too long to summarize."
|
|
||||||
|
|
||||||
formatted = get_buffer_string(trimmed)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.model.with_config(callbacks=[]).invoke(
|
|
||||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
|
||||||
config={
|
|
||||||
"metadata": {"lc_source": "summarization"},
|
|
||||||
"callbacks": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return self._extract_summary_text(response)
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error generating summary: {e!s}"
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
||||||
"""Generate summary without emitting streaming events to the client.
|
|
||||||
|
|
||||||
Suppresses callbacks to prevent the internal summarization LLM call from
|
|
||||||
producing visible AI message chunks in the frontend's ``messages-tuple``
|
|
||||||
stream (issue #2804).
|
|
||||||
"""
|
|
||||||
if not messages_to_summarize:
|
|
||||||
return "No previous conversation history."
|
|
||||||
|
|
||||||
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
|
||||||
if not trimmed:
|
|
||||||
return "Previous conversation was too long to summarize."
|
|
||||||
|
|
||||||
formatted = get_buffer_string(trimmed)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.model.with_config(callbacks=[]).ainvoke(
|
|
||||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
|
||||||
config={
|
|
||||||
"metadata": {"lc_source": "summarization"},
|
|
||||||
"callbacks": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return self._extract_summary_text(response)
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error generating summary: {e!s}"
|
|
||||||
|
|
||||||
def _extract_summary_text(self, response: Any) -> str:
|
|
||||||
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
|
|
||||||
# Fall back to .content for non-LangChain responses.
|
|
||||||
summary_text = getattr(response, "text", None)
|
|
||||||
if summary_text is None:
|
|
||||||
summary_text = getattr(response, "content", "")
|
|
||||||
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||||
"""
|
"""
|
||||||
return [
|
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||||
HumanMessage(
|
|
||||||
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
|
||||||
name="summary",
|
|
||||||
additional_kwargs={"hide_from_ui": True},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _preserve_dynamic_context_reminders(
|
def _preserve_dynamic_context_reminders(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -7,21 +7,17 @@ reminder message so the model still knows about the outstanding todo list.
|
|||||||
|
|
||||||
Additionally, this middleware prevents the agent from exiting the loop while
|
Additionally, this middleware prevents the agent from exiting the loop while
|
||||||
there are still incomplete todo items. When the model produces a final response
|
there are still incomplete todo items. When the model produces a final response
|
||||||
(no tool calls) but todos are not yet complete, the middleware queues a reminder
|
(no tool calls) but todos are not yet complete, the middleware injects a reminder
|
||||||
for the next model request and jumps back to the model node to force continued
|
and jumps back to the model node to force continued engagement.
|
||||||
engagement. The completion reminder is injected via ``wrap_model_call`` instead
|
|
||||||
of being persisted into graph state as a normal user-visible message.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import threading
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any, override
|
from typing import Any, override
|
||||||
|
|
||||||
from langchain.agents.middleware import TodoListMiddleware
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
from langchain.agents.middleware.types import hook_config
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
@@ -59,51 +55,6 @@ def _format_todos(todos: list[Todo]) -> str:
|
|||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _format_completion_reminder(todos: list[Todo]) -> str:
|
|
||||||
"""Format a completion reminder for incomplete todo items."""
|
|
||||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
|
||||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
|
||||||
return (
|
|
||||||
"<system_reminder>\n"
|
|
||||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
|
||||||
f"{incomplete_text}\n\n"
|
|
||||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
|
||||||
"as you finish them, and only respond when all items are done.\n"
|
|
||||||
"</system_reminder>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
|
|
||||||
|
|
||||||
|
|
||||||
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
|
|
||||||
"""Return True when an AIMessage is not a clean final answer.
|
|
||||||
|
|
||||||
Todo completion reminders should only fire when the model has produced a
|
|
||||||
plain final response. Provider/tool parsing details have moved across
|
|
||||||
LangChain versions and integrations, so keep all tool-intent/error signals
|
|
||||||
behind this helper instead of checking one concrete field at the call site.
|
|
||||||
"""
|
|
||||||
if message.tool_calls:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if getattr(message, "invalid_tool_calls", None):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Backward/provider compatibility: some integrations preserve raw or legacy
|
|
||||||
# tool-call intent in additional_kwargs even when structured tool_calls is
|
|
||||||
# empty. If this helper changes, update the matching sentinel test
|
|
||||||
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
|
|
||||||
# if that test fails after a LangChain upgrade, review this helper so new
|
|
||||||
# tool-call/error fields are not silently treated as clean final answers.
|
|
||||||
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
|
|
||||||
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
response_metadata = getattr(message, "response_metadata", {}) or {}
|
|
||||||
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
|
|
||||||
|
|
||||||
|
|
||||||
class TodoMiddleware(TodoListMiddleware):
|
class TodoMiddleware(TodoListMiddleware):
|
||||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||||
|
|
||||||
@@ -138,7 +89,6 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
formatted = _format_todos(todos)
|
formatted = _format_todos(todos)
|
||||||
reminder = HumanMessage(
|
reminder = HumanMessage(
|
||||||
name="todo_reminder",
|
name="todo_reminder",
|
||||||
additional_kwargs={"hide_from_ui": True},
|
|
||||||
content=(
|
content=(
|
||||||
"<system_reminder>\n"
|
"<system_reminder>\n"
|
||||||
"Your todo list from earlier is no longer visible in the current context window, "
|
"Your todo list from earlier is no longer visible in the current context window, "
|
||||||
@@ -163,100 +113,6 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
# Maximum number of completion reminders before allowing the agent to exit.
|
# Maximum number of completion reminders before allowing the agent to exit.
|
||||||
# This prevents infinite loops when the agent cannot make further progress.
|
# This prevents infinite loops when the agent cannot make further progress.
|
||||||
_MAX_COMPLETION_REMINDERS = 2
|
_MAX_COMPLETION_REMINDERS = 2
|
||||||
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
|
|
||||||
_MAX_COMPLETION_REMINDER_KEYS = 4096
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
|
|
||||||
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
|
|
||||||
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
|
|
||||||
self._completion_reminder_next_order = 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_thread_id(runtime: Runtime) -> str:
|
|
||||||
context = getattr(runtime, "context", None)
|
|
||||||
thread_id = context.get("thread_id") if context else None
|
|
||||||
return str(thread_id) if thread_id else "default"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_run_id(runtime: Runtime) -> str:
|
|
||||||
context = getattr(runtime, "context", None)
|
|
||||||
run_id = context.get("run_id") if context else None
|
|
||||||
return str(run_id) if run_id else "default"
|
|
||||||
|
|
||||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
|
||||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
|
||||||
|
|
||||||
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
|
||||||
self._completion_reminder_next_order += 1
|
|
||||||
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
|
|
||||||
|
|
||||||
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
|
|
||||||
keys = set(self._pending_completion_reminders)
|
|
||||||
keys.update(self._completion_reminder_counts)
|
|
||||||
keys.update(self._completion_reminder_touch_order)
|
|
||||||
return keys
|
|
||||||
|
|
||||||
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
|
||||||
self._pending_completion_reminders.pop(key, None)
|
|
||||||
self._completion_reminder_counts.pop(key, None)
|
|
||||||
self._completion_reminder_touch_order.pop(key, None)
|
|
||||||
|
|
||||||
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
|
|
||||||
keys = self._completion_reminder_keys_locked()
|
|
||||||
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
|
|
||||||
if overflow <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
candidates = [key for key in keys if key != protected_key]
|
|
||||||
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
|
|
||||||
for key in candidates[:overflow]:
|
|
||||||
self._drop_completion_reminder_key_locked(key)
|
|
||||||
|
|
||||||
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
self._pending_completion_reminders.setdefault(key, []).append(reminder)
|
|
||||||
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
|
|
||||||
self._touch_completion_reminder_key_locked(key)
|
|
||||||
self._prune_completion_reminder_state_locked(protected_key=key)
|
|
||||||
|
|
||||||
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
return self._completion_reminder_counts.get(key, 0)
|
|
||||||
|
|
||||||
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
reminders = self._pending_completion_reminders.pop(key, [])
|
|
||||||
if reminders or key in self._completion_reminder_counts:
|
|
||||||
self._touch_completion_reminder_key_locked(key)
|
|
||||||
return reminders
|
|
||||||
|
|
||||||
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
|
|
||||||
thread_id, current_run_id = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
for key in self._completion_reminder_keys_locked():
|
|
||||||
if key[0] == thread_id and key[1] != current_run_id:
|
|
||||||
self._drop_completion_reminder_key_locked(key)
|
|
||||||
|
|
||||||
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
self._drop_completion_reminder_key_locked(key)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_other_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_other_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@hook_config(can_jump_to=["model"])
|
@hook_config(can_jump_to=["model"])
|
||||||
@override
|
@override
|
||||||
@@ -281,12 +137,10 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
if base_result is not None:
|
if base_result is not None:
|
||||||
return base_result
|
return base_result
|
||||||
|
|
||||||
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
|
# 2. Only intervene when the agent wants to exit (no tool calls).
|
||||||
# intent or tool-call parse errors should be handled by the tool path
|
|
||||||
# instead of being masked by todo reminders.
|
|
||||||
messages = state.get("messages") or []
|
messages = state.get("messages") or []
|
||||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
||||||
if not last_ai or _has_tool_call_intent_or_error(last_ai):
|
if not last_ai or last_ai.tool_calls:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 3. Allow exit when all todos are completed or there are no todos.
|
# 3. Allow exit when all todos are completed or there are no todos.
|
||||||
@@ -295,14 +149,24 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
||||||
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
|
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 5. Queue a reminder for the next model request and jump back. We must
|
# 5. Inject a reminder and force the agent back to the model.
|
||||||
# not persist this control prompt as a normal HumanMessage, otherwise it
|
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||||
# can leak into user-visible message streams and saved transcripts.
|
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||||
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
|
reminder = HumanMessage(
|
||||||
return {"jump_to": "model"}
|
name="todo_completion_reminder",
|
||||||
|
content=(
|
||||||
|
"<system_reminder>\n"
|
||||||
|
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||||
|
f"{incomplete_text}\n\n"
|
||||||
|
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||||
|
"as you finish them, and only respond when all items are done.\n"
|
||||||
|
"</system_reminder>"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return {"jump_to": "model", "messages": [reminder]}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@hook_config(can_jump_to=["model"])
|
@hook_config(can_jump_to=["model"])
|
||||||
@@ -313,47 +177,3 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Async version of after_model."""
|
"""Async version of after_model."""
|
||||||
return self.after_model(state, runtime)
|
return self.after_model(state, runtime)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_pending_completion_reminders(reminders: list[str]) -> str:
|
|
||||||
return "\n\n".join(dict.fromkeys(reminders))
|
|
||||||
|
|
||||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
|
||||||
reminders = self._drain_completion_reminders(request.runtime)
|
|
||||||
if not reminders:
|
|
||||||
return request
|
|
||||||
new_messages = [
|
|
||||||
*request.messages,
|
|
||||||
HumanMessage(
|
|
||||||
content=self._format_pending_completion_reminders(reminders),
|
|
||||||
name="todo_completion_reminder",
|
|
||||||
additional_kwargs={"hide_from_ui": True},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return request.override(messages=new_messages)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def wrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
|
||||||
) -> ModelCallResult:
|
|
||||||
return handler(self._augment_request(request))
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def awrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
||||||
) -> ModelCallResult:
|
|
||||||
return await handler(self._augment_request(request))
|
|
||||||
|
|
||||||
@override
|
|
||||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_current_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_current_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Any, override
|
|||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain.agents.middleware.todo import Todo
|
from langchain.agents.middleware.todo import Todo
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -217,17 +217,6 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
|||||||
return "thinking"
|
return "thinking"
|
||||||
|
|
||||||
|
|
||||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
|
||||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
|
||||||
for tc in message.tool_calls or []:
|
|
||||||
if isinstance(tc, dict):
|
|
||||||
if tc.get("id") == tool_call_id:
|
|
||||||
return True
|
|
||||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||||
tool_calls = getattr(message, "tool_calls", None) or []
|
tool_calls = getattr(message, "tool_calls", None) or []
|
||||||
actions: list[dict[str, Any]] = []
|
actions: list[dict[str, Any]] = []
|
||||||
@@ -272,51 +261,8 @@ class TokenUsageMiddleware(AgentMiddleware):
|
|||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
|
||||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
|
||||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
|
||||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
|
||||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
|
||||||
# written back to the same dispatch message (merging into one update).
|
|
||||||
state_updates: dict[int, AIMessage] = {}
|
|
||||||
if len(messages) >= 2:
|
|
||||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
|
||||||
|
|
||||||
idx = len(messages) - 2
|
|
||||||
while idx >= 0:
|
|
||||||
tool_msg = messages[idx]
|
|
||||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
|
||||||
if subagent_usage:
|
|
||||||
# Search backward from the ToolMessage to find the AIMessage
|
|
||||||
# that dispatched it. A single model response can dispatch
|
|
||||||
# multiple task tool calls, so we can't assume a fixed offset.
|
|
||||||
dispatch_idx = idx - 1
|
|
||||||
while dispatch_idx >= 0:
|
|
||||||
candidate = messages[dispatch_idx]
|
|
||||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
|
||||||
# Accumulate into an existing update for the same
|
|
||||||
# AIMessage (multiple task calls in one response),
|
|
||||||
# or merge fresh from the original message.
|
|
||||||
existing_update = state_updates.get(dispatch_idx)
|
|
||||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
|
||||||
merged = {
|
|
||||||
**prev,
|
|
||||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
|
||||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
|
||||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
|
||||||
}
|
|
||||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
|
||||||
break
|
|
||||||
dispatch_idx -= 1
|
|
||||||
idx -= 1
|
|
||||||
|
|
||||||
last = messages[-1]
|
last = messages[-1]
|
||||||
if not isinstance(last, AIMessage):
|
if not isinstance(last, AIMessage):
|
||||||
if state_updates:
|
|
||||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
usage = getattr(last, "usage_metadata", None)
|
usage = getattr(last, "usage_metadata", None)
|
||||||
@@ -342,12 +288,11 @@ class TokenUsageMiddleware(AgentMiddleware):
|
|||||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||||
|
|
||||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
return None
|
||||||
|
|
||||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||||
state_updates[len(messages) - 1] = updated_msg
|
return {"messages": [updated_msg]}
|
||||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
|
|||||||
@@ -223,11 +223,10 @@ class RunRepository(RunStore):
|
|||||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||||
_completed = RunRow.status.in_(("success", "error"))
|
_completed = RunRow.status.in_(("success", "error"))
|
||||||
_thread = RunRow.thread_id == thread_id
|
_thread = RunRow.thread_id == thread_id
|
||||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
select(
|
select(
|
||||||
model_name.label("model"),
|
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
||||||
func.count().label("runs"),
|
func.count().label("runs"),
|
||||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||||
@@ -237,7 +236,7 @@ class RunRepository(RunStore):
|
|||||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||||
)
|
)
|
||||||
.where(_thread, _completed)
|
.where(_thread, _completed)
|
||||||
.group_by(model_name)
|
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import logging
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import delete, func, select, text
|
from sqlalchemy import delete, func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
@@ -86,28 +86,6 @@ class DbRunEventStore(RunEventStore):
|
|||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
return str(user.id) if user is not None else None
|
return str(user.id) if user is not None else None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
|
|
||||||
"""Return the current max seq while serializing writers per thread.
|
|
||||||
|
|
||||||
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
|
|
||||||
results are not lockable rows. As a release-safe workaround, take a
|
|
||||||
transaction-level advisory lock keyed by thread_id before reading the
|
|
||||||
aggregate. Other dialects keep the existing row-locking statement.
|
|
||||||
"""
|
|
||||||
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
|
|
||||||
bind = session.get_bind()
|
|
||||||
dialect_name = bind.dialect.name if bind is not None else ""
|
|
||||||
|
|
||||||
if dialect_name == "postgresql":
|
|
||||||
await session.execute(
|
|
||||||
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
|
|
||||||
{"thread_id": thread_id},
|
|
||||||
)
|
|
||||||
return await session.scalar(stmt)
|
|
||||||
|
|
||||||
return await session.scalar(stmt.with_for_update())
|
|
||||||
|
|
||||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||||
"""Write a single event — low-frequency path only.
|
"""Write a single event — low-frequency path only.
|
||||||
|
|
||||||
@@ -122,7 +100,10 @@ class DbRunEventStore(RunEventStore):
|
|||||||
user_id = self._user_id_from_context()
|
user_id = self._user_id_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
|
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||||
|
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||||
|
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||||
seq = (max_seq or 0) + 1
|
seq = (max_seq or 0) + 1
|
||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -145,8 +126,10 @@ class DbRunEventStore(RunEventStore):
|
|||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||||
|
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||||
|
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||||
thread_id = events[0]["thread_id"]
|
thread_id = events[0]["thread_id"]
|
||||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||||
seq = max_seq or 0
|
seq = max_seq or 0
|
||||||
rows = []
|
rows = []
|
||||||
for e in events:
|
for e in events:
|
||||||
|
|||||||
@@ -26,28 +26,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
|
|
||||||
# write it back to the triggering AIMessage's usage_metadata.
|
|
||||||
_subagent_usage_cache: dict[str, dict[str, int]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
|
||||||
if app_config is None:
|
|
||||||
try:
|
|
||||||
app_config = get_app_config()
|
|
||||||
except FileNotFoundError:
|
|
||||||
return False
|
|
||||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
|
||||||
|
|
||||||
|
|
||||||
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
|
|
||||||
if enabled and usage:
|
|
||||||
_subagent_usage_cache[tool_call_id] = usage
|
|
||||||
|
|
||||||
|
|
||||||
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
|
|
||||||
return _subagent_usage_cache.pop(tool_call_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_subagent_terminal(result: Any) -> bool:
|
def _is_subagent_terminal(result: Any) -> bool:
|
||||||
"""Return whether a background subagent result is safe to clean up."""
|
"""Return whether a background subagent result is safe to clean up."""
|
||||||
@@ -114,17 +92,6 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _summarize_usage(records: list[dict] | None) -> dict | None:
|
|
||||||
"""Summarize token usage records into a compact dict for SSE events."""
|
|
||||||
if not records:
|
|
||||||
return None
|
|
||||||
return {
|
|
||||||
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
|
|
||||||
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
|
|
||||||
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
||||||
"""Report subagent token usage to the parent RunJournal, if available.
|
"""Report subagent token usage to the parent RunJournal, if available.
|
||||||
|
|
||||||
@@ -210,7 +177,6 @@ async def task_tool(
|
|||||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||||
"""
|
"""
|
||||||
runtime_app_config = _get_runtime_app_config(runtime)
|
runtime_app_config = _get_runtime_app_config(runtime)
|
||||||
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
|
|
||||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||||
|
|
||||||
# Get subagent configuration
|
# Get subagent configuration
|
||||||
@@ -346,32 +312,27 @@ async def task_tool(
|
|||||||
last_message_count = current_message_count
|
last_message_count = current_message_count
|
||||||
|
|
||||||
# Check if task completed, failed, or timed out
|
# Check if task completed, failed, or timed out
|
||||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
|
||||||
if result.status == SubagentStatus.COMPLETED:
|
if result.status == SubagentStatus.COMPLETED:
|
||||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
||||||
_report_subagent_usage(runtime, result)
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
|
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task Succeeded. Result: {result.result}"
|
return f"Task Succeeded. Result: {result.result}"
|
||||||
elif result.status == SubagentStatus.FAILED:
|
elif result.status == SubagentStatus.FAILED:
|
||||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
||||||
_report_subagent_usage(runtime, result)
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
|
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task failed. Error: {result.error}"
|
return f"Task failed. Error: {result.error}"
|
||||||
elif result.status == SubagentStatus.CANCELLED:
|
elif result.status == SubagentStatus.CANCELLED:
|
||||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
||||||
_report_subagent_usage(runtime, result)
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
|
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return "Task cancelled by user."
|
return "Task cancelled by user."
|
||||||
elif result.status == SubagentStatus.TIMED_OUT:
|
elif result.status == SubagentStatus.TIMED_OUT:
|
||||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
||||||
_report_subagent_usage(runtime, result)
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
|
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task timed out. Error: {result.error}"
|
return f"Task timed out. Error: {result.error}"
|
||||||
@@ -390,9 +351,7 @@ async def task_tool(
|
|||||||
timeout_minutes = config.timeout_seconds // 60
|
timeout_minutes = config.timeout_seconds // 60
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||||
_report_subagent_usage(runtime, result)
|
_report_subagent_usage(runtime, result)
|
||||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
writer({"type": "task_timed_out", "task_id": task_id})
|
||||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
||||||
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
|
|
||||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Signal the background subagent thread to stop cooperatively.
|
# Signal the background subagent thread to stop cooperatively.
|
||||||
@@ -415,8 +374,4 @@ async def task_tool(
|
|||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
else:
|
else:
|
||||||
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
||||||
_subagent_usage_cache.pop(tool_call_id, None)
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
_subagent_usage_cache.pop(tool_call_id, None)
|
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
|
|||||||
from deerflow.reflection import resolve_variable
|
from deerflow.reflection import resolve_variable
|
||||||
from deerflow.sandbox.security import is_host_bash_allowed
|
from deerflow.sandbox.security import is_host_bash_allowed
|
||||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -116,6 +116,8 @@ def get_available_tools(
|
|||||||
# made through the Gateway API (which runs in a separate process) are immediately
|
# made through the Gateway API (which runs in a separate process) are immediately
|
||||||
# reflected when loading MCP tools.
|
# reflected when loading MCP tools.
|
||||||
mcp_tools = []
|
mcp_tools = []
|
||||||
|
# Reset deferred registry upfront to prevent stale state from previous calls
|
||||||
|
reset_deferred_registry()
|
||||||
if include_mcp:
|
if include_mcp:
|
||||||
try:
|
try:
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig
|
from deerflow.config.extensions_config import ExtensionsConfig
|
||||||
@@ -133,51 +135,12 @@ def get_available_tools(
|
|||||||
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
|
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
|
||||||
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
|
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
|
||||||
|
|
||||||
# Reuse the existing registry if one is already set for
|
registry = DeferredToolRegistry()
|
||||||
# this async context. ``get_available_tools`` is
|
for t in mcp_tools:
|
||||||
# re-entered whenever a subagent is spawned
|
registry.register(t)
|
||||||
# (``task_tool`` calls it to build the child agent's
|
set_deferred_registry(registry)
|
||||||
# toolset), and previously we used to unconditionally
|
|
||||||
# rebuild the registry — wiping out the parent agent's
|
|
||||||
# tool_search promotions. The
|
|
||||||
# ``DeferredToolFilterMiddleware`` then re-hid those
|
|
||||||
# tools from subsequent model calls, leaving the agent
|
|
||||||
# able to see a tool's name but unable to invoke it
|
|
||||||
# (issue #2884). ``contextvars`` already gives us the
|
|
||||||
# lifetime semantics we want: a fresh request / graph
|
|
||||||
# run starts in a new asyncio task with the
|
|
||||||
# ContextVar at its default of ``None``, so reuse is
|
|
||||||
# only triggered for re-entrant calls inside one run.
|
|
||||||
#
|
|
||||||
# Intentionally NOT reconciling against the current
|
|
||||||
# ``mcp_tools`` snapshot. The MCP cache only refreshes
|
|
||||||
# on ``extensions_config.json`` mtime changes, which
|
|
||||||
# in practice happens between graph runs — not inside
|
|
||||||
# one. And even if a refresh did happen mid-run, the
|
|
||||||
# already-built lead agent's ``ToolNode`` still holds
|
|
||||||
# the *previous* tool set (LangGraph binds tools at
|
|
||||||
# graph construction time), so a brand-new MCP tool
|
|
||||||
# couldn't actually be invoked anyway. The
|
|
||||||
# ``DeferredToolRegistry`` doesn't retain the names
|
|
||||||
# of previously-promoted tools (``promote()`` drops
|
|
||||||
# the entry entirely), so re-syncing the registry
|
|
||||||
# against a fresh ``mcp_tools`` list would
|
|
||||||
# mis-classify those promotions as new tools and
|
|
||||||
# re-register them as deferred — exactly the bug
|
|
||||||
# this fix exists to prevent.
|
|
||||||
existing_registry = get_deferred_registry()
|
|
||||||
if existing_registry is None:
|
|
||||||
registry = DeferredToolRegistry()
|
|
||||||
for t in mcp_tools:
|
|
||||||
registry.register(t)
|
|
||||||
set_deferred_registry(registry)
|
|
||||||
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
|
||||||
else:
|
|
||||||
mcp_tool_names = {t.name for t in mcp_tools}
|
|
||||||
still_deferred = len(existing_registry)
|
|
||||||
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
|
|
||||||
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
|
|
||||||
builtin_tools.append(tool_search_tool)
|
builtin_tools.append(tool_search_tool)
|
||||||
|
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
postgres = ["deerflow-harness[postgres]"]
|
postgres = ["deerflow-harness[postgres]"]
|
||||||
discord = ["discord.py>=2.7.0"]
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ Sets up sys.path and pre-mocks modules that would cause circular import
|
|||||||
issues when unit-testing lightweight config/registry code in isolation.
|
issues when unit-testing lightweight config/registry code in isolation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,16 +11,11 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
|
|
||||||
|
|
||||||
# Make 'app' and 'deerflow' importable from any working directory
|
# Make 'app' and 'deerflow' importable from any working directory
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
|
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
|
||||||
|
|
||||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
|
|
||||||
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
|
|
||||||
|
|
||||||
# Break the circular import chain that exists in production code:
|
# Break the circular import chain that exists in production code:
|
||||||
# deerflow.subagents.__init__
|
# deerflow.subagents.__init__
|
||||||
# -> .executor (SubagentExecutor, SubagentResult)
|
# -> .executor (SubagentExecutor, SubagentResult)
|
||||||
@@ -63,92 +56,6 @@ def provisioner_module():
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def blocking_io_detector():
|
|
||||||
"""Fail a focused test if blocking calls run on the event loop thread."""
|
|
||||||
with detect_blocking_io(fail_on_exit=True) as detector:
|
|
||||||
yield detector
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
||||||
group = parser.getgroup("blocking-io")
|
|
||||||
group.addoption(
|
|
||||||
"--detect-blocking-io",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
|
|
||||||
)
|
|
||||||
group.addoption(
|
|
||||||
"--detect-blocking-io-fail",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Set a failing exit status when --detect-blocking-io records violations.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config: pytest.Config) -> None:
|
|
||||||
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
|
||||||
if _blocking_io_probe_enabled(session.config):
|
|
||||||
_blocking_io_probe.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(hookwrapper=True)
|
|
||||||
def pytest_runtest_call(item: pytest.Item):
|
|
||||||
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
|
|
||||||
detector.__enter__()
|
|
||||||
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(hookwrapper=True)
|
|
||||||
def pytest_runtest_teardown(item: pytest.Item):
|
|
||||||
yield
|
|
||||||
|
|
||||||
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
|
|
||||||
if detector is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
detector.__exit__(None, None, None)
|
|
||||||
_blocking_io_probe.record(item.nodeid, detector.violations)
|
|
||||||
finally:
|
|
||||||
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_sessionfinish(session: pytest.Session) -> None:
|
|
||||||
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
|
|
||||||
session.exitstatus = pytest.ExitCode.TESTS_FAILED
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
|
|
||||||
if not _blocking_io_probe_enabled(terminalreporter.config):
|
|
||||||
return
|
|
||||||
|
|
||||||
header, *details = _blocking_io_probe.format_summary().splitlines()
|
|
||||||
terminalreporter.write_sep("=", header)
|
|
||||||
for line in details:
|
|
||||||
terminalreporter.write_line(line)
|
|
||||||
|
|
||||||
|
|
||||||
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
|
|
||||||
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
|
|
||||||
|
|
||||||
|
|
||||||
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
|
|
||||||
return bool(config.getoption("--detect-blocking-io-fail"))
|
|
||||||
|
|
||||||
|
|
||||||
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
|
|
||||||
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Auto-set user context for every test unless marked no_auto_user
|
# Auto-set user context for every test unless marked no_auto_user
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Shared test support helpers."""
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Runtime and static detectors used by tests."""
|
|
||||||
@@ -1,287 +0,0 @@
|
|||||||
"""Test helper for detecting blocking calls on an asyncio event loop.
|
|
||||||
|
|
||||||
The detector is intentionally test-only. It monkeypatches a small set of
|
|
||||||
well-known blocking entry points and their already-loaded module-level aliases,
|
|
||||||
then records calls only when they happen on a thread that is currently running
|
|
||||||
an asyncio event loop. Aliases captured in closures or default arguments remain
|
|
||||||
out of scope.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import importlib
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from collections import Counter
|
|
||||||
from collections.abc import Callable, Iterable, Iterator
|
|
||||||
from contextlib import AbstractContextManager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import wraps
|
|
||||||
from pathlib import Path
|
|
||||||
from types import TracebackType
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
BlockingCallable = Callable[..., Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class BlockingCallSpec:
|
|
||||||
"""Describes one blocking callable to wrap during a detector run."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
target: str
|
|
||||||
record_on_iteration: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class BlockingCall:
|
|
||||||
"""One blocking call observed on an asyncio event loop thread."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
target: str
|
|
||||||
stack: tuple[traceback.FrameSummary, ...]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
|
|
||||||
BlockingCallSpec("time.sleep", "time:sleep"),
|
|
||||||
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
|
|
||||||
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
|
|
||||||
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
|
|
||||||
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
|
|
||||||
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
|
|
||||||
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_event_loop_thread() -> bool:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return False
|
|
||||||
return loop.is_running()
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
|
|
||||||
module_name, attr_path = target.split(":", maxsplit=1)
|
|
||||||
owner: object = importlib.import_module(module_name)
|
|
||||||
parts = attr_path.split(".")
|
|
||||||
for part in parts[:-1]:
|
|
||||||
owner = getattr(owner, part)
|
|
||||||
|
|
||||||
attr_name = parts[-1]
|
|
||||||
original = getattr(owner, attr_name)
|
|
||||||
return owner, attr_name, original
|
|
||||||
|
|
||||||
|
|
||||||
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
|
|
||||||
return tuple(frame for frame in stack if frame.filename != __file__)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
|
|
||||||
"""Record blocking calls made from async runtime code.
|
|
||||||
|
|
||||||
By default the detector reports violations but does not fail on context
|
|
||||||
exit. Tests can set ``fail_on_exit=True`` or call
|
|
||||||
``assert_no_blocking_calls()`` explicitly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
|
||||||
*,
|
|
||||||
fail_on_exit: bool = False,
|
|
||||||
patch_loaded_aliases: bool = True,
|
|
||||||
stack_limit: int = 12,
|
|
||||||
) -> None:
|
|
||||||
self._specs = tuple(specs)
|
|
||||||
self._fail_on_exit = fail_on_exit
|
|
||||||
self._patch_loaded_aliases_enabled = patch_loaded_aliases
|
|
||||||
self._stack_limit = stack_limit
|
|
||||||
self._patches: list[tuple[object, str, BlockingCallable]] = []
|
|
||||||
self._patch_keys: set[tuple[int, str]] = set()
|
|
||||||
self.violations: list[BlockingCall] = []
|
|
||||||
self._active = False
|
|
||||||
|
|
||||||
def __enter__(self) -> BlockingIODetector:
|
|
||||||
try:
|
|
||||||
self._active = True
|
|
||||||
alias_replacements: dict[int, BlockingCallable] = {}
|
|
||||||
for spec in self._specs:
|
|
||||||
owner, attr_name, original = _resolve_target(spec.target)
|
|
||||||
wrapper = self._wrap(spec, original)
|
|
||||||
self._patch_attribute(owner, attr_name, original, wrapper)
|
|
||||||
alias_replacements[id(original)] = wrapper
|
|
||||||
|
|
||||||
if self._patch_loaded_aliases_enabled:
|
|
||||||
self._patch_loaded_module_aliases(alias_replacements)
|
|
||||||
except Exception:
|
|
||||||
self._restore()
|
|
||||||
self._active = False
|
|
||||||
raise
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(
|
|
||||||
self,
|
|
||||||
exc_type: type[BaseException] | None,
|
|
||||||
exc_value: BaseException | None,
|
|
||||||
traceback_value: TracebackType | None,
|
|
||||||
) -> bool | None:
|
|
||||||
self._restore()
|
|
||||||
self._active = False
|
|
||||||
if exc_type is None and self._fail_on_exit:
|
|
||||||
self.assert_no_blocking_calls()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _restore(self) -> None:
|
|
||||||
for owner, attr_name, original in reversed(self._patches):
|
|
||||||
setattr(owner, attr_name, original)
|
|
||||||
self._patches.clear()
|
|
||||||
self._patch_keys.clear()
|
|
||||||
|
|
||||||
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
|
|
||||||
key = (id(owner), attr_name)
|
|
||||||
if key in self._patch_keys:
|
|
||||||
return
|
|
||||||
setattr(owner, attr_name, replacement)
|
|
||||||
self._patches.append((owner, attr_name, original))
|
|
||||||
self._patch_keys.add(key)
|
|
||||||
|
|
||||||
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
|
|
||||||
for module in tuple(sys.modules.values()):
|
|
||||||
namespace = getattr(module, "__dict__", None)
|
|
||||||
if not isinstance(namespace, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for attr_name, value in tuple(namespace.items()):
|
|
||||||
replacement = replacements_by_id.get(id(value))
|
|
||||||
if replacement is not None:
|
|
||||||
self._patch_attribute(module, attr_name, value, replacement)
|
|
||||||
|
|
||||||
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
|
|
||||||
@wraps(original)
|
|
||||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
if spec.record_on_iteration:
|
|
||||||
result = original(*args, **kwargs)
|
|
||||||
return self._wrap_iteration(spec, result)
|
|
||||||
self._record_if_blocking(spec)
|
|
||||||
return original(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
|
|
||||||
iterator = iter(iterable)
|
|
||||||
reported = False
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if not reported:
|
|
||||||
reported = self._record_if_blocking(spec)
|
|
||||||
try:
|
|
||||||
yield next(iterator)
|
|
||||||
except StopIteration:
|
|
||||||
return
|
|
||||||
|
|
||||||
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
|
|
||||||
if self._active and _is_event_loop_thread():
|
|
||||||
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
|
|
||||||
self.violations.append(BlockingCall(spec.name, spec.target, stack))
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def assert_no_blocking_calls(self) -> None:
|
|
||||||
if self.violations:
|
|
||||||
raise AssertionError(format_blocking_calls(self.violations))
|
|
||||||
|
|
||||||
|
|
||||||
class BlockingIOProbe:
|
|
||||||
"""Collect detector output across tests and format a compact summary."""
|
|
||||||
|
|
||||||
def __init__(self, project_root: Path) -> None:
|
|
||||||
self._project_root = project_root.resolve()
|
|
||||||
self._observed: list[tuple[str, BlockingCall]] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def violation_count(self) -> int:
|
|
||||||
return len(self._observed)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def test_count(self) -> int:
|
|
||||||
return len({nodeid for nodeid, _violation in self._observed})
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self._observed.clear()
|
|
||||||
|
|
||||||
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
|
|
||||||
for violation in violations:
|
|
||||||
self._observed.append((nodeid, violation))
|
|
||||||
|
|
||||||
def format_summary(self, *, limit: int = 30) -> str:
|
|
||||||
if not self._observed:
|
|
||||||
return "blocking io probe: no violations"
|
|
||||||
|
|
||||||
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
|
|
||||||
for _nodeid, violation in self._observed:
|
|
||||||
frame = self._local_call_site(violation.stack)
|
|
||||||
if frame is None:
|
|
||||||
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
call_sites[
|
|
||||||
(
|
|
||||||
violation.name,
|
|
||||||
self._relative(frame.filename),
|
|
||||||
frame.lineno,
|
|
||||||
frame.name,
|
|
||||||
(frame.line or "").strip(),
|
|
||||||
)
|
|
||||||
] += 1
|
|
||||||
|
|
||||||
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
|
|
||||||
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
|
|
||||||
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
def _relative(self, filename: str) -> str:
|
|
||||||
try:
|
|
||||||
return str(Path(filename).resolve().relative_to(self._project_root))
|
|
||||||
except ValueError:
|
|
||||||
return filename
|
|
||||||
|
|
||||||
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
|
|
||||||
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
|
|
||||||
if local_frames:
|
|
||||||
return local_frames[-1]
|
|
||||||
|
|
||||||
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
|
|
||||||
return test_frames[-1] if test_frames else None
|
|
||||||
|
|
||||||
|
|
||||||
def detect_blocking_io(
|
|
||||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
|
||||||
*,
|
|
||||||
fail_on_exit: bool = False,
|
|
||||||
patch_loaded_aliases: bool = True,
|
|
||||||
stack_limit: int = 12,
|
|
||||||
) -> BlockingIODetector:
|
|
||||||
"""Create a detector context manager for a focused test scope."""
|
|
||||||
|
|
||||||
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
|
|
||||||
|
|
||||||
|
|
||||||
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
|
|
||||||
"""Format detector output with enough stack context to locate call sites."""
|
|
||||||
|
|
||||||
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
|
|
||||||
for index, violation in enumerate(violations, start=1):
|
|
||||||
lines.append(f"{index}. {violation.name} ({violation.target})")
|
|
||||||
lines.extend(_format_stack(violation.stack))
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
|
|
||||||
for frame in stack:
|
|
||||||
location = f"{frame.filename}:{frame.lineno}"
|
|
||||||
lines = [f" at {frame.name} ({location})"]
|
|
||||||
if frame.line:
|
|
||||||
lines.append(f" {frame.line.strip()}")
|
|
||||||
yield from lines
|
|
||||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||||
from fastapi import HTTPException
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import FileResponse
|
from starlette.responses import FileResponse
|
||||||
@@ -103,17 +102,3 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "hello"
|
assert response.text == "hello"
|
||||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||||
|
|
||||||
|
|
||||||
def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None:
|
|
||||||
skill_path = tmp_path / "sample.skill"
|
|
||||||
payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1)
|
|
||||||
with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref:
|
|
||||||
zip_ref.writestr("SKILL.md", payload)
|
|
||||||
|
|
||||||
assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES
|
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
|
||||||
artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md")
|
|
||||||
|
|
||||||
assert exc_info.value.status_code == 413
|
|
||||||
|
|||||||
@@ -5,26 +5,28 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import app.gateway.auth.config as cfg
|
from app.gateway.auth.config import AuthConfig
|
||||||
|
|
||||||
|
|
||||||
def test_auth_config_defaults():
|
def test_auth_config_defaults():
|
||||||
config = cfg.AuthConfig(jwt_secret="test-secret-key-123")
|
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||||
assert config.token_expiry_days == 7
|
assert config.token_expiry_days == 7
|
||||||
|
|
||||||
|
|
||||||
def test_auth_config_token_expiry_range():
|
def test_auth_config_token_expiry_range():
|
||||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=1)
|
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=30)
|
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=0)
|
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=31)
|
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||||
|
|
||||||
|
|
||||||
def test_auth_config_from_env():
|
def test_auth_config_from_env():
|
||||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||||
with patch.dict(os.environ, env, clear=False):
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
import app.gateway.auth.config as cfg
|
||||||
|
|
||||||
old = cfg._auth_config
|
old = cfg._auth_config
|
||||||
cfg._auth_config = None
|
cfg._auth_config = None
|
||||||
try:
|
try:
|
||||||
@@ -34,57 +36,19 @@ def test_auth_config_from_env():
|
|||||||
cfg._auth_config = old
|
cfg._auth_config = old
|
||||||
|
|
||||||
|
|
||||||
def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog):
|
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from deerflow.config.paths import Paths
|
import app.gateway.auth.config as cfg
|
||||||
|
|
||||||
old = cfg._auth_config
|
old = cfg._auth_config
|
||||||
cfg._auth_config = None
|
cfg._auth_config = None
|
||||||
secret_file = tmp_path / ".jwt_secret"
|
|
||||||
try:
|
try:
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
config = cfg.get_auth_config()
|
config = cfg.get_auth_config()
|
||||||
assert config.jwt_secret
|
assert config.jwt_secret
|
||||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||||
assert secret_file.exists()
|
|
||||||
assert secret_file.read_text().strip() == config.jwt_secret
|
|
||||||
finally:
|
|
||||||
cfg._auth_config = old
|
|
||||||
|
|
||||||
|
|
||||||
def test_auth_config_reuses_persisted_secret(tmp_path):
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
old = cfg._auth_config
|
|
||||||
cfg._auth_config = None
|
|
||||||
persisted = "persisted-secret-from-file-min-32-chars!!"
|
|
||||||
(tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8")
|
|
||||||
try:
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
|
||||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
|
||||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
|
|
||||||
config = cfg.get_auth_config()
|
|
||||||
assert config.jwt_secret == persisted
|
|
||||||
finally:
|
|
||||||
cfg._auth_config = old
|
|
||||||
|
|
||||||
|
|
||||||
def test_auth_config_empty_secret_file_generates_new(tmp_path):
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
old = cfg._auth_config
|
|
||||||
cfg._auth_config = None
|
|
||||||
(tmp_path / ".jwt_secret").write_text("", encoding="utf-8")
|
|
||||||
try:
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
|
||||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
|
||||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
|
|
||||||
config = cfg.get_auth_config()
|
|
||||||
assert config.jwt_secret
|
|
||||||
assert len(config.jwt_secret) > 20
|
|
||||||
assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret
|
|
||||||
finally:
|
finally:
|
||||||
cfg._auth_config = old
|
cfg._auth_config = old
|
||||||
|
|||||||
@@ -1,190 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from os import walk as imported_walk
|
|
||||||
from pathlib import Path
|
|
||||||
from time import sleep as imported_sleep
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
from support.detectors.blocking_io import (
|
|
||||||
BlockingCallSpec,
|
|
||||||
BlockingIOProbe,
|
|
||||||
detect_blocking_io,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
|
||||||
|
|
||||||
|
|
||||||
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
|
|
||||||
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
|
|
||||||
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
|
|
||||||
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
|
|
||||||
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_time_sleep_on_event_loop() -> None:
|
|
||||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
|
||||||
time.sleep(0)
|
|
||||||
|
|
||||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
|
|
||||||
original_alias = imported_sleep
|
|
||||||
|
|
||||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
|
||||||
imported_sleep(0)
|
|
||||||
|
|
||||||
assert imported_sleep is original_alias
|
|
||||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_can_disable_loaded_alias_patching() -> None:
|
|
||||||
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
|
|
||||||
imported_sleep(0)
|
|
||||||
|
|
||||||
assert detector.violations == []
|
|
||||||
|
|
||||||
|
|
||||||
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
|
|
||||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
|
||||||
await asyncio.to_thread(time.sleep, 0)
|
|
||||||
|
|
||||||
assert detector.violations == []
|
|
||||||
|
|
||||||
|
|
||||||
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
|
|
||||||
await asyncio.to_thread(time.sleep, 0)
|
|
||||||
|
|
||||||
assert blocking_io_detector.violations == []
|
|
||||||
|
|
||||||
|
|
||||||
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
|
|
||||||
def call_sleep() -> list[str]:
|
|
||||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
|
||||||
time.sleep(0)
|
|
||||||
return [violation.name for violation in detector.violations]
|
|
||||||
|
|
||||||
assert await asyncio.to_thread(call_sleep) == []
|
|
||||||
|
|
||||||
|
|
||||||
async def test_fail_on_exit_includes_call_site() -> None:
|
|
||||||
with pytest.raises(AssertionError) as exc_info:
|
|
||||||
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
|
|
||||||
time.sleep(0)
|
|
||||||
|
|
||||||
message = str(exc_info.value)
|
|
||||||
assert "time.sleep" in message
|
|
||||||
assert "test_fail_on_exit_includes_call_site" in message
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
|
|
||||||
return f"{method}:{url}"
|
|
||||||
|
|
||||||
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
|
|
||||||
|
|
||||||
with detect_blocking_io(REQUESTS_ONLY) as detector:
|
|
||||||
assert requests.get("https://example.invalid") == "get:https://example.invalid"
|
|
||||||
|
|
||||||
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
|
|
||||||
return httpx.Response(200, request=httpx.Request(method, url))
|
|
||||||
|
|
||||||
monkeypatch.setattr(httpx.Client, "request", fake_request)
|
|
||||||
|
|
||||||
with detect_blocking_io(HTTPX_ONLY) as detector:
|
|
||||||
with httpx.Client() as client:
|
|
||||||
response = client.get("https://example.invalid")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
|
|
||||||
(tmp_path / "nested").mkdir()
|
|
||||||
|
|
||||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
|
||||||
assert list(os.walk(tmp_path))
|
|
||||||
|
|
||||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
|
|
||||||
(tmp_path / "nested").mkdir()
|
|
||||||
original_alias = imported_walk
|
|
||||||
|
|
||||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
|
||||||
assert list(imported_walk(tmp_path))
|
|
||||||
|
|
||||||
assert imported_walk is original_alias
|
|
||||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
|
|
||||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
|
||||||
walker = os.walk(tmp_path)
|
|
||||||
|
|
||||||
assert list(walker)
|
|
||||||
assert detector.violations == []
|
|
||||||
|
|
||||||
|
|
||||||
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
|
|
||||||
(tmp_path / "nested").mkdir()
|
|
||||||
|
|
||||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
|
||||||
walker = os.walk(tmp_path)
|
|
||||||
assert await asyncio.to_thread(lambda: list(walker))
|
|
||||||
|
|
||||||
assert detector.violations == []
|
|
||||||
|
|
||||||
|
|
||||||
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
|
|
||||||
path = tmp_path / "data.txt"
|
|
||||||
path.write_text("content", encoding="utf-8")
|
|
||||||
|
|
||||||
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
|
|
||||||
assert path.read_text(encoding="utf-8") == "content"
|
|
||||||
|
|
||||||
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
|
|
||||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
|
||||||
path = tmp_path / "data.txt"
|
|
||||||
path.write_text("content", encoding="utf-8")
|
|
||||||
|
|
||||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
|
||||||
assert path.read_text(encoding="utf-8") == "content"
|
|
||||||
|
|
||||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
|
||||||
summary = probe.format_summary()
|
|
||||||
|
|
||||||
assert "blocking io probe: 1 violations across 1 tests" in summary
|
|
||||||
assert "pathlib.Path.read_text" in summary
|
|
||||||
|
|
||||||
|
|
||||||
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
|
|
||||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
|
||||||
|
|
||||||
assert probe.format_summary() == "blocking io probe: no violations"
|
|
||||||
|
|
||||||
path = tmp_path / "data.txt"
|
|
||||||
path.write_text("content", encoding="utf-8")
|
|
||||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
|
||||||
assert path.read_text(encoding="utf-8") == "content"
|
|
||||||
|
|
||||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
|
||||||
assert probe.violation_count == 1
|
|
||||||
|
|
||||||
probe.clear()
|
|
||||||
|
|
||||||
assert probe.violation_count == 0
|
|
||||||
assert probe.format_summary() == "blocking io probe: no violations"
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
ORIGINAL_SLEEP = time.sleep
|
|
||||||
|
|
||||||
|
|
||||||
def replacement_sleep(seconds: float) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(time, "sleep", replacement_sleep)
|
|
||||||
assert time.sleep is replacement_sleep
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.no_blocking_io_probe
|
|
||||||
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
|
|
||||||
assert time.sleep is ORIGINAL_SLEEP
|
|
||||||
assert getattr(time.sleep, "__wrapped__", None) is None
|
|
||||||
@@ -761,7 +761,7 @@ class TestChannelManager:
|
|||||||
|
|
||||||
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||||
|
|
||||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None):
|
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||||
del assistant_id, context # unused in this test, kept for signature parity
|
del assistant_id, context # unused in this test, kept for signature parity
|
||||||
|
|
||||||
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||||
|
|||||||
@@ -158,88 +158,6 @@ class TestBuildPatchedMessagesPatching:
|
|||||||
assert patched[1].name == "bash"
|
assert patched[1].name == "bash"
|
||||||
assert patched[1].status == "error"
|
assert patched[1].status == "error"
|
||||||
|
|
||||||
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
|
|
||||||
middleware = DanglingToolCallMiddleware()
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
||||||
HumanMessage(content="interruption"),
|
|
||||||
_tool_msg("call_1", "bash"),
|
|
||||||
]
|
|
||||||
patched = middleware._build_patched_messages(msgs)
|
|
||||||
assert patched is not None
|
|
||||||
assert isinstance(patched[0], AIMessage)
|
|
||||||
assert isinstance(patched[1], ToolMessage)
|
|
||||||
assert patched[1].tool_call_id == "call_1"
|
|
||||||
assert isinstance(patched[2], HumanMessage)
|
|
||||||
|
|
||||||
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
|
||||||
HumanMessage(content="interruption"),
|
|
||||||
_tool_msg("call_2", "read"),
|
|
||||||
_tool_msg("call_1", "bash"),
|
|
||||||
]
|
|
||||||
|
|
||||||
patched = mw._build_patched_messages(msgs)
|
|
||||||
|
|
||||||
assert patched is not None
|
|
||||||
assert isinstance(patched[0], AIMessage)
|
|
||||||
assert isinstance(patched[1], ToolMessage)
|
|
||||||
assert isinstance(patched[2], ToolMessage)
|
|
||||||
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
|
|
||||||
assert isinstance(patched[3], HumanMessage)
|
|
||||||
|
|
||||||
def test_valid_adjacent_tool_results_are_unchanged(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
||||||
_tool_msg("call_1", "bash"),
|
|
||||||
HumanMessage(content="next"),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert mw._build_patched_messages(msgs) is None
|
|
||||||
|
|
||||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
||||||
HumanMessage(content="interruption"),
|
|
||||||
_ai_with_tool_calls([_tc("read", "call_2")]),
|
|
||||||
_tool_msg("call_1", "bash"),
|
|
||||||
_tool_msg("call_2", "read"),
|
|
||||||
]
|
|
||||||
|
|
||||||
patched = mw._build_patched_messages(msgs)
|
|
||||||
|
|
||||||
assert patched is not None
|
|
||||||
assert isinstance(patched[0], AIMessage)
|
|
||||||
assert isinstance(patched[1], ToolMessage)
|
|
||||||
assert patched[1].tool_call_id == "call_1"
|
|
||||||
assert isinstance(patched[2], HumanMessage)
|
|
||||||
assert isinstance(patched[3], AIMessage)
|
|
||||||
assert isinstance(patched[4], ToolMessage)
|
|
||||||
assert patched[4].tool_call_id == "call_2"
|
|
||||||
|
|
||||||
def test_orphan_tool_message_is_preserved_during_grouping(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
orphan = _tool_msg("orphan_call", "orphan")
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
||||||
orphan,
|
|
||||||
HumanMessage(content="interruption"),
|
|
||||||
_tool_msg("call_1", "bash"),
|
|
||||||
]
|
|
||||||
|
|
||||||
patched = mw._build_patched_messages(msgs)
|
|
||||||
|
|
||||||
assert patched is not None
|
|
||||||
assert isinstance(patched[0], AIMessage)
|
|
||||||
assert isinstance(patched[1], ToolMessage)
|
|
||||||
assert patched[1].tool_call_id == "call_1"
|
|
||||||
assert orphan in patched
|
|
||||||
assert patched.count(orphan) == 1
|
|
||||||
|
|
||||||
def test_invalid_tool_call_is_patched(self):
|
def test_invalid_tool_call_is_patched(self):
|
||||||
mw = DanglingToolCallMiddleware()
|
mw = DanglingToolCallMiddleware()
|
||||||
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
|
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
|
||||||
|
|||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Real-LLM end-to-end verification for issue #2884.
|
|
||||||
|
|
||||||
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
|
|
||||||
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
|
|
||||||
and the production ``get_available_tools`` pipeline. The only thing we mock is
|
|
||||||
the MCP tool source — we hand-roll two ``@tool``s and inject them through
|
|
||||||
``deerflow.mcp.cache.get_cached_mcp_tools``.
|
|
||||||
|
|
||||||
The flow exercised:
|
|
||||||
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
|
|
||||||
that re-enters ``get_available_tools`` on the same task — this is the
|
|
||||||
code path issue #2884 reports). It must call ``tool_search`` to
|
|
||||||
discover the deferred ``fake_calculator`` tool.
|
|
||||||
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
|
|
||||||
``fake_subagent_trigger`` re-enters ``get_available_tools``.
|
|
||||||
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
|
|
||||||
so it can actually call it. Without this PR's fix, the re-entry wipes
|
|
||||||
the promotion and the model can no longer invoke the tool.
|
|
||||||
|
|
||||||
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
|
|
||||||
test run. Run with::
|
|
||||||
|
|
||||||
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
|
|
||||||
PYTHONPATH=. uv run pytest \
|
|
||||||
tests/test_deferred_tool_promotion_real_llm.py -v -s
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.tools import tool as as_tool
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Skip control: only run when explicitly opted in.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.getenv("ONEAPI_E2E") != "1",
|
|
||||||
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Fake "MCP" tools the agent should discover via tool_search.
|
|
||||||
# Keep them obviously synthetic so the model can pattern-match the search.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
_calls: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
@as_tool
|
|
||||||
def fake_calculator(expression: str) -> str:
|
|
||||||
"""Evaluate a tiny arithmetic expression like '2 + 2'.
|
|
||||||
|
|
||||||
Reserved for the user — only call this if the user asks for arithmetic.
|
|
||||||
"""
|
|
||||||
_calls.append(f"fake_calculator:{expression}")
|
|
||||||
try:
|
|
||||||
# Trivially safe-eval just for the e2e check
|
|
||||||
allowed = set("0123456789+-*/() .")
|
|
||||||
if not set(expression) <= allowed:
|
|
||||||
return "expression contains disallowed characters"
|
|
||||||
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
|
|
||||||
except Exception as e:
|
|
||||||
return f"error: {e}"
|
|
||||||
|
|
||||||
|
|
||||||
@as_tool
|
|
||||||
def fake_translator(text: str, target_lang: str) -> str:
|
|
||||||
"""Translate text into the given language code. Decorative — not used."""
|
|
||||||
_calls.append(f"fake_translator:{text}:{target_lang}")
|
|
||||||
return f"[{target_lang}] {text}"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pipeline wiring (same shape as the in-process tests).
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _reset_registry_between_tests():
|
|
||||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
|
||||||
|
|
||||||
reset_deferred_registry()
|
|
||||||
yield
|
|
||||||
reset_deferred_registry()
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
|
||||||
|
|
||||||
real_ext = ExtensionsConfig(
|
|
||||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
|
||||||
classmethod(lambda cls: real_ext),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
|
||||||
|
|
||||||
|
|
||||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Build a minimal mock AppConfig and patch the symbol — never call the
|
|
||||||
real loader, which would trigger ``_apply_singleton_configs`` and
|
|
||||||
permanently mutate cross-test singletons (memory, title, …)."""
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
|
||||||
|
|
||||||
mock_cfg = AppConfig.model_construct(
|
|
||||||
log_level="info",
|
|
||||||
models=[],
|
|
||||||
tools=[],
|
|
||||||
tool_groups=[],
|
|
||||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
|
||||||
tool_search=ToolSearchConfig(enabled=True),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Real-LLM e2e test
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""End-to-end against a real OpenAI-compatible LLM.
|
|
||||||
|
|
||||||
The model must:
|
|
||||||
Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and
|
|
||||||
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
|
|
||||||
``fake_subagent_trigger(...)``.
|
|
||||||
Turn 2 — call ``fake_calculator`` and finish.
|
|
||||||
|
|
||||||
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
|
|
||||||
layer — recorded in ``_calls`` — which proves the model received the
|
|
||||||
promoted schema after the re-entrant ``get_available_tools`` call.
|
|
||||||
"""
|
|
||||||
from langchain.agents import create_agent
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
|
||||||
from deerflow.tools.tools import get_available_tools
|
|
||||||
|
|
||||||
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
|
|
||||||
_force_tool_search_enabled(monkeypatch)
|
|
||||||
_calls.clear()
|
|
||||||
|
|
||||||
@as_tool
|
|
||||||
async def fake_subagent_trigger(prompt: str) -> str:
|
|
||||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
|
|
||||||
|
|
||||||
Use this whenever the user asks you to delegate work — pass a short
|
|
||||||
description as ``prompt``.
|
|
||||||
"""
|
|
||||||
# ``task_tool`` does this internally. Whether the registry-reset that
|
|
||||||
# used to happen here actually leaks back to the parent task depends
|
|
||||||
# on asyncio's implicit context-copying semantics (gather creates
|
|
||||||
# child tasks with copied contexts, so reset_deferred_registry is
|
|
||||||
# task-local) — but the fix in this PR is what GUARANTEES the
|
|
||||||
# promotion sticks regardless of which integration path triggers a
|
|
||||||
# re-entrant ``get_available_tools`` call.
|
|
||||||
get_available_tools(subagent_enabled=False)
|
|
||||||
_calls.append(f"fake_subagent_trigger:{prompt}")
|
|
||||||
return "subagent completed"
|
|
||||||
|
|
||||||
tools = get_available_tools() + [fake_subagent_trigger]
|
|
||||||
|
|
||||||
model = ChatOpenAI(
|
|
||||||
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
|
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
|
||||||
base_url=os.environ["OPENAI_API_BASE"],
|
|
||||||
temperature=0,
|
|
||||||
max_retries=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
system_prompt = (
|
|
||||||
"You are a meticulous assistant. Available deferred tools include a "
|
|
||||||
"calculator and a translator — their schemas are hidden until you "
|
|
||||||
"search for them via tool_search.\n\n"
|
|
||||||
"Procedure for the user's request:\n"
|
|
||||||
" 1. Call tool_search with query 'select:fake_calculator' AND "
|
|
||||||
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
|
|
||||||
"to delegate the side work. Put both tool_calls in your first response.\n"
|
|
||||||
" 2. After both tool messages come back, call fake_calculator with "
|
|
||||||
"the user's expression.\n"
|
|
||||||
" 3. Reply with just the numeric result."
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = create_agent(
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
middleware=[DeferredToolFilterMiddleware()],
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await graph.ainvoke(
|
|
||||||
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
|
|
||||||
config={"recursion_limit": 12},
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n=== tool calls recorded ===")
|
|
||||||
for c in _calls:
|
|
||||||
print(f" {c}")
|
|
||||||
print("\n=== final message ===")
|
|
||||||
final_text = result["messages"][-1].content if result["messages"] else "(none)"
|
|
||||||
print(f" {final_text!r}")
|
|
||||||
|
|
||||||
# The smoking-gun assertion: fake_calculator was actually invoked at the
|
|
||||||
# tool layer. This is only possible if the promoted schema reached the
|
|
||||||
# model in turn 2, despite the subagent-style re-entry in turn 1.
|
|
||||||
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
|
|
||||||
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
|
|
||||||
|
|
||||||
# And the math should actually be done correctly (sanity that the LLM
|
|
||||||
# really used the result, not just hallucinated the answer).
|
|
||||||
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
|
|
||||||
@@ -1,390 +0,0 @@
|
|||||||
"""Reproduce + regression-guard issue #2884.
|
|
||||||
|
|
||||||
Hypothesis from the issue:
|
|
||||||
``tools.tools.get_available_tools`` unconditionally calls
|
|
||||||
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
|
|
||||||
every time it is invoked. If anything calls ``get_available_tools`` again
|
|
||||||
during the same async context (after the agent has promoted tools via
|
|
||||||
``tool_search``), the promotion is wiped and the next model call hides the
|
|
||||||
tool's schema again.
|
|
||||||
|
|
||||||
These tests pin two things:
|
|
||||||
|
|
||||||
A. **At the unit boundary** — verify the failure mode directly. Promote a
|
|
||||||
tool in the registry, then call ``get_available_tools`` again and observe
|
|
||||||
that the ContextVar registry is reset and the promotion is lost.
|
|
||||||
|
|
||||||
B. **At the graph-execution boundary** — drive a real ``create_agent`` graph
|
|
||||||
with the real ``DeferredToolFilterMiddleware`` through two model turns.
|
|
||||||
The first turn calls ``tool_search`` which promotes a tool. The second
|
|
||||||
turn must see that tool's schema in ``request.tools``. If
|
|
||||||
``get_available_tools`` were to run again between the two turns and reset
|
|
||||||
the registry, the second turn's filter would strip the tool.
|
|
||||||
|
|
||||||
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
|
|
||||||
unmodified; mock only the LLM and the MCP tool source. Patch
|
|
||||||
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
|
|
||||||
``get_available_tools`` resolves via lazy import) to return our fixture
|
|
||||||
tools so we don't need a real MCP server.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
|
||||||
from langchain_core.runnables import Runnable
|
|
||||||
from langchain_core.tools import tool as as_tool
|
|
||||||
|
|
||||||
|
|
||||||
class FakeToolCallingModel(FakeMessagesListChatModel):
|
|
||||||
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
|
|
||||||
|
|
||||||
def bind_tools( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
tools: Any,
|
|
||||||
*,
|
|
||||||
tool_choice: Any = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Runnable:
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@as_tool
|
|
||||||
def fake_mcp_search(query: str) -> str:
|
|
||||||
"""Pretend to search a knowledge base for the given query."""
|
|
||||||
return f"results for {query}"
|
|
||||||
|
|
||||||
|
|
||||||
@as_tool
|
|
||||||
def fake_mcp_fetch(url: str) -> str:
|
|
||||||
"""Pretend to fetch a page at the given URL."""
|
|
||||||
return f"content of {url}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _supply_env(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
|
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
|
|
||||||
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _reset_deferred_registry_between_tests():
|
|
||||||
"""Each test must start with a clean ContextVar.
|
|
||||||
|
|
||||||
The registry lives in a module-level ContextVar with no per-task isolation
|
|
||||||
in a synchronous test runner, so one test's promotion can leak into the
|
|
||||||
next and silently break filter assertions.
|
|
||||||
"""
|
|
||||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
|
||||||
|
|
||||||
reset_deferred_registry()
|
|
||||||
yield
|
|
||||||
reset_deferred_registry()
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
|
||||||
"""Make get_available_tools believe an MCP server is registered.
|
|
||||||
|
|
||||||
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
|
|
||||||
that both ``AppConfig.from_file`` (which calls
|
|
||||||
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
|
|
||||||
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
|
|
||||||
see a valid instance. Then point the MCP tool cache at our fixture tools.
|
|
||||||
"""
|
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
|
||||||
|
|
||||||
real_ext = ExtensionsConfig(
|
|
||||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
|
||||||
classmethod(lambda cls: real_ext),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
|
||||||
|
|
||||||
|
|
||||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Force config.tool_search.enabled=True without touching the yaml.
|
|
||||||
|
|
||||||
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
|
|
||||||
which permanently mutates module-level singletons (``_memory_config``,
|
|
||||||
``_title_config``, …) to match the developer's ``config.yaml`` — even
|
|
||||||
after pytest restores our patch. That leaks across tests later in the
|
|
||||||
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
|
|
||||||
require ``_memory_config.enabled = True``, which is the dataclass default
|
|
||||||
but FALSE in the actual yaml).
|
|
||||||
|
|
||||||
Build a minimal mock AppConfig instead and never call the real loader.
|
|
||||||
"""
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
|
||||||
|
|
||||||
mock_cfg = AppConfig.model_construct(
|
|
||||||
log_level="info",
|
|
||||||
models=[],
|
|
||||||
tools=[],
|
|
||||||
tool_groups=[],
|
|
||||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
|
||||||
tool_search=ToolSearchConfig(enabled=True),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Section A — direct unit-level reproduction
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
|
|
||||||
|
|
||||||
Step 1: call get_available_tools() — registers MCP tools as deferred.
|
|
||||||
Step 2: simulate the agent calling tool_search by promoting one tool.
|
|
||||||
Step 3: call get_available_tools() again (the same code path
|
|
||||||
``task_tool`` exercises mid-run).
|
|
||||||
|
|
||||||
Assertion: after step 3, the promoted tool is STILL promoted (not
|
|
||||||
re-deferred). On ``main`` before the fix, step 3's
|
|
||||||
``reset_deferred_registry()`` wiped the promotion and re-registered
|
|
||||||
every MCP tool as deferred — this assertion fired with
|
|
||||||
``REGRESSION (#2884)``.
|
|
||||||
"""
|
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
|
||||||
from deerflow.tools.tools import get_available_tools
|
|
||||||
|
|
||||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
|
||||||
_force_tool_search_enabled(monkeypatch)
|
|
||||||
|
|
||||||
# Step 1: first call — both MCP tools start deferred
|
|
||||||
get_available_tools()
|
|
||||||
reg1 = get_deferred_registry()
|
|
||||||
assert reg1 is not None
|
|
||||||
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
|
|
||||||
|
|
||||||
# Step 2: simulate tool_search promoting one of them
|
|
||||||
reg1.promote({"fake_mcp_search"})
|
|
||||||
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
|
|
||||||
|
|
||||||
# Step 3: second call — registry must NOT silently undo the promotion
|
|
||||||
get_available_tools()
|
|
||||||
reg2 = get_deferred_registry()
|
|
||||||
assert reg2 is not None
|
|
||||||
deferred_after = {e.name for e in reg2.entries}
|
|
||||||
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Section B — graph-execution reproduction
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _ToolSearchPromotingModel(FakeToolCallingModel):
|
|
||||||
"""Two-turn model that:
|
|
||||||
|
|
||||||
Turn 1 → emit a tool_call for ``tool_search`` (the real one)
|
|
||||||
Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool)
|
|
||||||
|
|
||||||
Records the tools it received on each turn so the test can inspect what
|
|
||||||
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
bound_tools_per_turn: list[list[str]] = []
|
|
||||||
|
|
||||||
def bind_tools( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
tools: Any,
|
|
||||||
*,
|
|
||||||
tool_choice: Any = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Runnable:
|
|
||||||
# Record the tool names the model would see in this turn
|
|
||||||
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
|
|
||||||
self.bound_tools_per_turn.append(names)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def _build_promoting_model() -> _ToolSearchPromotingModel:
|
|
||||||
return _ToolSearchPromotingModel(
|
|
||||||
responses=[
|
|
||||||
AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": "tool_search",
|
|
||||||
"args": {"query": "select:fake_mcp_search"},
|
|
||||||
"id": "call_search_1",
|
|
||||||
"type": "tool_call",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": "fake_mcp_search",
|
|
||||||
"args": {"query": "hello"},
|
|
||||||
"id": "call_mcp_1",
|
|
||||||
"type": "tool_call",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
AIMessage(content="all done"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""End-to-end: drive a real create_agent graph through two turns.
|
|
||||||
|
|
||||||
Without the fix, the second-turn bind_tools call should NOT contain
|
|
||||||
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
|
|
||||||
registry and strips it). With the fix, the model sees the schema and can
|
|
||||||
invoke it.
|
|
||||||
"""
|
|
||||||
from langchain.agents import create_agent
|
|
||||||
|
|
||||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
|
||||||
from deerflow.tools.tools import get_available_tools
|
|
||||||
|
|
||||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
|
||||||
_force_tool_search_enabled(monkeypatch)
|
|
||||||
|
|
||||||
tools = get_available_tools()
|
|
||||||
# Sanity: the assembled tool list includes the deferred tools (they're in
|
|
||||||
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
|
|
||||||
# they reach the model)
|
|
||||||
tool_names = {getattr(t, "name", "") for t in tools}
|
|
||||||
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
|
|
||||||
|
|
||||||
model = _build_promoting_model()
|
|
||||||
model.bound_tools_per_turn = [] # reset class-level recorder
|
|
||||||
|
|
||||||
graph = create_agent(
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
middleware=[DeferredToolFilterMiddleware()],
|
|
||||||
system_prompt="bug-2884-repro",
|
|
||||||
)
|
|
||||||
|
|
||||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
|
||||||
|
|
||||||
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
|
|
||||||
turn1 = set(model.bound_tools_per_turn[0])
|
|
||||||
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
|
|
||||||
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
|
|
||||||
|
|
||||||
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
|
|
||||||
# This is the load-bearing assertion for issue #2884.
|
|
||||||
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
|
|
||||||
turn2 = set(model.bound_tools_per_turn[1])
|
|
||||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Section C — the actual issue #2884 trigger: a re-entrant
|
|
||||||
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
|
|
||||||
# wipe the parent's promotion.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
|
|
||||||
(the same pattern that happens when ``task_tool`` builds a subagent's
|
|
||||||
toolset mid-run) must not wipe the parent agent's tool_search promotions.
|
|
||||||
|
|
||||||
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
|
|
||||||
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
|
|
||||||
``get_available_tools`` again — exactly what ``task_tool`` does when it
|
|
||||||
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
|
|
||||||
promoted tool. Without the fix, the re-entry wipes the registry and
|
|
||||||
the filter re-hides it.
|
|
||||||
"""
|
|
||||||
from langchain.agents import create_agent
|
|
||||||
|
|
||||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
|
||||||
from deerflow.tools.tools import get_available_tools
|
|
||||||
|
|
||||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
|
||||||
_force_tool_search_enabled(monkeypatch)
|
|
||||||
|
|
||||||
# The trigger tool simulates what task_tool does internally: rebuild the
|
|
||||||
# toolset by calling get_available_tools while the registry is live.
|
|
||||||
@as_tool
|
|
||||||
def fake_subagent_trigger(prompt: str) -> str:
|
|
||||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
|
|
||||||
get_available_tools(subagent_enabled=False)
|
|
||||||
return f"spawned subagent for: {prompt}"
|
|
||||||
|
|
||||||
tools = get_available_tools() + [fake_subagent_trigger]
|
|
||||||
|
|
||||||
bound_per_turn: list[list[str]] = []
|
|
||||||
|
|
||||||
class _Model(FakeToolCallingModel):
|
|
||||||
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
|
|
||||||
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
|
|
||||||
return self
|
|
||||||
|
|
||||||
model = _Model(
|
|
||||||
responses=[
|
|
||||||
# Turn 1: do both in one batch — promote AND trigger the
|
|
||||||
# subagent-style rebuild. LangGraph executes them in order in the
|
|
||||||
# same agent step.
|
|
||||||
AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": "tool_search",
|
|
||||||
"args": {"query": "select:fake_mcp_search"},
|
|
||||||
"id": "call_search_1",
|
|
||||||
"type": "tool_call",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "fake_subagent_trigger",
|
|
||||||
"args": {"prompt": "go"},
|
|
||||||
"id": "call_trigger_1",
|
|
||||||
"type": "tool_call",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
# Turn 2: try to invoke the promoted tool. The model gets this
|
|
||||||
# turn only if turn 1's bind_tools recorded what the filter sent.
|
|
||||||
AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": "fake_mcp_search",
|
|
||||||
"args": {"query": "hello"},
|
|
||||||
"id": "call_mcp_1",
|
|
||||||
"type": "tool_call",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
AIMessage(content="all done"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = create_agent(
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
middleware=[DeferredToolFilterMiddleware()],
|
|
||||||
system_prompt="bug-2884-subagent-repro",
|
|
||||||
)
|
|
||||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
|
||||||
|
|
||||||
# Turn 1 sanity: deferred tool not visible yet
|
|
||||||
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
|
|
||||||
|
|
||||||
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
|
|
||||||
# re-entrant get_available_tools call that happened in turn 1's tool batch.
|
|
||||||
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
|
|
||||||
turn2 = set(bound_per_turn[1])
|
|
||||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
@@ -164,85 +164,3 @@ def test_flush_nowait_is_non_blocking() -> None:
|
|||||||
assert elapsed < 0.1
|
assert elapsed < 0.1
|
||||||
assert finished.is_set() is False
|
assert finished.is_set() is False
|
||||||
assert finished.wait(1.0) is True
|
assert finished.wait(1.0) is True
|
||||||
|
|
||||||
|
|
||||||
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
|
|
||||||
queue = MemoryUpdateQueue()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
|
||||||
patch.object(queue, "_reset_timer"),
|
|
||||||
):
|
|
||||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
|
||||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
|
||||||
|
|
||||||
assert queue.pending_count == 2
|
|
||||||
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
|
|
||||||
queue = MemoryUpdateQueue()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
|
||||||
patch.object(queue, "_reset_timer"),
|
|
||||||
):
|
|
||||||
queue.add(
|
|
||||||
thread_id="thread-1",
|
|
||||||
messages=["first"],
|
|
||||||
agent_name="agent-a",
|
|
||||||
correction_detected=True,
|
|
||||||
)
|
|
||||||
queue.add(
|
|
||||||
thread_id="thread-1",
|
|
||||||
messages=["second"],
|
|
||||||
agent_name="agent-a",
|
|
||||||
correction_detected=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert queue.pending_count == 1
|
|
||||||
assert queue._queue[0].agent_name == "agent-a"
|
|
||||||
assert queue._queue[0].messages == ["second"]
|
|
||||||
assert queue._queue[0].correction_detected is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
|
|
||||||
queue = MemoryUpdateQueue()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
|
||||||
patch.object(queue, "_reset_timer"),
|
|
||||||
):
|
|
||||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
|
||||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
|
||||||
|
|
||||||
mock_updater = MagicMock()
|
|
||||||
mock_updater.update_memory.return_value = True
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
|
|
||||||
patch("deerflow.agents.memory.queue.time.sleep"),
|
|
||||||
):
|
|
||||||
queue.flush()
|
|
||||||
|
|
||||||
assert mock_updater.update_memory.call_count == 2
|
|
||||||
mock_updater.update_memory.assert_has_calls(
|
|
||||||
[
|
|
||||||
call(
|
|
||||||
messages=["agent-a"],
|
|
||||||
thread_id="thread-1",
|
|
||||||
agent_name="agent-a",
|
|
||||||
correction_detected=False,
|
|
||||||
reinforcement_detected=False,
|
|
||||||
user_id=None,
|
|
||||||
),
|
|
||||||
call(
|
|
||||||
messages=["agent-b"],
|
|
||||||
thread_id="thread-1",
|
|
||||||
agent_name="agent-b",
|
|
||||||
correction_detected=False,
|
|
||||||
reinforcement_detected=False,
|
|
||||||
user_id=None,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_context_has_user_id():
|
def test_conversation_context_has_user_id():
|
||||||
@@ -18,7 +17,7 @@ def test_conversation_context_user_id_default_none():
|
|||||||
|
|
||||||
def test_queue_add_stores_user_id():
|
def test_queue_add_stores_user_id():
|
||||||
q = MemoryUpdateQueue()
|
q = MemoryUpdateQueue()
|
||||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
with patch.object(q, "_reset_timer"):
|
||||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||||
assert len(q._queue) == 1
|
assert len(q._queue) == 1
|
||||||
assert q._queue[0].user_id == "alice"
|
assert q._queue[0].user_id == "alice"
|
||||||
@@ -27,7 +26,7 @@ def test_queue_add_stores_user_id():
|
|||||||
|
|
||||||
def test_queue_process_passes_user_id_to_updater():
|
def test_queue_process_passes_user_id_to_updater():
|
||||||
q = MemoryUpdateQueue()
|
q = MemoryUpdateQueue()
|
||||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
with patch.object(q, "_reset_timer"):
|
||||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||||
|
|
||||||
mock_updater = MagicMock()
|
mock_updater = MagicMock()
|
||||||
@@ -38,42 +37,3 @@ def test_queue_process_passes_user_id_to_updater():
|
|||||||
mock_updater.update_memory.assert_called_once()
|
mock_updater.update_memory.assert_called_once()
|
||||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||||
assert call_kwargs["user_id"] == "alice"
|
assert call_kwargs["user_id"] == "alice"
|
||||||
|
|
||||||
|
|
||||||
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
|
|
||||||
q = MemoryUpdateQueue()
|
|
||||||
|
|
||||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
|
||||||
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
|
||||||
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
|
||||||
|
|
||||||
assert q.pending_count == 2
|
|
||||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
|
||||||
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
|
|
||||||
q = MemoryUpdateQueue()
|
|
||||||
|
|
||||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
|
||||||
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
|
|
||||||
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
|
|
||||||
|
|
||||||
assert q.pending_count == 1
|
|
||||||
assert q._queue[0].messages == ["second"]
|
|
||||||
assert q._queue[0].user_id == "alice"
|
|
||||||
assert q._queue[0].agent_name == "researcher"
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_nowait_keeps_different_users_separate():
|
|
||||||
q = MemoryUpdateQueue()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
|
|
||||||
patch.object(q, "_schedule_timer"),
|
|
||||||
):
|
|
||||||
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
|
||||||
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
|
||||||
|
|
||||||
assert q.pending_count == 2
|
|
||||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
|
||||||
|
|||||||
@@ -454,6 +454,7 @@ class TestAStream:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_tools_emits_tool_call_chunk(self):
|
async def test_with_tools_emits_tool_call_chunk(self):
|
||||||
|
|
||||||
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
|
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
|
||||||
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
|
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
|
||||||
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
|
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
|
||||||
|
|||||||
@@ -268,39 +268,6 @@ class TestEdgeCases:
|
|||||||
class TestDbRunEventStore:
|
class TestDbRunEventStore:
|
||||||
"""Tests for DbRunEventStore with temp SQLite."""
|
"""Tests for DbRunEventStore with temp SQLite."""
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
|
||||||
|
|
||||||
class FakeSession:
|
|
||||||
def __init__(self):
|
|
||||||
self.dialect = postgresql.dialect()
|
|
||||||
self.execute_calls = []
|
|
||||||
self.scalar_stmt = None
|
|
||||||
|
|
||||||
def get_bind(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def execute(self, stmt, params=None):
|
|
||||||
self.execute_calls.append((stmt, params))
|
|
||||||
|
|
||||||
async def scalar(self, stmt):
|
|
||||||
self.scalar_stmt = stmt
|
|
||||||
return 41
|
|
||||||
|
|
||||||
session = FakeSession()
|
|
||||||
|
|
||||||
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
|
|
||||||
|
|
||||||
assert max_seq == 41
|
|
||||||
assert session.execute_calls
|
|
||||||
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
|
|
||||||
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
|
|
||||||
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
|
|
||||||
assert "FOR UPDATE" not in compiled
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_basic_crud(self, tmp_path):
|
async def test_basic_crud(self, tmp_path):
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
|||||||
@@ -3,10 +3,7 @@
|
|||||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
|
|
||||||
@@ -281,48 +278,3 @@ class TestRunRepository:
|
|||||||
assert row4["model_name"] is None
|
assert row4["model_name"] is None
|
||||||
|
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
|
|
||||||
captured = []
|
|
||||||
|
|
||||||
class FakeResult:
|
|
||||||
def all(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
class FakeSession:
|
|
||||||
async def execute(self, stmt):
|
|
||||||
captured.append(stmt)
|
|
||||||
return FakeResult()
|
|
||||||
|
|
||||||
class FakeSessionContext:
|
|
||||||
async def __aenter__(self):
|
|
||||||
return FakeSession()
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
|
||||||
return None
|
|
||||||
|
|
||||||
repo = RunRepository(lambda: FakeSessionContext())
|
|
||||||
|
|
||||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
|
||||||
assert agg == {
|
|
||||||
"total_tokens": 0,
|
|
||||||
"total_input_tokens": 0,
|
|
||||||
"total_output_tokens": 0,
|
|
||||||
"total_runs": 0,
|
|
||||||
"by_model": {},
|
|
||||||
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
|
|
||||||
}
|
|
||||||
assert len(captured) == 1
|
|
||||||
|
|
||||||
stmt = captured[0]
|
|
||||||
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
|
|
||||||
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
|
|
||||||
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
|
|
||||||
|
|
||||||
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
|
|
||||||
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
|
|
||||||
|
|
||||||
assert select_match is not None
|
|
||||||
assert group_by_match is not None
|
|
||||||
assert select_match.group(1) == group_by_match.group(1)
|
|
||||||
|
|||||||
@@ -30,18 +30,12 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _runtime(
|
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||||
thread_id: str | None = "thread-1",
|
|
||||||
agent_name: str | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> SimpleNamespace:
|
|
||||||
context = {}
|
context = {}
|
||||||
if thread_id is not None:
|
if thread_id is not None:
|
||||||
context["thread_id"] = thread_id
|
context["thread_id"] = thread_id
|
||||||
if agent_name is not None:
|
if agent_name is not None:
|
||||||
context["agent_name"] = agent_name
|
context["agent_name"] = agent_name
|
||||||
if user_id is not None:
|
|
||||||
context["user_id"] = user_id
|
|
||||||
return SimpleNamespace(context=context)
|
return SimpleNamespace(context=context)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,8 +50,7 @@ def _middleware(
|
|||||||
preserve_recent_skill_tokens_per_skill: int = 0,
|
preserve_recent_skill_tokens_per_skill: int = 0,
|
||||||
) -> DeerFlowSummarizationMiddleware:
|
) -> DeerFlowSummarizationMiddleware:
|
||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
model.invoke.return_value = AIMessage(content="compressed summary")
|
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||||
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
|
|
||||||
return DeerFlowSummarizationMiddleware(
|
return DeerFlowSummarizationMiddleware(
|
||||||
model=model,
|
model=model,
|
||||||
trigger=trigger,
|
trigger=trigger,
|
||||||
@@ -641,99 +634,3 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
|||||||
|
|
||||||
queue.add_nowait.assert_called_once()
|
queue.add_nowait.assert_called_once()
|
||||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Issue #2804: summary text must not leak to the frontend via streaming
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_new_messages_sets_hide_from_ui() -> None:
|
|
||||||
"""The summary HumanMessage must carry hide_from_ui so the frontend filters it."""
|
|
||||||
middleware = _middleware()
|
|
||||||
messages = middleware._build_new_messages("test summary")
|
|
||||||
|
|
||||||
assert len(messages) == 1
|
|
||||||
msg = messages[0]
|
|
||||||
assert msg.name == "summary"
|
|
||||||
assert msg.additional_kwargs.get("hide_from_ui") is True
|
|
||||||
assert "test summary" in msg.content
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_summary_suppresses_callbacks() -> None:
|
|
||||||
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
|
||||||
in the invoke config to suppress inherited LangGraph stream callbacks."""
|
|
||||||
middleware = _middleware()
|
|
||||||
|
|
||||||
middleware._create_summary(_messages())
|
|
||||||
|
|
||||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
|
||||||
bound = middleware.model.with_config.return_value
|
|
||||||
bound.invoke.assert_called_once()
|
|
||||||
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
|
|
||||||
assert call_config is not None
|
|
||||||
assert call_config.get("callbacks") == []
|
|
||||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_acreate_summary_suppresses_callbacks() -> None:
|
|
||||||
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
|
||||||
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
|
|
||||||
middleware = _middleware()
|
|
||||||
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
|
||||||
|
|
||||||
await middleware._acreate_summary(_messages())
|
|
||||||
|
|
||||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
|
||||||
bound = middleware.model.with_config.return_value
|
|
||||||
bound.ainvoke.assert_called_once()
|
|
||||||
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
|
|
||||||
assert call_config is not None
|
|
||||||
assert call_config.get("callbacks") == []
|
|
||||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
|
||||||
|
|
||||||
|
|
||||||
def test_before_model_summary_message_has_hide_from_ui() -> None:
|
|
||||||
"""End-to-end: the emitted state update contains a summary message with hide_from_ui."""
|
|
||||||
middleware = _middleware()
|
|
||||||
|
|
||||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
|
||||||
|
|
||||||
emitted = result["messages"]
|
|
||||||
summary_msg = emitted[1]
|
|
||||||
assert summary_msg.name == "summary"
|
|
||||||
assert summary_msg.additional_kwargs.get("hide_from_ui") is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
queue = MagicMock()
|
|
||||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
|
||||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
|
||||||
|
|
||||||
memory_flush_hook(
|
|
||||||
SummarizationEvent(
|
|
||||||
messages_to_summarize=tuple(_messages()[:2]),
|
|
||||||
preserved_messages=(),
|
|
||||||
thread_id="main",
|
|
||||||
agent_name="researcher",
|
|
||||||
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
queue.add_nowait.assert_called_once()
|
|
||||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
|
|
||||||
"""AIMessage.content can be a list of content blocks; _extract_summary_text
|
|
||||||
must normalize to plain text via the .text property instead of producing
|
|
||||||
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
|
|
||||||
middleware = _middleware()
|
|
||||||
|
|
||||||
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
|
|
||||||
assert middleware._extract_summary_text(response) == "A summary of the chat."
|
|
||||||
|
|
||||||
# Plain string content still works
|
|
||||||
response_str = AIMessage(content="Plain summary")
|
|
||||||
assert middleware._extract_summary_text(response_str) == "Plain summary"
|
|
||||||
|
|||||||
@@ -59,15 +59,12 @@ def _make_result(
|
|||||||
ai_messages: list[dict] | None = None,
|
ai_messages: list[dict] | None = None,
|
||||||
result: str | None = None,
|
result: str | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
token_usage_records: list[dict] | None = None,
|
|
||||||
) -> SimpleNamespace:
|
) -> SimpleNamespace:
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
status=status,
|
status=status,
|
||||||
ai_messages=ai_messages or [],
|
ai_messages=ai_messages or [],
|
||||||
result=result,
|
result=result,
|
||||||
error=error,
|
error=error,
|
||||||
token_usage_records=token_usage_records or [],
|
|
||||||
usage_reported=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1135,153 +1132,3 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
|
|||||||
assert len(report_calls) == 1
|
assert len(report_calls) == 1
|
||||||
assert report_calls[0][1] is cancel_result
|
assert report_calls[0][1] is cancel_result
|
||||||
assert cleanup_calls == ["tc-cancel-report"]
|
assert cleanup_calls == ["tc-cancel-report"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"status, expected_type",
|
|
||||||
[
|
|
||||||
(FakeSubagentStatus.COMPLETED, "task_completed"),
|
|
||||||
(FakeSubagentStatus.FAILED, "task_failed"),
|
|
||||||
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
|
|
||||||
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
|
|
||||||
"""Terminal task events include a usage summary from token_usage_records."""
|
|
||||||
config = _make_subagent_config()
|
|
||||||
runtime = _make_runtime()
|
|
||||||
events = []
|
|
||||||
|
|
||||||
records = [
|
|
||||||
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
|
||||||
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
|
|
||||||
]
|
|
||||||
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
|
|
||||||
|
|
||||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
|
||||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
|
||||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
|
||||||
|
|
||||||
_run_task_tool(
|
|
||||||
runtime=runtime,
|
|
||||||
description="test",
|
|
||||||
prompt="do work",
|
|
||||||
subagent_type="general-purpose",
|
|
||||||
tool_call_id="tc-usage",
|
|
||||||
)
|
|
||||||
|
|
||||||
terminal_events = [e for e in events if e["type"] == expected_type]
|
|
||||||
assert len(terminal_events) == 1
|
|
||||||
assert terminal_events[0]["usage"] == {
|
|
||||||
"input_tokens": 300,
|
|
||||||
"output_tokens": 130,
|
|
||||||
"total_tokens": 430,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
|
||||||
"""Terminal event has usage=None when token_usage_records is empty."""
|
|
||||||
config = _make_subagent_config()
|
|
||||||
runtime = _make_runtime()
|
|
||||||
events = []
|
|
||||||
|
|
||||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
|
|
||||||
|
|
||||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
|
||||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
|
||||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
|
||||||
|
|
||||||
_run_task_tool(
|
|
||||||
runtime=runtime,
|
|
||||||
description="test",
|
|
||||||
prompt="do work",
|
|
||||||
subagent_type="general-purpose",
|
|
||||||
tool_call_id="tc-no-records",
|
|
||||||
)
|
|
||||||
|
|
||||||
completed = [e for e in events if e["type"] == "task_completed"]
|
|
||||||
assert len(completed) == 1
|
|
||||||
assert completed[0]["usage"] is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module,
|
|
||||||
"get_app_config",
|
|
||||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
|
|
||||||
config = _make_subagent_config()
|
|
||||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
|
|
||||||
runtime = _make_runtime(app_config=app_config)
|
|
||||||
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
|
|
||||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
|
|
||||||
|
|
||||||
task_tool_module._subagent_usage_cache.clear()
|
|
||||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module,
|
|
||||||
"SubagentExecutor",
|
|
||||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
|
||||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
|
||||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
|
||||||
|
|
||||||
_run_task_tool(
|
|
||||||
runtime=runtime,
|
|
||||||
description="test",
|
|
||||||
prompt="do work",
|
|
||||||
subagent_type="general-purpose",
|
|
||||||
tool_call_id="tc-disabled-cache",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
|
|
||||||
config = _make_subagent_config()
|
|
||||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
|
|
||||||
runtime = _make_runtime(app_config=app_config)
|
|
||||||
|
|
||||||
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
|
|
||||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module,
|
|
||||||
"SubagentExecutor",
|
|
||||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="poll failed"):
|
|
||||||
_run_task_tool(
|
|
||||||
runtime=runtime,
|
|
||||||
description="test",
|
|
||||||
prompt="do work",
|
|
||||||
subagent_type="general-purpose",
|
|
||||||
tool_call_id="tc-error",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic:
|
|||||||
assert middleware._should_generate_title(state) is False
|
assert middleware._should_generate_title(state) is False
|
||||||
|
|
||||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||||
_set_test_title_config(max_chars=12, model_name=None)
|
_set_test_title_config(max_chars=12)
|
||||||
middleware = TitleMiddleware()
|
middleware = TitleMiddleware()
|
||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
"""Tests for TodoMiddleware context-loss detection."""
|
"""Tests for TodoMiddleware context-loss detection."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
|
||||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from pydantic import PrivateAttr
|
|
||||||
|
|
||||||
from deerflow.agents.middlewares.todo_middleware import (
|
from deerflow.agents.middlewares.todo_middleware import (
|
||||||
TodoMiddleware,
|
TodoMiddleware,
|
||||||
_completion_reminder_count,
|
_completion_reminder_count,
|
||||||
_format_todos,
|
_format_todos,
|
||||||
_has_tool_call_intent_or_error,
|
|
||||||
_reminder_in_messages,
|
_reminder_in_messages,
|
||||||
_todos_in_messages,
|
_todos_in_messages,
|
||||||
)
|
)
|
||||||
@@ -27,35 +22,9 @@ def _reminder_msg():
|
|||||||
return HumanMessage(name="todo_reminder", content="reminder")
|
return HumanMessage(name="todo_reminder", content="reminder")
|
||||||
|
|
||||||
|
|
||||||
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
|
|
||||||
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def seen_messages(self) -> list[list[Any]]:
|
|
||||||
return self._seen_messages
|
|
||||||
|
|
||||||
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
||||||
self._seen_messages.append(list(messages))
|
|
||||||
return super()._generate(
|
|
||||||
messages,
|
|
||||||
stop=stop,
|
|
||||||
run_manager=run_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_runtime():
|
def _make_runtime():
|
||||||
runtime = MagicMock()
|
runtime = MagicMock()
|
||||||
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
|
runtime.context = {"thread_id": "test-thread"}
|
||||||
return runtime
|
|
||||||
|
|
||||||
|
|
||||||
def _make_runtime_for(thread_id: str, run_id: str):
|
|
||||||
runtime = _make_runtime()
|
|
||||||
runtime.context = {"thread_id": thread_id, "run_id": run_id}
|
|
||||||
return runtime
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
@@ -192,62 +161,10 @@ def _completion_reminder_msg():
|
|||||||
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
||||||
|
|
||||||
|
|
||||||
def _todo_completion_reminders(messages):
|
|
||||||
reminders = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
|
|
||||||
reminders.append(message)
|
|
||||||
return reminders
|
|
||||||
|
|
||||||
|
|
||||||
def _ai_no_tool_calls():
|
def _ai_no_tool_calls():
|
||||||
return AIMessage(content="I'm done!")
|
return AIMessage(content="I'm done!")
|
||||||
|
|
||||||
|
|
||||||
def _ai_with_invalid_tool_calls():
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[],
|
|
||||||
invalid_tool_calls=[
|
|
||||||
{
|
|
||||||
"type": "invalid_tool_call",
|
|
||||||
"id": "write_file:36",
|
|
||||||
"name": "write_file",
|
|
||||||
"args": "{invalid",
|
|
||||||
"error": "Failed to parse tool arguments",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ai_with_raw_provider_tool_calls():
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[],
|
|
||||||
invalid_tool_calls=[],
|
|
||||||
additional_kwargs={
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"id": "raw-tool-call",
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ai_with_legacy_function_call():
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ai_with_tool_finish_reason():
|
|
||||||
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
|
|
||||||
|
|
||||||
|
|
||||||
def _incomplete_todos():
|
def _incomplete_todos():
|
||||||
return [
|
return [
|
||||||
{"status": "completed", "content": "Step 1"},
|
{"status": "completed", "content": "Step 1"},
|
||||||
@@ -277,36 +194,6 @@ class TestCompletionReminderCount:
|
|||||||
assert _completion_reminder_count(msgs) == 1
|
assert _completion_reminder_count(msgs) == 1
|
||||||
|
|
||||||
|
|
||||||
class TestToolCallIntentOrError:
|
|
||||||
def test_false_for_plain_final_answer(self):
|
|
||||||
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
|
|
||||||
|
|
||||||
def test_true_for_structured_tool_calls(self):
|
|
||||||
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
|
|
||||||
|
|
||||||
def test_true_for_invalid_tool_calls(self):
|
|
||||||
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
|
|
||||||
|
|
||||||
def test_true_for_raw_provider_tool_calls(self):
|
|
||||||
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
|
|
||||||
|
|
||||||
def test_true_for_legacy_function_call(self):
|
|
||||||
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
|
|
||||||
|
|
||||||
def test_true_for_tool_finish_reason(self):
|
|
||||||
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
|
|
||||||
|
|
||||||
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
|
|
||||||
# Sentinel for LangChain compatibility: if future AIMessage versions add
|
|
||||||
# new top-level tool/function-call fields, this test should fail. When
|
|
||||||
# it does, update `_has_tool_call_intent_or_error()` so the completion
|
|
||||||
# reminder guard explicitly decides whether each new field means "not a
|
|
||||||
# clean final answer"; the helper has a matching comment pointing back
|
|
||||||
# to this sentinel.
|
|
||||||
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
|
|
||||||
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestAfterModel:
|
class TestAfterModel:
|
||||||
def test_returns_none_when_agent_still_using_tools(self):
|
def test_returns_none_when_agent_still_using_tools(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
@@ -348,299 +235,68 @@ class TestAfterModel:
|
|||||||
}
|
}
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
|
|
||||||
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
|
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = mw.after_model(state, runtime)
|
result = mw.after_model(state, _make_runtime())
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "model"
|
assert result["jump_to"] == "model"
|
||||||
assert "messages" not in result
|
assert len(result["messages"]) == 1
|
||||||
|
reminder = result["messages"][0]
|
||||||
request = MagicMock()
|
|
||||||
request.runtime = runtime
|
|
||||||
request.messages = state["messages"]
|
|
||||||
request.override.return_value = "patched-request"
|
|
||||||
handler = MagicMock(return_value="response")
|
|
||||||
|
|
||||||
assert mw.wrap_model_call(request, handler) == "response"
|
|
||||||
request.override.assert_called_once()
|
|
||||||
reminder = request.override.call_args.kwargs["messages"][-1]
|
|
||||||
assert isinstance(reminder, HumanMessage)
|
assert isinstance(reminder, HumanMessage)
|
||||||
assert reminder.name == "todo_completion_reminder"
|
assert reminder.name == "todo_completion_reminder"
|
||||||
assert reminder.additional_kwargs["hide_from_ui"] is True
|
|
||||||
assert "Step 2" in reminder.content
|
assert "Step 2" in reminder.content
|
||||||
assert "Step 3" in reminder.content
|
assert "Step 3" in reminder.content
|
||||||
handler.assert_called_once_with("patched-request")
|
|
||||||
|
|
||||||
def test_reminder_lists_only_incomplete_items(self):
|
def test_reminder_lists_only_incomplete_items(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [_ai_no_tool_calls()],
|
"messages": [_ai_no_tool_calls()],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = mw.after_model(state, runtime)
|
result = mw.after_model(state, _make_runtime())
|
||||||
assert result is not None
|
content = result["messages"][0].content
|
||||||
|
|
||||||
request = MagicMock()
|
|
||||||
request.runtime = runtime
|
|
||||||
request.messages = state["messages"]
|
|
||||||
request.override.return_value = "patched-request"
|
|
||||||
mw.wrap_model_call(request, MagicMock(return_value="response"))
|
|
||||||
content = request.override.call_args.kwargs["messages"][-1].content
|
|
||||||
assert "Step 1" not in content # completed — should not appear
|
assert "Step 1" not in content # completed — should not appear
|
||||||
assert "Step 2" in content
|
assert "Step 2" in content
|
||||||
assert "Step 3" in content
|
assert "Step 3" in content
|
||||||
|
|
||||||
def test_allows_exit_after_max_reminders(self):
|
def test_allows_exit_after_max_reminders(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
|
_completion_reminder_msg(),
|
||||||
|
_completion_reminder_msg(),
|
||||||
_ai_no_tool_calls(),
|
_ai_no_tool_calls(),
|
||||||
],
|
],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
assert mw.after_model(state, runtime) is not None
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
assert mw.after_model(state, runtime) is not None
|
|
||||||
assert mw.after_model(state, runtime) is None
|
|
||||||
|
|
||||||
def test_still_sends_reminder_before_cap(self):
|
def test_still_sends_reminder_before_cap(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
|
_completion_reminder_msg(), # 1 reminder so far
|
||||||
_ai_no_tool_calls(),
|
_ai_no_tool_calls(),
|
||||||
],
|
],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
assert mw.after_model(state, runtime) is not None
|
result = mw.after_model(state, _make_runtime())
|
||||||
result = mw.after_model(state, runtime)
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "model"
|
assert result["jump_to"] == "model"
|
||||||
|
|
||||||
def test_does_not_trigger_for_invalid_tool_calls(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
state = {
|
|
||||||
"messages": [_ai_with_invalid_tool_calls()],
|
|
||||||
"todos": _incomplete_todos(),
|
|
||||||
}
|
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
|
||||||
|
|
||||||
def test_does_not_trigger_for_raw_provider_tool_calls(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
state = {
|
|
||||||
"messages": [_ai_with_raw_provider_tool_calls()],
|
|
||||||
"todos": _incomplete_todos(),
|
|
||||||
}
|
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
|
||||||
|
|
||||||
def test_does_not_trigger_for_legacy_function_call(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
state = {
|
|
||||||
"messages": [_ai_with_legacy_function_call()],
|
|
||||||
"todos": _incomplete_todos(),
|
|
||||||
}
|
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
|
||||||
|
|
||||||
def test_does_not_trigger_for_tool_finish_reason(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
state = {
|
|
||||||
"messages": [_ai_with_tool_finish_reason()],
|
|
||||||
"todos": _incomplete_todos(),
|
|
||||||
}
|
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestAafterModel:
|
class TestAafterModel:
|
||||||
def test_delegates_to_sync(self):
|
def test_delegates_to_sync(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [_ai_no_tool_calls()],
|
"messages": [_ai_no_tool_calls()],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = asyncio.run(mw.aafter_model(state, runtime))
|
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "model"
|
assert result["jump_to"] == "model"
|
||||||
assert "messages" not in result
|
assert result["messages"][0].name == "todo_completion_reminder"
|
||||||
|
|
||||||
|
|
||||||
class TestWrapModelCall:
|
|
||||||
def test_no_pending_reminder_passthrough(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
request = MagicMock()
|
|
||||||
request.runtime = _make_runtime()
|
|
||||||
request.messages = [HumanMessage(content="hi")]
|
|
||||||
handler = MagicMock(return_value="response")
|
|
||||||
|
|
||||||
assert mw.wrap_model_call(request, handler) == "response"
|
|
||||||
request.override.assert_not_called()
|
|
||||||
handler.assert_called_once_with(request)
|
|
||||||
|
|
||||||
def test_pending_reminder_is_injected_once(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
|
||||||
"messages": [_ai_no_tool_calls()],
|
|
||||||
"todos": _incomplete_todos(),
|
|
||||||
}
|
|
||||||
mw.after_model(state, runtime)
|
|
||||||
|
|
||||||
request = MagicMock()
|
|
||||||
request.runtime = runtime
|
|
||||||
request.messages = state["messages"]
|
|
||||||
request.override.return_value = "patched-request"
|
|
||||||
handler = MagicMock(return_value="response")
|
|
||||||
|
|
||||||
assert mw.wrap_model_call(request, handler) == "response"
|
|
||||||
injected_messages = request.override.call_args.kwargs["messages"]
|
|
||||||
assert injected_messages[-1].name == "todo_completion_reminder"
|
|
||||||
|
|
||||||
request.override.reset_mock()
|
|
||||||
handler.reset_mock()
|
|
||||||
handler.return_value = "second-response"
|
|
||||||
assert mw.wrap_model_call(request, handler) == "second-response"
|
|
||||||
request.override.assert_not_called()
|
|
||||||
handler.assert_called_once_with(request)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTodoMiddlewareAgentGraphIntegration:
|
|
||||||
def test_completion_reminder_is_transient_in_real_agent_graph(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
model = _CapturingFakeMessagesListChatModel(
|
|
||||||
responses=[
|
|
||||||
AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": "write_todos",
|
|
||||||
"id": "todos-1",
|
|
||||||
"args": {
|
|
||||||
"todos": [
|
|
||||||
{"content": "Step 1", "status": "completed"},
|
|
||||||
{"content": "Step 2", "status": "pending"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
AIMessage(content="premature final 1"),
|
|
||||||
AIMessage(content="premature final 2"),
|
|
||||||
AIMessage(content="premature final 3"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
graph = create_agent(model=model, tools=[], middleware=[mw])
|
|
||||||
|
|
||||||
result = graph.invoke(
|
|
||||||
{"messages": [("user", "finish all todos")]},
|
|
||||||
context={"thread_id": "integration-thread", "run_id": "integration-run"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(model.seen_messages) == 4
|
|
||||||
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
|
|
||||||
assert reminders_by_call[0] == []
|
|
||||||
assert reminders_by_call[1] == []
|
|
||||||
assert len(reminders_by_call[2]) == 1
|
|
||||||
assert len(reminders_by_call[3]) == 1
|
|
||||||
assert "Step 1" not in reminders_by_call[2][0].content
|
|
||||||
assert "Step 2" in reminders_by_call[2][0].content
|
|
||||||
|
|
||||||
persisted_reminders = _todo_completion_reminders(result["messages"])
|
|
||||||
assert persisted_reminders == []
|
|
||||||
assert result["messages"][-1].content == "premature final 3"
|
|
||||||
assert result["todos"] == [
|
|
||||||
{"content": "Step 1", "status": "completed"},
|
|
||||||
{"content": "Step 2", "status": "pending"},
|
|
||||||
]
|
|
||||||
assert mw._pending_completion_reminders == {}
|
|
||||||
assert mw._completion_reminder_counts == {}
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunScopedReminderCleanup:
|
|
||||||
def test_before_agent_clears_stale_count_without_pending_reminder(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
stale_runtime = _make_runtime()
|
|
||||||
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
|
|
||||||
current_runtime = _make_runtime()
|
|
||||||
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
|
|
||||||
other_thread_runtime = _make_runtime()
|
|
||||||
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
|
|
||||||
|
|
||||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
|
||||||
assert mw.after_model(state, stale_runtime) is not None
|
|
||||||
assert mw.after_model(state, other_thread_runtime) is not None
|
|
||||||
|
|
||||||
# Simulate a model call that drained the pending message, followed by an
|
|
||||||
# abnormal run end where after_agent did not clear the reminder count.
|
|
||||||
assert mw._drain_completion_reminders(stale_runtime)
|
|
||||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
|
|
||||||
|
|
||||||
mw.before_agent({}, current_runtime)
|
|
||||||
|
|
||||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
|
||||||
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
|
|
||||||
|
|
||||||
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
mw._MAX_COMPLETION_REMINDER_KEYS = 2
|
|
||||||
first_runtime = _make_runtime_for("thread-a", "run-a")
|
|
||||||
second_runtime = _make_runtime_for("thread-b", "run-b")
|
|
||||||
third_runtime = _make_runtime_for("thread-c", "run-c")
|
|
||||||
|
|
||||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
|
||||||
assert mw.after_model(state, first_runtime) is not None
|
|
||||||
|
|
||||||
# Simulate the normal model request path: pending reminder is consumed,
|
|
||||||
# but the run count remains until after_agent() or stale cleanup.
|
|
||||||
assert mw._drain_completion_reminders(first_runtime)
|
|
||||||
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
|
|
||||||
|
|
||||||
assert mw.after_model(state, second_runtime) is not None
|
|
||||||
assert mw.after_model(state, third_runtime) is not None
|
|
||||||
|
|
||||||
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
|
|
||||||
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
|
|
||||||
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
|
|
||||||
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
|
|
||||||
|
|
||||||
def test_size_guard_prunes_pending_and_count_state_together(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
mw._MAX_COMPLETION_REMINDER_KEYS = 1
|
|
||||||
stale_runtime = _make_runtime_for("thread-a", "run-a")
|
|
||||||
current_runtime = _make_runtime_for("thread-b", "run-b")
|
|
||||||
|
|
||||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
|
||||||
assert mw.after_model(state, stale_runtime) is not None
|
|
||||||
assert mw.after_model(state, current_runtime) is not None
|
|
||||||
|
|
||||||
assert mw._drain_completion_reminders(stale_runtime) == []
|
|
||||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
|
||||||
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestAwrapModelCall:
|
|
||||||
def test_async_pending_reminder_is_injected(self):
|
|
||||||
mw = TodoMiddleware()
|
|
||||||
runtime = _make_runtime()
|
|
||||||
state = {
|
|
||||||
"messages": [_ai_no_tool_calls()],
|
|
||||||
"todos": _incomplete_todos(),
|
|
||||||
}
|
|
||||||
mw.after_model(state, runtime)
|
|
||||||
|
|
||||||
request = MagicMock()
|
|
||||||
request.runtime = runtime
|
|
||||||
request.messages = state["messages"]
|
|
||||||
request.override.return_value = "patched-request"
|
|
||||||
handler = AsyncMock(return_value="response")
|
|
||||||
|
|
||||||
result = asyncio.run(mw.awrap_model_call(request, handler))
|
|
||||||
assert result == "response"
|
|
||||||
injected_messages = request.override.call_args.kwargs["messages"]
|
|
||||||
assert injected_messages[-1].name == "todo_completion_reminder"
|
|
||||||
handler.assert_awaited_once_with("patched-request")
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||||
@@ -233,49 +232,3 @@ class TestTokenUsageMiddleware:
|
|||||||
"tool_call_id": "write_todos:remove",
|
"tool_call_id": "write_todos:remove",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
|
|
||||||
middleware = TokenUsageMiddleware()
|
|
||||||
first_dispatch = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
|
|
||||||
)
|
|
||||||
second_dispatch = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{"id": "task:second-a", "name": "task", "args": {}},
|
|
||||||
{"id": "task:second-b", "name": "task", "args": {}},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
first_dispatch,
|
|
||||||
ToolMessage(content="first", tool_call_id="task:first"),
|
|
||||||
second_dispatch,
|
|
||||||
ToolMessage(content="second-a", tool_call_id="task:second-a"),
|
|
||||||
ToolMessage(content="second-b", tool_call_id="task:second-b"),
|
|
||||||
AIMessage(content="done"),
|
|
||||||
]
|
|
||||||
cached_usage = {
|
|
||||||
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
||||||
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
|
|
||||||
}
|
|
||||||
|
|
||||||
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module,
|
|
||||||
"pop_cached_subagent_usage",
|
|
||||||
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = middleware.after_model({"messages": messages}, _make_runtime())
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
|
|
||||||
assert len(usage_updates) == 1
|
|
||||||
updated = usage_updates[0]
|
|
||||||
assert updated.tool_calls == second_dispatch.tool_calls
|
|
||||||
assert updated.usage_metadata == {
|
|
||||||
"input_tokens": 30,
|
|
||||||
"output_tokens": 12,
|
|
||||||
"total_tokens": 42,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -65,7 +65,8 @@ def _make_minimal_config(tools):
|
|||||||
|
|
||||||
@patch("deerflow.tools.tools.get_app_config")
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||||
|
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
|
||||||
"""Config-loaded async-only tools can still be invoked by sync clients."""
|
"""Config-loaded async-only tools can still be invoked by sync clients."""
|
||||||
|
|
||||||
async def async_tool_impl(x: int) -> str:
|
async def async_tool_impl(x: int) -> str:
|
||||||
@@ -97,7 +98,8 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
|||||||
|
|
||||||
@patch("deerflow.tools.tools.get_app_config")
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
def test_no_duplicates_returned(mock_bash, mock_cfg):
|
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||||
|
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
|
||||||
"""get_available_tools() never returns two tools with the same name."""
|
"""get_available_tools() never returns two tools with the same name."""
|
||||||
mock_cfg.return_value = _make_minimal_config([])
|
mock_cfg.return_value = _make_minimal_config([])
|
||||||
|
|
||||||
@@ -111,7 +113,8 @@ def test_no_duplicates_returned(mock_bash, mock_cfg):
|
|||||||
|
|
||||||
@patch("deerflow.tools.tools.get_app_config")
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
def test_first_occurrence_wins(mock_bash, mock_cfg):
|
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||||
|
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
|
||||||
"""When duplicates exist, the first occurrence is kept."""
|
"""When duplicates exist, the first occurrence is kept."""
|
||||||
mock_cfg.return_value = _make_minimal_config([])
|
mock_cfg.return_value = _make_minimal_config([])
|
||||||
|
|
||||||
@@ -129,7 +132,8 @@ def test_first_occurrence_wins(mock_bash, mock_cfg):
|
|||||||
|
|
||||||
@patch("deerflow.tools.tools.get_app_config")
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
|
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||||
|
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
|
||||||
"""A warning is logged for every skipped duplicate."""
|
"""A warning is logged for every skipped duplicate."""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|||||||
Generated
+5
-22
@@ -1,5 +1,5 @@
|
|||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14' and sys_platform == 'win32'",
|
"python_full_version >= '3.14' and sys_platform == 'win32'",
|
||||||
@@ -763,9 +763,6 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
discord = [
|
|
||||||
{ name = "discord-py" },
|
|
||||||
]
|
|
||||||
postgres = [
|
postgres = [
|
||||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||||
]
|
]
|
||||||
@@ -784,7 +781,6 @@ requires-dist = [
|
|||||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||||
{ name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" },
|
|
||||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||||
{ name = "httpx", specifier = ">=0.28.0" },
|
{ name = "httpx", specifier = ">=0.28.0" },
|
||||||
@@ -799,7 +795,7 @@ requires-dist = [
|
|||||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||||
]
|
]
|
||||||
provides-extras = ["postgres", "discord"]
|
provides-extras = ["postgres"]
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
dev = [
|
dev = [
|
||||||
@@ -927,19 +923,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" },
|
{ url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "discord-py"
|
|
||||||
version = "2.7.1"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "aiohttp" },
|
|
||||||
{ name = "audioop-lts", marker = "python_full_version >= '3.13'" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "distro"
|
name = "distro"
|
||||||
version = "1.9.0"
|
version = "1.9.0"
|
||||||
@@ -2022,7 +2005,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langsmith"
|
name = "langsmith"
|
||||||
version = "0.8.0"
|
version = "0.7.36"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
@@ -2035,9 +2018,9 @@ dependencies = [
|
|||||||
{ name = "xxhash" },
|
{ name = "xxhash" },
|
||||||
{ name = "zstandard" },
|
{ name = "zstandard" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" },
|
{ url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
|||||||
@@ -1029,14 +1029,6 @@ run_events:
|
|||||||
# client_secret: $DINGTALK_CLIENT_SECRET
|
# client_secret: $DINGTALK_CLIENT_SECRET
|
||||||
# allowed_users: [] # empty = allow all
|
# allowed_users: [] # empty = allow all
|
||||||
# card_template_id: "" # Optional: AI Card template ID for streaming updates
|
# card_template_id: "" # Optional: AI Card template ID for streaming updates
|
||||||
#
|
|
||||||
# discord:
|
|
||||||
# enabled: false
|
|
||||||
# bot_token: $DISCORD_BOT_TOKEN
|
|
||||||
# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID
|
|
||||||
# mention_only: false # If true, only respond when the bot is mentioned
|
|
||||||
# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention)
|
|
||||||
# thread_mode: false # If true, group a channel conversation into a thread
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Guardrails Configuration
|
# Guardrails Configuration
|
||||||
|
|||||||
+3
-21
@@ -28,10 +28,6 @@ http {
|
|||||||
set $gateway_upstream gateway:8001;
|
set $gateway_upstream gateway:8001;
|
||||||
set $frontend_upstream frontend:3000;
|
set $frontend_upstream frontend:3000;
|
||||||
|
|
||||||
# Default proxy settings for all locations (streaming/SSE support)
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
|
|
||||||
# Keep the unified nginx endpoint same-origin by default. When split
|
# Keep the unified nginx endpoint same-origin by default. When split
|
||||||
# frontend/backend or port-forwarded deployments need browser CORS,
|
# frontend/backend or port-forwarded deployments need browser CORS,
|
||||||
# configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and
|
# configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and
|
||||||
@@ -53,6 +49,8 @@ http {
|
|||||||
proxy_set_header Connection '';
|
proxy_set_header Connection '';
|
||||||
|
|
||||||
# SSE/Streaming support
|
# SSE/Streaming support
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_cache off;
|
||||||
proxy_set_header X-Accel-Buffering no;
|
proxy_set_header X-Accel-Buffering no;
|
||||||
|
|
||||||
# Timeouts for long-running requests
|
# Timeouts for long-running requests
|
||||||
@@ -72,7 +70,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Memory endpoint
|
# Custom API: Memory endpoint
|
||||||
@@ -83,7 +80,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: MCP configuration endpoint
|
# Custom API: MCP configuration endpoint
|
||||||
@@ -94,7 +90,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Skills configuration endpoint
|
# Custom API: Skills configuration endpoint
|
||||||
@@ -105,7 +100,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Agents endpoint
|
# Custom API: Agents endpoint
|
||||||
@@ -116,7 +110,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Uploads endpoint
|
# Custom API: Uploads endpoint
|
||||||
@@ -131,8 +124,6 @@ http {
|
|||||||
# Large file upload support
|
# Large file upload support
|
||||||
client_max_body_size 100M;
|
client_max_body_size 100M;
|
||||||
proxy_request_buffering off;
|
proxy_request_buffering off;
|
||||||
|
|
||||||
# Disable response buffering to avoid permission errors
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Other endpoints under /api/threads
|
# Custom API: Other endpoints under /api/threads
|
||||||
@@ -143,7 +134,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Documentation: Swagger UI
|
# API Documentation: Swagger UI
|
||||||
@@ -154,7 +144,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Documentation: ReDoc
|
# API Documentation: ReDoc
|
||||||
@@ -165,7 +154,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Documentation: OpenAPI Schema
|
# API Documentation: OpenAPI Schema
|
||||||
@@ -176,7 +164,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Health check endpoint (gateway)
|
# Health check endpoint (gateway)
|
||||||
@@ -187,7 +174,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Provisioner API (sandbox management) ────────────────────────
|
# ── Provisioner API (sandbox management) ────────────────────────
|
||||||
@@ -201,7 +187,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*).
|
# Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*).
|
||||||
@@ -213,9 +198,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
# Disable buffering to avoid permission errors when nginx
|
|
||||||
# runs as a non-root user (e.g. local development).
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# All other requests go to frontend
|
# All other requests go to frontend
|
||||||
@@ -238,4 +220,4 @@ http {
|
|||||||
proxy_read_timeout 600s;
|
proxy_read_timeout 600s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,11 +70,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
# Disable buffering to avoid permission errors when nginx
|
|
||||||
# runs as a non-root user (e.g. local development).
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Memory endpoint
|
# Custom API: Memory endpoint
|
||||||
@@ -85,9 +80,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: MCP configuration endpoint
|
# Custom API: MCP configuration endpoint
|
||||||
@@ -98,9 +90,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Skills configuration endpoint
|
# Custom API: Skills configuration endpoint
|
||||||
@@ -111,9 +100,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Agents endpoint
|
# Custom API: Agents endpoint
|
||||||
@@ -124,9 +110,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Uploads endpoint
|
# Custom API: Uploads endpoint
|
||||||
@@ -141,10 +124,6 @@ http {
|
|||||||
# Large file upload support
|
# Large file upload support
|
||||||
client_max_body_size 100M;
|
client_max_body_size 100M;
|
||||||
proxy_request_buffering off;
|
proxy_request_buffering off;
|
||||||
|
|
||||||
# Disable response buffering to avoid permission errors
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Custom API: Other endpoints under /api/threads
|
# Custom API: Other endpoints under /api/threads
|
||||||
@@ -155,9 +134,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Documentation: Swagger UI
|
# API Documentation: Swagger UI
|
||||||
@@ -168,9 +144,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Documentation: ReDoc
|
# API Documentation: ReDoc
|
||||||
@@ -181,9 +154,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# API Documentation: OpenAPI Schema
|
# API Documentation: OpenAPI Schema
|
||||||
@@ -194,9 +164,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Health check endpoint (gateway)
|
# Health check endpoint (gateway)
|
||||||
@@ -207,9 +174,6 @@ http {
|
|||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Catch-all for any /api/* prefix not matched by a more specific block above.
|
# Catch-all for any /api/* prefix not matched by a more specific block above.
|
||||||
@@ -229,11 +193,6 @@ http {
|
|||||||
# Auth endpoints set HttpOnly cookies — make sure nginx doesn't
|
# Auth endpoints set HttpOnly cookies — make sure nginx doesn't
|
||||||
# strip the Set-Cookie header from upstream responses.
|
# strip the Set-Cookie header from upstream responses.
|
||||||
proxy_pass_header Set-Cookie;
|
proxy_pass_header Set-Cookie;
|
||||||
|
|
||||||
# Disable buffering to avoid permission errors when nginx
|
|
||||||
# runs as a non-root user (e.g. local development).
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_cache off;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# All other requests go to frontend
|
# All other requests go to frontend
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import { FlickeringGrid } from "@/components/ui/flickering-grid";
|
|||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
import { useAuth } from "@/core/auth/AuthProvider";
|
import { useAuth } from "@/core/auth/AuthProvider";
|
||||||
import { parseAuthError } from "@/core/auth/types";
|
import { parseAuthError } from "@/core/auth/types";
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validate next parameter
|
* Validate next parameter
|
||||||
@@ -71,7 +72,7 @@ export default function LoginPage() {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let cancelled = false;
|
let cancelled = false;
|
||||||
|
|
||||||
void fetch("/api/v1/auth/setup-status")
|
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
|
||||||
.then((r) => r.json())
|
.then((r) => r.json())
|
||||||
.then((data: { needs_setup?: boolean }) => {
|
.then((data: { needs_setup?: boolean }) => {
|
||||||
if (!cancelled && data.needs_setup) {
|
if (!cancelled && data.needs_setup) {
|
||||||
@@ -94,8 +95,8 @@ export default function LoginPage() {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const endpoint = isLogin
|
const endpoint = isLogin
|
||||||
? "/api/v1/auth/login/local"
|
? `${getBackendBaseURL()}/api/v1/auth/login/local`
|
||||||
: "/api/v1/auth/register";
|
: `${getBackendBaseURL()}/api/v1/auth/register`;
|
||||||
const body = isLogin
|
const body = isLogin
|
||||||
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
|
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
|
||||||
: JSON.stringify({ email, password });
|
: JSON.stringify({ email, password });
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import { Input } from "@/components/ui/input";
|
|||||||
import { getCsrfHeaders } from "@/core/api/fetcher";
|
import { getCsrfHeaders } from "@/core/api/fetcher";
|
||||||
import { useAuth } from "@/core/auth/AuthProvider";
|
import { useAuth } from "@/core/auth/AuthProvider";
|
||||||
import { parseAuthError } from "@/core/auth/types";
|
import { parseAuthError } from "@/core/auth/types";
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
type SetupMode = "loading" | "init_admin" | "change_password";
|
type SetupMode = "loading" | "init_admin" | "change_password";
|
||||||
|
|
||||||
@@ -36,7 +37,7 @@ export default function SetupPage() {
|
|||||||
setMode("change_password");
|
setMode("change_password");
|
||||||
} else if (!isAuthenticated) {
|
} else if (!isAuthenticated) {
|
||||||
// Check if the system has no users yet
|
// Check if the system has no users yet
|
||||||
void fetch("/api/v1/auth/setup-status")
|
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
|
||||||
.then((r) => r.json())
|
.then((r) => r.json())
|
||||||
.then((data: { needs_setup?: boolean }) => {
|
.then((data: { needs_setup?: boolean }) => {
|
||||||
if (cancelled) return;
|
if (cancelled) return;
|
||||||
@@ -72,7 +73,7 @@ export default function SetupPage() {
|
|||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const res = await fetch("/api/v1/auth/initialize", {
|
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/initialize`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
@@ -113,19 +114,22 @@ export default function SetupPage() {
|
|||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const res = await fetch("/api/v1/auth/change-password", {
|
const res = await fetch(
|
||||||
method: "POST",
|
`${getBackendBaseURL()}/api/v1/auth/change-password`,
|
||||||
headers: {
|
{
|
||||||
"Content-Type": "application/json",
|
method: "POST",
|
||||||
...getCsrfHeaders(),
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
...getCsrfHeaders(),
|
||||||
|
},
|
||||||
|
credentials: "include",
|
||||||
|
body: JSON.stringify({
|
||||||
|
current_password: currentPassword,
|
||||||
|
new_password: newPassword,
|
||||||
|
new_email: email || undefined,
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
credentials: "include",
|
);
|
||||||
body: JSON.stringify({
|
|
||||||
current_password: currentPassword,
|
|
||||||
new_password: newPassword,
|
|
||||||
new_email: email || undefined,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ export default function AgentChatPage() {
|
|||||||
thread,
|
thread,
|
||||||
pendingUsageMessages,
|
pendingUsageMessages,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
isUploading,
|
|
||||||
isHistoryLoading,
|
isHistoryLoading,
|
||||||
hasMoreHistory,
|
hasMoreHistory,
|
||||||
loadMoreHistory,
|
loadMoreHistory,
|
||||||
@@ -107,11 +106,7 @@ export default function AgentChatPage() {
|
|||||||
|
|
||||||
const handleSubmit = useCallback(
|
const handleSubmit = useCallback(
|
||||||
(message: PromptInputMessage) => {
|
(message: PromptInputMessage) => {
|
||||||
const sendPromise = sendMessage(threadId, message, { agent_name });
|
void sendMessage(threadId, message, { agent_name });
|
||||||
if (message.files.length > 0) {
|
|
||||||
return sendPromise;
|
|
||||||
}
|
|
||||||
void sendPromise;
|
|
||||||
},
|
},
|
||||||
[sendMessage, threadId, agent_name],
|
[sendMessage, threadId, agent_name],
|
||||||
);
|
);
|
||||||
@@ -248,10 +243,7 @@ export default function AgentChatPage() {
|
|||||||
<AgentWelcome agent={agent} agentName={agent_name} />
|
<AgentWelcome agent={agent} agentName={agent_name} />
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
disabled={
|
disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"}
|
||||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
|
|
||||||
isUploading
|
|
||||||
}
|
|
||||||
onContextChange={(context) => setSettings("context", context)}
|
onContextChange={(context) => setSettings("context", context)}
|
||||||
onSubmit={handleSubmit}
|
onSubmit={handleSubmit}
|
||||||
onStop={handleStop}
|
onStop={handleStop}
|
||||||
|
|||||||
@@ -109,11 +109,7 @@ export default function ChatPage() {
|
|||||||
|
|
||||||
const handleSubmit = useCallback(
|
const handleSubmit = useCallback(
|
||||||
(message: PromptInputMessage) => {
|
(message: PromptInputMessage) => {
|
||||||
const sendPromise = sendMessage(threadId, message);
|
void sendMessage(threadId, message);
|
||||||
if (message.files.length > 0) {
|
|
||||||
return sendPromise;
|
|
||||||
}
|
|
||||||
void sendPromise;
|
|
||||||
},
|
},
|
||||||
[sendMessage, threadId],
|
[sendMessage, threadId],
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { redirect } from "next/navigation";
|
|||||||
import { AuthProvider } from "@/core/auth/AuthProvider";
|
import { AuthProvider } from "@/core/auth/AuthProvider";
|
||||||
import { getServerSideUser } from "@/core/auth/server";
|
import { getServerSideUser } from "@/core/auth/server";
|
||||||
import { assertNever } from "@/core/auth/types";
|
import { assertNever } from "@/core/auth/types";
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
import { WorkspaceContent } from "./workspace-content";
|
import { WorkspaceContent } from "./workspace-content";
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ export default async function WorkspaceLayout({
|
|||||||
Retry
|
Retry
|
||||||
</Link>
|
</Link>
|
||||||
<Link
|
<Link
|
||||||
href="/api/v1/auth/logout"
|
href={`${getBackendBaseURL()}/api/v1/auth/logout`}
|
||||||
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
|
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
|
||||||
>
|
>
|
||||||
Logout & Reset
|
Logout & Reset
|
||||||
|
|||||||
@@ -499,10 +499,6 @@ export const PromptInput = ({
|
|||||||
// Keep a ref to files for cleanup on unmount (avoids stale closure)
|
// Keep a ref to files for cleanup on unmount (avoids stale closure)
|
||||||
const filesRef = useRef(files);
|
const filesRef = useRef(files);
|
||||||
filesRef.current = files;
|
filesRef.current = files;
|
||||||
const providerTextRef = useRef("");
|
|
||||||
if (usingProvider) {
|
|
||||||
providerTextRef.current = controller.textInput.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
const openFileDialogLocal = useCallback(() => {
|
const openFileDialogLocal = useCallback(() => {
|
||||||
inputRef.current?.click();
|
inputRef.current?.click();
|
||||||
@@ -772,24 +768,6 @@ export const PromptInput = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert blob URLs to data URLs asynchronously
|
// Convert blob URLs to data URLs asynchronously
|
||||||
const submittedFileIds = files.map((file) => file.id);
|
|
||||||
const clearSubmittedState = () => {
|
|
||||||
const currentFileIds = new Set(filesRef.current.map((file) => file.id));
|
|
||||||
const submittedFileIdsStillPresent = submittedFileIds.filter((id) =>
|
|
||||||
currentFileIds.has(id),
|
|
||||||
);
|
|
||||||
if (submittedFileIdsStillPresent.length === filesRef.current.length) {
|
|
||||||
clear();
|
|
||||||
} else {
|
|
||||||
for (const id of submittedFileIdsStillPresent) {
|
|
||||||
remove(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (usingProvider && providerTextRef.current === text) {
|
|
||||||
controller.textInput.clear();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Promise.all(
|
Promise.all(
|
||||||
files.map(async ({ id, ...item }) => {
|
files.map(async ({ id, ...item }) => {
|
||||||
if (item.file instanceof File) {
|
if (item.file instanceof File) {
|
||||||
@@ -815,14 +793,20 @@ export const PromptInput = ({
|
|||||||
if (result instanceof Promise) {
|
if (result instanceof Promise) {
|
||||||
result
|
result
|
||||||
.then(() => {
|
.then(() => {
|
||||||
clearSubmittedState();
|
clear();
|
||||||
|
if (usingProvider) {
|
||||||
|
controller.textInput.clear();
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
// Don't clear on error - user may want to retry
|
// Don't clear on error - user may want to retry
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Sync function completed without throwing, clear attachments
|
// Sync function completed without throwing, clear attachments
|
||||||
clearSubmittedState();
|
clear();
|
||||||
|
if (usingProvider) {
|
||||||
|
controller.textInput.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
// Don't clear on error - user may want to retry
|
// Don't clear on error - user may want to retry
|
||||||
|
|||||||
@@ -110,7 +110,6 @@ export function InputBox({
|
|||||||
threadId,
|
threadId,
|
||||||
initialValue,
|
initialValue,
|
||||||
onContextChange,
|
onContextChange,
|
||||||
onFollowupsVisibilityChange,
|
|
||||||
onSubmit,
|
onSubmit,
|
||||||
onStop,
|
onStop,
|
||||||
...props
|
...props
|
||||||
@@ -143,8 +142,7 @@ export function InputBox({
|
|||||||
reasoning_effort?: "minimal" | "low" | "medium" | "high";
|
reasoning_effort?: "minimal" | "low" | "medium" | "high";
|
||||||
},
|
},
|
||||||
) => void;
|
) => void;
|
||||||
onFollowupsVisibilityChange?: (visible: boolean) => void;
|
onSubmit?: (message: PromptInputMessage) => void;
|
||||||
onSubmit?: (message: PromptInputMessage) => void | Promise<void>;
|
|
||||||
onStop?: () => void;
|
onStop?: () => void;
|
||||||
}) {
|
}) {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
@@ -253,12 +251,12 @@ export function InputBox({
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleSubmit = useCallback(
|
const handleSubmit = useCallback(
|
||||||
(message: PromptInputMessage) => {
|
async (message: PromptInputMessage) => {
|
||||||
if (status === "streaming") {
|
if (status === "streaming") {
|
||||||
onStop?.();
|
onStop?.();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!message.text.trim() && message.files.length === 0) {
|
if (!message.text) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setFollowups([]);
|
setFollowups([]);
|
||||||
@@ -276,14 +274,11 @@ export function InputBox({
|
|||||||
selectedModel?.supports_thinking ?? false,
|
selectedModel?.supports_thinking ?? false,
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
return new Promise<void>((resolve, reject) => {
|
setTimeout(() => onSubmit?.(message), 0);
|
||||||
setTimeout(() => {
|
return;
|
||||||
Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject);
|
|
||||||
}, 0);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return onSubmit?.(message);
|
onSubmit?.(message);
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
context,
|
context,
|
||||||
@@ -353,14 +348,6 @@ export function InputBox({
|
|||||||
!followupsHidden &&
|
!followupsHidden &&
|
||||||
(followupsLoading || followups.length > 0);
|
(followupsLoading || followups.length > 0);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
onFollowupsVisibilityChange?.(showFollowups);
|
|
||||||
}, [onFollowupsVisibilityChange, showFollowups]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return () => onFollowupsVisibilityChange?.(false);
|
|
||||||
}, [onFollowupsVisibilityChange]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
messagesRef.current = thread.messages;
|
messagesRef.current = thread.messages;
|
||||||
}, [thread.messages]);
|
}, [thread.messages]);
|
||||||
|
|||||||
@@ -12,11 +12,13 @@ function TokenUsageSummary({
|
|||||||
inputTokens,
|
inputTokens,
|
||||||
outputTokens,
|
outputTokens,
|
||||||
totalTokens,
|
totalTokens,
|
||||||
|
unavailable = false,
|
||||||
}: {
|
}: {
|
||||||
className?: string;
|
className?: string;
|
||||||
inputTokens?: number;
|
inputTokens?: number;
|
||||||
outputTokens?: number;
|
outputTokens?: number;
|
||||||
totalTokens?: number;
|
totalTokens?: number;
|
||||||
|
unavailable?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
|
|
||||||
@@ -31,15 +33,21 @@ function TokenUsageSummary({
|
|||||||
<CoinsIcon className="size-3" />
|
<CoinsIcon className="size-3" />
|
||||||
{t.tokenUsage.label}
|
{t.tokenUsage.label}
|
||||||
</span>
|
</span>
|
||||||
<span>
|
{!unavailable ? (
|
||||||
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
<>
|
||||||
</span>
|
<span>
|
||||||
<span>
|
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
||||||
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
</span>
|
||||||
</span>
|
<span>
|
||||||
<span className="font-medium">
|
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
||||||
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
</span>
|
||||||
</span>
|
<span className="font-medium">
|
||||||
|
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
||||||
|
</span>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<span>{t.tokenUsage.unavailableShort}</span>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -47,7 +55,7 @@ function TokenUsageSummary({
|
|||||||
export function MessageTokenUsageList({
|
export function MessageTokenUsageList({
|
||||||
className,
|
className,
|
||||||
enabled = false,
|
enabled = false,
|
||||||
isLoading: _isLoading = false,
|
isLoading = false,
|
||||||
messages,
|
messages,
|
||||||
}: {
|
}: {
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -55,7 +63,7 @@ export function MessageTokenUsageList({
|
|||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
messages: Message[];
|
messages: Message[];
|
||||||
}) {
|
}) {
|
||||||
if (!enabled) {
|
if (!enabled || isLoading) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,16 +75,13 @@ export function MessageTokenUsageList({
|
|||||||
|
|
||||||
const usage = accumulateUsage(aiMessages);
|
const usage = accumulateUsage(aiMessages);
|
||||||
|
|
||||||
if (!usage) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TokenUsageSummary
|
<TokenUsageSummary
|
||||||
className={className}
|
className={className}
|
||||||
inputTokens={usage.inputTokens}
|
inputTokens={usage?.inputTokens}
|
||||||
outputTokens={usage.outputTokens}
|
outputTokens={usage?.outputTokens}
|
||||||
totalTokens={usage.totalTokens}
|
totalTokens={usage?.totalTokens}
|
||||||
|
unavailable={!usage}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import { Input } from "@/components/ui/input";
|
|||||||
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
|
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
|
||||||
import { useAuth } from "@/core/auth/AuthProvider";
|
import { useAuth } from "@/core/auth/AuthProvider";
|
||||||
import { parseAuthError } from "@/core/auth/types";
|
import { parseAuthError } from "@/core/auth/types";
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
|
|
||||||
import { SettingsSection } from "./settings-section";
|
import { SettingsSection } from "./settings-section";
|
||||||
@@ -38,17 +39,20 @@ export function AccountSettingsPage() {
|
|||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const res = await fetch("/api/v1/auth/change-password", {
|
const res = await fetch(
|
||||||
method: "POST",
|
`${getBackendBaseURL()}/api/v1/auth/change-password`,
|
||||||
headers: {
|
{
|
||||||
"Content-Type": "application/json",
|
method: "POST",
|
||||||
...getCsrfHeaders(),
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
...getCsrfHeaders(),
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
current_password: currentPassword,
|
||||||
|
new_password: newPassword,
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
);
|
||||||
current_password: currentPassword,
|
|
||||||
new_password: newPassword,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import React, {
|
|||||||
type ReactNode,
|
type ReactNode,
|
||||||
} from "react";
|
} from "react";
|
||||||
|
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
import { type User, buildLoginUrl } from "./types";
|
import { type User, buildLoginUrl } from "./types";
|
||||||
|
|
||||||
// Re-export for consumers
|
// Re-export for consumers
|
||||||
@@ -56,7 +58,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
|
|||||||
const refreshUser = useCallback(async () => {
|
const refreshUser = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
setIsLoading(true);
|
setIsLoading(true);
|
||||||
const res = await fetch("/api/v1/auth/me", {
|
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/me`, {
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -88,7 +90,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
|
|||||||
setUser(null);
|
setUser(null);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await fetch("/api/v1/auth/logout", {
|
await fetch(`${getBackendBaseURL()}/api/v1/auth/logout`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
|
|||||||
return hasUsage ? cumulative : null;
|
return hasUsage ? cumulative : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function hasNonZeroUsage(
|
function hasNonZeroUsage(
|
||||||
usage: TokenUsage | null | undefined,
|
usage: TokenUsage | null | undefined,
|
||||||
): usage is TokenUsage {
|
): usage is TokenUsage {
|
||||||
return (
|
return (
|
||||||
@@ -75,7 +75,7 @@ export function hasNonZeroUsage(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
||||||
return {
|
return {
|
||||||
inputTokens: base.inputTokens + delta.inputTokens,
|
inputTokens: base.inputTokens + delta.inputTokens,
|
||||||
outputTokens: base.outputTokens + delta.outputTokens,
|
outputTokens: base.outputTokens + delta.outputTokens,
|
||||||
|
|||||||
@@ -26,13 +26,6 @@ export type MessageGroup =
|
|||||||
| AssistantClarificationGroup
|
| AssistantClarificationGroup
|
||||||
| AssistantSubagentGroup;
|
| AssistantSubagentGroup;
|
||||||
|
|
||||||
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
|
|
||||||
"summary",
|
|
||||||
"loop_warning",
|
|
||||||
"todo_reminder",
|
|
||||||
"todo_completion_reminder",
|
|
||||||
]);
|
|
||||||
|
|
||||||
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||||
if (messages.length === 0) {
|
if (messages.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
@@ -60,6 +53,10 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (message.name === "todo_reminder") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (message.type === "human") {
|
if (message.type === "human") {
|
||||||
groups.push({ id: message.id, type: "human", messages: [message] });
|
groups.push({ id: message.id, type: "human", messages: [message] });
|
||||||
continue;
|
continue;
|
||||||
@@ -371,8 +368,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
|
|||||||
export function isHiddenFromUIMessage(message: Message) {
|
export function isHiddenFromUIMessage(message: Message) {
|
||||||
return (
|
return (
|
||||||
message.additional_kwargs?.hide_from_ui === true ||
|
message.additional_kwargs?.hide_from_ui === true ||
|
||||||
(typeof message.name === "string" &&
|
message.name === "summary" ||
|
||||||
HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
|
message.name === "loop_warning"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,60 +45,15 @@ type SendMessageOptions = {
|
|||||||
additionalKwargs?: Record<string, unknown>;
|
additionalKwargs?: Record<string, unknown>;
|
||||||
};
|
};
|
||||||
|
|
||||||
function isNonEmptyString(value: string | undefined): value is string {
|
function mergeMessages(
|
||||||
return typeof value === "string" && value.length > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
function messageIdentity(message: Message): string | undefined {
|
|
||||||
if (
|
|
||||||
"tool_call_id" in message &&
|
|
||||||
typeof message.tool_call_id === "string" &&
|
|
||||||
message.tool_call_id.length > 0
|
|
||||||
) {
|
|
||||||
return `tool:${message.tool_call_id}`;
|
|
||||||
}
|
|
||||||
if (typeof message.id === "string" && message.id.length > 0) {
|
|
||||||
return `message:${message.id}`;
|
|
||||||
}
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
|
||||||
const lastIndexByIdentity = new Map<string, number>();
|
|
||||||
|
|
||||||
messages.forEach((message, index) => {
|
|
||||||
const identity = messageIdentity(message);
|
|
||||||
if (identity) {
|
|
||||||
lastIndexByIdentity.set(identity, index);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return messages.filter((message, index) => {
|
|
||||||
const identity = messageIdentity(message);
|
|
||||||
return !identity || lastIndexByIdentity.get(identity) === index;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function findLatestUnloadedRunIndex(
|
|
||||||
runs: Run[],
|
|
||||||
loadedRunIds: ReadonlySet<string>,
|
|
||||||
): number {
|
|
||||||
for (let i = runs.length - 1; i >= 0; i--) {
|
|
||||||
const run = runs[i];
|
|
||||||
if (run && !loadedRunIds.has(run.run_id)) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function mergeMessages(
|
|
||||||
historyMessages: Message[],
|
historyMessages: Message[],
|
||||||
threadMessages: Message[],
|
threadMessages: Message[],
|
||||||
optimisticMessages: Message[],
|
optimisticMessages: Message[],
|
||||||
): Message[] {
|
): Message[] {
|
||||||
const threadMessageIds = new Set(
|
const threadMessageIds = new Set(
|
||||||
threadMessages.map(messageIdentity).filter(isNonEmptyString),
|
threadMessages
|
||||||
|
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
|
||||||
|
.filter(Boolean),
|
||||||
);
|
);
|
||||||
|
|
||||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||||
@@ -110,19 +65,28 @@ export function mergeMessages(
|
|||||||
if (!msg) {
|
if (!msg) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const identity = messageIdentity(msg);
|
if (
|
||||||
if (identity && threadMessageIds.has(identity)) {
|
(msg?.id && threadMessageIds.has(msg.id)) ||
|
||||||
|
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
|
||||||
|
) {
|
||||||
cutoff = i;
|
cutoff = i;
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return dedupeMessagesByIdentity([
|
return [
|
||||||
...historyMessages.slice(0, cutoff),
|
...historyMessages.slice(0, cutoff),
|
||||||
...threadMessages,
|
...threadMessages,
|
||||||
...optimisticMessages,
|
...optimisticMessages,
|
||||||
]);
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
function messageIdentity(message: Message): string | undefined {
|
||||||
|
if ("tool_call_id" in message) {
|
||||||
|
return message.tool_call_id;
|
||||||
|
}
|
||||||
|
return message.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getMessagesAfterBaseline(
|
function getMessagesAfterBaseline(
|
||||||
@@ -332,11 +296,7 @@ export function useThreadStream({
|
|||||||
onError(error) {
|
onError(error) {
|
||||||
setOptimisticMessages([]);
|
setOptimisticMessages([]);
|
||||||
toast.error(getStreamErrorMessage(error));
|
toast.error(getStreamErrorMessage(error));
|
||||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||||
messagesRef.current
|
|
||||||
.map(messageIdentity)
|
|
||||||
.filter((id): id is string => Boolean(id)),
|
|
||||||
);
|
|
||||||
if (threadIdRef.current && !isMock) {
|
if (threadIdRef.current && !isMock) {
|
||||||
void queryClient.invalidateQueries({
|
void queryClient.invalidateQueries({
|
||||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||||
@@ -345,11 +305,7 @@ export function useThreadStream({
|
|||||||
},
|
},
|
||||||
onFinish(state) {
|
onFinish(state) {
|
||||||
listeners.current.onFinish?.(state.values);
|
listeners.current.onFinish?.(state.values);
|
||||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||||
messagesRef.current
|
|
||||||
.map(messageIdentity)
|
|
||||||
.filter((id): id is string => Boolean(id)),
|
|
||||||
);
|
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
if (threadIdRef.current && !isMock) {
|
if (threadIdRef.current && !isMock) {
|
||||||
void queryClient.invalidateQueries({
|
void queryClient.invalidateQueries({
|
||||||
@@ -383,11 +339,7 @@ export function useThreadStream({
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
startedRef.current = false;
|
startedRef.current = false;
|
||||||
sendInFlightRef.current = false;
|
sendInFlightRef.current = false;
|
||||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||||
messagesRef.current
|
|
||||||
.map(messageIdentity)
|
|
||||||
.filter((id): id is string => Boolean(id)),
|
|
||||||
);
|
|
||||||
prevHumanMsgCountRef.current =
|
prevHumanMsgCountRef.current =
|
||||||
latestMessageCountsRef.current.humanMessageCount;
|
latestMessageCountsRef.current.humanMessageCount;
|
||||||
}, [threadId]);
|
}, [threadId]);
|
||||||
@@ -663,105 +615,48 @@ export function useThreadHistory(threadId: string) {
|
|||||||
const runsRef = useRef(runs.data ?? []);
|
const runsRef = useRef(runs.data ?? []);
|
||||||
const indexRef = useRef(-1);
|
const indexRef = useRef(-1);
|
||||||
const loadingRef = useRef(false);
|
const loadingRef = useRef(false);
|
||||||
const pendingLoadRef = useRef(false);
|
|
||||||
const loadingRunIdRef = useRef<string | null>(null);
|
|
||||||
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [messages, setMessages] = useState<Message[]>([]);
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
|
|
||||||
|
loadingRef.current = loading;
|
||||||
const loadMessages = useCallback(async () => {
|
const loadMessages = useCallback(async () => {
|
||||||
if (loadingRef.current) {
|
|
||||||
const pendingRunIndex = findLatestUnloadedRunIndex(
|
|
||||||
runsRef.current,
|
|
||||||
loadedRunIdsRef.current,
|
|
||||||
);
|
|
||||||
const pendingRun = runsRef.current[pendingRunIndex];
|
|
||||||
if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) {
|
|
||||||
pendingLoadRef.current = true;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (runsRef.current.length === 0) {
|
if (runsRef.current.length === 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
const run = runsRef.current[indexRef.current];
|
||||||
loadingRef.current = true;
|
if (!run || loadingRef.current) {
|
||||||
setLoading(true);
|
return;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
do {
|
setLoading(true);
|
||||||
pendingLoadRef.current = false;
|
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||||
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||||
const nextRunIndex = findLatestUnloadedRunIndex(
|
{
|
||||||
runsRef.current,
|
method: "GET",
|
||||||
loadedRunIdsRef.current,
|
headers: {
|
||||||
);
|
"Content-Type": "application/json",
|
||||||
indexRef.current = nextRunIndex;
|
|
||||||
|
|
||||||
const run = runsRef.current[nextRunIndex];
|
|
||||||
if (!run) {
|
|
||||||
indexRef.current = -1;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const requestThreadId = threadIdRef.current;
|
|
||||||
loadingRunIdRef.current = run.run_id;
|
|
||||||
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
|
||||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
|
||||||
{
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
credentials: "include",
|
|
||||||
},
|
},
|
||||||
).then((res) => {
|
credentials: "include",
|
||||||
return res.json();
|
},
|
||||||
});
|
).then((res) => {
|
||||||
const _messages = result.data
|
return res.json();
|
||||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
});
|
||||||
.map((m) => m.content);
|
const _messages = result.data
|
||||||
if (threadIdRef.current !== requestThreadId) {
|
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||||
return;
|
.map((m) => m.content);
|
||||||
}
|
setMessages((prev) => [..._messages, ...prev]);
|
||||||
setMessages((prev) =>
|
indexRef.current -= 1;
|
||||||
dedupeMessagesByIdentity([..._messages, ...prev]),
|
|
||||||
);
|
|
||||||
loadedRunIdsRef.current.add(run.run_id);
|
|
||||||
indexRef.current = findLatestUnloadedRunIndex(
|
|
||||||
runsRef.current,
|
|
||||||
loadedRunIdsRef.current,
|
|
||||||
);
|
|
||||||
} while (pendingLoadRef.current);
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
} finally {
|
} finally {
|
||||||
loadingRef.current = false;
|
|
||||||
loadingRunIdRef.current = null;
|
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const threadChanged = threadIdRef.current !== threadId;
|
|
||||||
threadIdRef.current = threadId;
|
threadIdRef.current = threadId;
|
||||||
|
|
||||||
if (threadChanged) {
|
|
||||||
runsRef.current = [];
|
|
||||||
indexRef.current = -1;
|
|
||||||
pendingLoadRef.current = false;
|
|
||||||
loadingRunIdRef.current = null;
|
|
||||||
loadedRunIdsRef.current = new Set();
|
|
||||||
loadingRef.current = false;
|
|
||||||
setLoading(false);
|
|
||||||
setMessages([]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (runs.data && runs.data.length > 0) {
|
if (runs.data && runs.data.length > 0) {
|
||||||
runsRef.current = runs.data ?? [];
|
runsRef.current = runs.data ?? [];
|
||||||
indexRef.current = findLatestUnloadedRunIndex(
|
indexRef.current = runs.data.length - 1;
|
||||||
runs.data,
|
|
||||||
loadedRunIdsRef.current,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
loadMessages().catch(() => {
|
loadMessages().catch(() => {
|
||||||
toast.error("Failed to load thread history.");
|
toast.error("Failed to load thread history.");
|
||||||
@@ -770,7 +665,7 @@ export function useThreadHistory(threadId: string) {
|
|||||||
|
|
||||||
const appendMessages = useCallback((_messages: Message[]) => {
|
const appendMessages = useCallback((_messages: Message[]) => {
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
return [...prev, ..._messages];
|
||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
const hasMore = indexRef.current >= 0 || !runs.data;
|
const hasMore = indexRef.current >= 0 || !runs.data;
|
||||||
|
|||||||
@@ -48,66 +48,4 @@ test.describe("Chat workspace", () => {
|
|||||||
timeout: 10_000,
|
timeout: 10_000,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
test("keeps attachments visible while upload submit is pending", async ({
|
|
||||||
page,
|
|
||||||
}) => {
|
|
||||||
let releaseUpload!: () => void;
|
|
||||||
const uploadCanFinish = new Promise<void>((resolve) => {
|
|
||||||
releaseUpload = resolve;
|
|
||||||
});
|
|
||||||
let uploadStarted!: () => void;
|
|
||||||
const uploadStartedPromise = new Promise<void>((resolve) => {
|
|
||||||
uploadStarted = resolve;
|
|
||||||
});
|
|
||||||
|
|
||||||
await page.route("**/api/threads/*/uploads", async (route) => {
|
|
||||||
uploadStarted();
|
|
||||||
await uploadCanFinish;
|
|
||||||
return route.fulfill({
|
|
||||||
status: 200,
|
|
||||||
contentType: "application/json",
|
|
||||||
body: JSON.stringify({
|
|
||||||
success: true,
|
|
||||||
message: "Uploaded",
|
|
||||||
files: [
|
|
||||||
{
|
|
||||||
filename: "report.docx",
|
|
||||||
size: 12,
|
|
||||||
path: "report.docx",
|
|
||||||
virtual_path: "/mnt/user-data/uploads/report.docx",
|
|
||||||
artifact_url: "/api/threads/test/uploads/report.docx",
|
|
||||||
extension: ".docx",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
await page.goto("/workspace/chats/new");
|
|
||||||
|
|
||||||
const textarea = page.getByPlaceholder(/how can i assist you/i);
|
|
||||||
await expect(textarea).toBeVisible({ timeout: 15_000 });
|
|
||||||
const promptForm = page.locator("form").filter({ has: textarea });
|
|
||||||
|
|
||||||
await page.getByLabel("Upload files").setInputFiles({
|
|
||||||
name: "report.docx",
|
|
||||||
mimeType:
|
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
||||||
buffer: Buffer.from("fake docx"),
|
|
||||||
});
|
|
||||||
await expect(promptForm.getByText("report.docx")).toBeVisible();
|
|
||||||
|
|
||||||
await textarea.fill("Summarize this document");
|
|
||||||
await textarea.press("Enter");
|
|
||||||
|
|
||||||
await uploadStartedPromise;
|
|
||||||
await expect(promptForm.getByText("report.docx")).toBeVisible();
|
|
||||||
|
|
||||||
releaseUpload();
|
|
||||||
await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({
|
|
||||||
timeout: 10_000,
|
|
||||||
});
|
|
||||||
await expect(promptForm.getByText("report.docx")).toBeHidden();
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -63,37 +63,3 @@ test("aggregates token usage messages once per assistant turn", () => {
|
|||||||
),
|
),
|
||||||
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
||||||
});
|
});
|
||||||
|
|
||||||
test("hides internal todo reminder messages from message groups", () => {
|
|
||||||
const messages = [
|
|
||||||
{
|
|
||||||
id: "human-1",
|
|
||||||
type: "human",
|
|
||||||
content: "Audit the middleware",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "todo-reminder-1",
|
|
||||||
type: "human",
|
|
||||||
name: "todo_completion_reminder",
|
|
||||||
content: "<system_reminder>finish todos</system_reminder>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "todo-reminder-2",
|
|
||||||
type: "human",
|
|
||||||
name: "todo_reminder",
|
|
||||||
content: "<system_reminder>remember todos</system_reminder>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "ai-1",
|
|
||||||
type: "ai",
|
|
||||||
content: "Done",
|
|
||||||
},
|
|
||||||
] as Message[];
|
|
||||||
|
|
||||||
const groups = getMessageGroups(messages);
|
|
||||||
|
|
||||||
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
|
|
||||||
expect(
|
|
||||||
groups.flatMap((group) => group.messages).map((message) => message.id),
|
|
||||||
).toEqual(["human-1", "ai-1"]);
|
|
||||||
});
|
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
import type { Message } from "@langchain/langgraph-sdk";
|
|
||||||
import { expect, test } from "vitest";
|
|
||||||
|
|
||||||
import { mergeMessages } from "@/core/threads/hooks";
|
|
||||||
|
|
||||||
test("mergeMessages removes duplicate messages already present in history", () => {
|
|
||||||
const human = {
|
|
||||||
id: "human-1",
|
|
||||||
type: "human",
|
|
||||||
content: "Design an agent",
|
|
||||||
} as Message;
|
|
||||||
const ai = {
|
|
||||||
id: "ai-1",
|
|
||||||
type: "ai",
|
|
||||||
content: "Let's design it.",
|
|
||||||
} as Message;
|
|
||||||
|
|
||||||
expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("mergeMessages lets live thread messages replace overlapping history", () => {
|
|
||||||
const oldHuman = {
|
|
||||||
id: "human-1",
|
|
||||||
type: "human",
|
|
||||||
content: "old",
|
|
||||||
} as Message;
|
|
||||||
const liveHuman = {
|
|
||||||
id: "human-1",
|
|
||||||
type: "human",
|
|
||||||
content: "live",
|
|
||||||
} as Message;
|
|
||||||
const oldAi = {
|
|
||||||
id: "ai-1",
|
|
||||||
type: "ai",
|
|
||||||
content: "old",
|
|
||||||
} as Message;
|
|
||||||
const liveAi = {
|
|
||||||
id: "ai-1",
|
|
||||||
type: "ai",
|
|
||||||
content: "live",
|
|
||||||
} as Message;
|
|
||||||
|
|
||||||
expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([
|
|
||||||
liveHuman,
|
|
||||||
liveAi,
|
|
||||||
]);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("mergeMessages deduplicates tool messages by tool_call_id", () => {
|
|
||||||
const oldTool = {
|
|
||||||
id: "tool-message-old",
|
|
||||||
type: "tool",
|
|
||||||
tool_call_id: "call-1",
|
|
||||||
content: "old",
|
|
||||||
} as Message;
|
|
||||||
const liveTool = {
|
|
||||||
id: "tool-message-live",
|
|
||||||
type: "tool",
|
|
||||||
tool_call_id: "call-1",
|
|
||||||
content: "live",
|
|
||||||
} as Message;
|
|
||||||
|
|
||||||
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
|
|
||||||
});
|
|
||||||
@@ -72,7 +72,6 @@ def find_config_file() -> Path | None:
|
|||||||
|
|
||||||
|
|
||||||
_SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$")
|
_SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$")
|
||||||
_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$")
|
|
||||||
_KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$")
|
_KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$")
|
||||||
|
|
||||||
|
|
||||||
@@ -142,84 +141,6 @@ def section_value(lines: list[str], section: str, key: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None:
|
|
||||||
"""Return the value of a nested YAML key like ``channels.discord.enabled``.
|
|
||||||
|
|
||||||
Handles two levels of nesting:
|
|
||||||
channels:
|
|
||||||
discord:
|
|
||||||
enabled: true
|
|
||||||
"""
|
|
||||||
parts = section_path.split(".")
|
|
||||||
if len(parts) != 2:
|
|
||||||
return None
|
|
||||||
parent_section, child_section = parts
|
|
||||||
|
|
||||||
inside_parent = False
|
|
||||||
inside_child = False
|
|
||||||
parent_indent: int | None = None
|
|
||||||
child_indent: int | None = None
|
|
||||||
|
|
||||||
for raw in lines:
|
|
||||||
line = _strip_comment(raw)
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
stripped = line.lstrip()
|
|
||||||
indent = len(line) - len(stripped)
|
|
||||||
|
|
||||||
# Top-level section match
|
|
||||||
sect_match = _SECTION_RE.match(line)
|
|
||||||
if sect_match:
|
|
||||||
if indent == 0:
|
|
||||||
inside_parent = sect_match.group(1) == parent_section
|
|
||||||
inside_child = False
|
|
||||||
parent_indent = None
|
|
||||||
child_indent = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not inside_parent:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Track parent indent from first child
|
|
||||||
if parent_indent is None and indent > 0:
|
|
||||||
parent_indent = indent
|
|
||||||
|
|
||||||
# If indent goes back to 0, we left the parent section
|
|
||||||
if indent == 0:
|
|
||||||
inside_parent = False
|
|
||||||
inside_child = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if we're at the parent's child level (subsection)
|
|
||||||
if parent_indent is not None and indent == parent_indent:
|
|
||||||
# This could be a subsection or a direct key of parent
|
|
||||||
sub_match = _INDENTED_SECTION_RE.match(line)
|
|
||||||
if sub_match and sub_match.group(1) == child_section:
|
|
||||||
inside_child = True
|
|
||||||
child_indent = None
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
inside_child = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not inside_child:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# We're inside the subsection — track child indent
|
|
||||||
if child_indent is None and indent > (parent_indent or 0):
|
|
||||||
child_indent = indent
|
|
||||||
|
|
||||||
if child_indent is not None and indent != child_indent:
|
|
||||||
continue
|
|
||||||
|
|
||||||
key_match = _KEY_RE.match(line)
|
|
||||||
if key_match and key_match.group(1) == key:
|
|
||||||
return _unquote(key_match.group(2).strip())
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def detect_from_config(path: Path) -> list[str]:
|
def detect_from_config(path: Path) -> list[str]:
|
||||||
try:
|
try:
|
||||||
text = path.read_text(encoding="utf-8", errors="replace")
|
text = path.read_text(encoding="utf-8", errors="replace")
|
||||||
@@ -231,8 +152,6 @@ def detect_from_config(path: Path) -> list[str]:
|
|||||||
extras.add("postgres")
|
extras.add("postgres")
|
||||||
if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres":
|
if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres":
|
||||||
extras.add("postgres")
|
extras.add("postgres")
|
||||||
if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true":
|
|
||||||
extras.add("discord")
|
|
||||||
return sorted(extras)
|
return sorted(extras)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user