diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 51b834b4f..ceebba99c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -185,9 +185,9 @@ If you need to start services individually: 1. **Start backend service**: ```bash - # Terminal 1: Start Gateway API and embedded LangGraph-compatible runtime (port 8001) + # Terminal 1: Start Gateway API + embedded agent runtime (port 8001) cd backend - make gateway + make dev # Terminal 2: Start Frontend (port 3000) cd frontend @@ -207,7 +207,7 @@ If you need to start services individually: The nginx configuration provides: - Unified entry point on port 2026 -- Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx +- Rewrites `/api/langgraph/*` to Gateway's LangGraph-compatible API (8001) - Routes other `/api/*` endpoints to Gateway API (8001) - Routes non-API requests to Frontend (3000) - Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist @@ -231,7 +231,7 @@ deer-flow/ ├── backend/ # Backend application │ ├── src/ │ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001) -│ │ ├── agents/ # LangGraph agent definitions +│ │ ├── agents/ # LangGraph agent runtime used by Gateway │ │ ├── mcp/ # Model Context Protocol integration │ │ ├── skills/ # Skills system │ │ └── sandbox/ # Sandbox execution diff --git a/README.md b/README.md index 9ff1d501b..8248e8fe4 100644 --- a/README.md +++ b/README.md @@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl Complex tasks rarely fit in a single pass. DeerFlow decomposes them. -The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. +The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step. This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands. diff --git a/README_fr.md b/README_fr.md index 3b8dc3d41..f144d8bc5 100644 --- a/README_fr.md +++ b/README_fr.md @@ -228,7 +228,7 @@ make down # Stop and remove containers ``` > [!NOTE] -> Le serveur d'agents LangGraph fonctionne actuellement via `langgraph dev` (le serveur CLI open source). +> Le runtime d'agent s'exécute actuellement dans la Gateway. nginx réécrit `/api/langgraph/*` vers l'API compatible LangGraph servie par la Gateway. Accès : http://localhost:2026 @@ -296,8 +296,8 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca ```yaml channels: - # LangGraph Server URL (default: http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL (default: http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/README_ja.md b/README_ja.md index d2ba81750..2bf060799 100644 --- a/README_ja.md +++ b/README_ja.md @@ -181,7 +181,7 @@ make down # コンテナを停止して削除 ``` > [!NOTE] -> LangGraphエージェントサーバーは現在`langgraph dev`(オープンソースCLIサーバー)経由で実行されます。 +> Agentランタイムは現在Gateway内で実行されます。`/api/langgraph/*`はnginxによってGatewayのLangGraph-compatible APIへ書き換えられます。 アクセス: http://localhost:2026 @@ -249,8 +249,8 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート ```yaml channels: - # LangGraphサーバーURL(デフォルト: http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL(デフォルト: http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL(デフォルト: http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/README_zh.md b/README_zh.md index d5317082e..ec67b95d6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -184,7 +184,7 @@ make down # 停止并移除容器 ``` > [!NOTE] -> 当前 LangGraph agent server 通过开源 CLI 服务 `langgraph dev` 运行。 +> 当前 Agent 运行时嵌入在 Gateway 中运行,`/api/langgraph/*` 会由 nginx 重写到 Gateway 的 LangGraph-compatible API。 访问地址:http://localhost:2026 @@ -254,8 +254,8 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应 ```yaml channels: - # LangGraph Server URL(默认:http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL(默认:http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL(默认:http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 67ee9cc7e..5e0aebfdb 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) -11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional) +11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) diff --git a/backend/CONTRIBUTING.md b/backend/CONTRIBUTING.md index 322710e74..f7ef58447 100644 --- a/backend/CONTRIBUTING.md +++ b/backend/CONTRIBUTING.md @@ -56,11 +56,8 @@ export OPENAI_API_KEY="your-api-key" ### Run the Development Server ```bash -# Terminal 1: LangGraph server +# Gateway API + embedded agent runtime make dev - -# Terminal 2: Gateway API -make gateway ``` ## Project Structure diff --git a/backend/README.md b/backend/README.md index 9b4d26fb1..8c61e2db2 100644 --- a/backend/README.md +++ b/backend/README.md @@ -11,34 +11,26 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent │ Nginx (Port 2026) │ │ Unified reverse proxy │ └───────┬──────────────────┬───────────┘ - │ │ - /api/langgraph/* │ │ /api/* (other) - ▼ ▼ - ┌──────────────────────────────────────────────┐ - │ Gateway API (8001) │ - │ FastAPI REST + LangGraph-compatible runtime │ - │ │ - │ Models, MCP, Skills, Memory, Uploads, │ - │ Artifacts, Threads, Runs, Streaming │ - │ │ - │ ┌────────────────┐ │ - │ │ Lead Agent │ │ - │ │ ┌──────────┐ │ │ - │ │ │Middleware│ │ │ - │ │ │ Chain │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │ Tools │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │Subagents │ │ │ - │ │ └──────────┘ │ │ - │ └────────────────┘ │ - └──────────────────────────────────────────────┘ + │ + /api/langgraph/* │ /api/* (other) + rewritten to /api/* │ + ▼ + ┌────────────────────────────────────────┐ + │ Gateway API (8001) │ + │ FastAPI REST + agent runtime │ + │ │ + │ Models, MCP, Skills, Memory, Uploads, │ + │ Artifacts, Threads, Runs, Streaming │ + │ │ + │ ┌────────────────────────────────────┐ │ + │ │ Lead Agent │ │ + │ │ Middleware Chain, Tools, Subagents │ │ + │ └────────────────────────────────────┘ │ + └────────────────────────────────────────┘ ``` **Request Routing** (via Nginx): -- `/api/langgraph/*` → Gateway API - LangGraph-compatible agent interactions, threads, runs, and streaming translated to native `/api/*` routers +- `/api/langgraph/*` → Gateway LangGraph-compatible API - agent interactions, threads, streaming - `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup - `/` (non-API) → Frontend - Next.js web interface @@ -196,7 +188,7 @@ export OPENAI_API_KEY="your-api-key-here" **Full Application** (from project root): ```bash -make dev # Starts LangGraph + Gateway + Frontend + Nginx +make dev # Starts Gateway + Frontend + Nginx ``` Access at: http://localhost:2026 @@ -204,14 +196,11 @@ Access at: http://localhost:2026 **Backend Only** (from backend directory): ```bash -# Terminal 1: LangGraph server +# Gateway API + embedded agent runtime make dev - -# Terminal 2: Gateway API -make gateway ``` -Direct access: LangGraph at http://localhost:2024, Gateway at http://localhost:8001 +Direct access: Gateway at http://localhost:8001 --- @@ -247,12 +236,16 @@ backend/ │ └── utils/ # Utilities ├── docs/ # Documentation ├── tests/ # Test suite -├── langgraph.json # LangGraph server configuration +├── langgraph.json # LangGraph graph registry for tooling/Studio compatibility ├── pyproject.toml # Python dependencies ├── Makefile # Development commands └── Dockerfile # Container build ``` +`langgraph.json` is not the default service entrypoint. The scripts and Docker +deployments run the Gateway embedded runtime; the file is kept for LangGraph +tooling, Studio, or direct LangGraph Server compatibility. + --- ## Configuration @@ -365,8 +358,8 @@ If a provider is explicitly enabled but required credentials are missing, or the ```bash make install # Install dependencies -make dev # Run LangGraph server (port 2024) -make gateway # Run Gateway API (port 8001) +make dev # Run Gateway API + embedded agent runtime (port 8001) +make gateway # Run Gateway API without reload (port 8001) make lint # Run linter (ruff) make format # Format code (ruff) ``` diff --git a/backend/app/channels/discord.py b/backend/app/channels/discord.py index 2d2889126..3b113c28d 100644 --- a/backend/app/channels/discord.py +++ b/backend/app/channels/discord.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio +import json import logging import threading +from pathlib import Path from typing import Any from app.channels.base import Channel @@ -21,6 +23,12 @@ class DiscordChannel(Channel): Configuration keys (in ``config.yaml`` under ``channels.discord``): - ``bot_token``: Discord Bot token. - ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all. + - ``mention_only``: (optional) If true, only respond when the bot is mentioned. + - ``allowed_channels``: (optional) List of channel IDs where messages are always accepted + (even when mention_only is true). Use for channels where you want the bot to respond + without mentions. Empty = mention_only applies everywhere. + - ``thread_mode``: (optional) If true, group a channel conversation into a thread. + Default: same as ``mention_only``. """ def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None: @@ -32,6 +40,29 @@ class DiscordChannel(Channel): self._allowed_guilds.add(int(guild_id)) except (TypeError, ValueError): continue + self._mention_only: bool = bool(config.get("mention_only", False)) + self._thread_mode: bool = config.get("thread_mode", self._mention_only) + self._allowed_channels: set[str] = set() + for channel_id in config.get("allowed_channels", []): + self._allowed_channels.add(str(channel_id)) + + # Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON). + # Uses a dedicated JSON file separate from ChannelStore, which maps IM + # conversations to DeerFlow thread IDs — a different concern. + self._active_threads: dict[str, str] = {} + # Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()). + self._active_thread_ids: set[str] = set() + # Lock protecting _active_threads and the JSON file from concurrent access. + # _run_client (Discord loop thread) and the main thread both read/write. + self._thread_store_lock = threading.Lock() + store = config.get("channel_store") + if store is not None: + self._thread_store_path = store._path.parent / "discord_threads.json" + else: + self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json" + + # Typing indicator management + self._typing_tasks: dict[str, asyncio.Task] = {} self._client = None self._thread: threading.Thread | None = None @@ -75,12 +106,56 @@ class DiscordChannel(Channel): self._thread = threading.Thread(target=self._run_client, daemon=True) self._thread.start() + self._load_active_threads() logger.info("Discord channel started") + def _load_active_threads(self) -> None: + """Restore Discord thread mappings from the dedicated JSON file on startup.""" + with self._thread_store_lock: + try: + if not self._thread_store_path.exists(): + logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path) + return + data = json.loads(self._thread_store_path.read_text()) + self._active_threads.clear() + self._active_thread_ids.clear() + for channel_id, thread_id in data.items(): + self._active_threads[channel_id] = thread_id + self._active_thread_ids.add(thread_id) + if self._active_threads: + logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path) + except Exception: + logger.exception("[Discord] failed to load thread mappings") + + def _save_thread(self, channel_id: str, thread_id: str) -> None: + """Persist a Discord thread mapping to the dedicated JSON file.""" + with self._thread_store_lock: + try: + data: dict[str, str] = {} + if self._thread_store_path.exists(): + data = json.loads(self._thread_store_path.read_text()) + old_id = data.get(channel_id) + data[channel_id] = thread_id + # Update reverse-lookup set + if old_id: + self._active_thread_ids.discard(old_id) + self._active_thread_ids.add(thread_id) + self._thread_store_path.parent.mkdir(parents=True, exist_ok=True) + self._thread_store_path.write_text(json.dumps(data, indent=2)) + except Exception: + logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id) + async def stop(self) -> None: self._running = False self.bus.unsubscribe_outbound(self._on_outbound) + # Cancel all active typing indicator tasks + for target_id, task in list(self._typing_tasks.items()): + if not task.done(): + task.cancel() + logger.debug("[Discord] cancelled typing task for target %s", target_id) + self._typing_tasks.clear() + if self._client and self._discord_loop and self._discord_loop.is_running(): close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop) try: @@ -100,6 +175,10 @@ class DiscordChannel(Channel): logger.info("Discord channel stopped") async def send(self, msg: OutboundMessage) -> None: + # Stop typing indicator once we're sending the response + stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop) + await asyncio.wrap_future(stop_future) + target = await self._resolve_target(msg) if target is None: logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts) @@ -111,6 +190,9 @@ class DiscordChannel(Channel): await asyncio.wrap_future(send_future) async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop) + await asyncio.wrap_future(stop_future) + target = await self._resolve_target(msg) if target is None: logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts) @@ -130,6 +212,41 @@ class DiscordChannel(Channel): logger.exception("[Discord] failed to upload file: %s", attachment.filename) return False + async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None: + """Starts a loop to send periodic typing indicators.""" + target_id = thread_ts or chat_id + if target_id in self._typing_tasks: + return # Already typing for this target + + async def _typing_loop(): + try: + while True: + try: + await channel.trigger_typing() + except Exception: + pass + await asyncio.sleep(10) + except asyncio.CancelledError: + pass + + task = asyncio.create_task(_typing_loop()) + self._typing_tasks[target_id] = task + + async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None: + """Stops the typing loop for a specific target.""" + target_id = thread_ts or chat_id + task = self._typing_tasks.pop(target_id, None) + if task and not task.done(): + task.cancel() + logger.debug("[Discord] stopped typing indicator for target %s", target_id) + + async def _add_reaction(self, message) -> None: + """Add a checkmark reaction to acknowledge the message was received.""" + try: + await message.add_reaction("✅") + except Exception: + logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True) + async def _on_message(self, message) -> None: if not self._running or not self._client: return @@ -152,15 +269,143 @@ class DiscordChannel(Channel): if self._discord_module is None: return - if isinstance(message.channel, self._discord_module.Thread): - chat_id = str(message.channel.parent_id or message.channel.id) - thread_id = str(message.channel.id) + # Determine whether the bot is mentioned in this message + user = self._client.user if self._client else None + if user: + bot_mention = user.mention # <@ID> + alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant) + standard_mention = f"<@{user.id}>" else: - thread = await self._create_thread(message) - if thread is None: + bot_mention = None + alt_mention = None + standard_mention = "" + has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content) + + # Strip mention from text for processing + if has_mention: + text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip() + # Don't return early if text is empty — still process the mention (e.g., create thread) + + # --- Determine thread/channel routing and typing target --- + thread_id = None + chat_id = None + typing_target = None # The Discord object to type into + + if isinstance(message.channel, self._discord_module.Thread): + # --- Message already inside a thread --- + thread_obj = message.channel + thread_id = str(thread_obj.id) + chat_id = str(thread_obj.parent_id or thread_obj.id) + typing_target = thread_obj + + # If this is a known active thread, process normally + if thread_id in self._active_thread_ids: + msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT + inbound = self._make_inbound( + chat_id=chat_id, + user_id=str(message.author.id), + text=text, + msg_type=msg_type, + thread_ts=thread_id, + metadata={ + "guild_id": str(guild.id) if guild else None, + "channel_id": str(message.channel.id), + "message_id": str(message.id), + }, + ) + inbound.topic_id = thread_id + self._publish(inbound) + # Start typing indicator in the thread + if typing_target: + asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id)) + asyncio.create_task(self._add_reaction(message)) return - chat_id = str(message.channel.id) - thread_id = str(thread.id) + + # Thread not tracked (orphaned) — create new thread and handle below + logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id) + thread_id = None + typing_target = None + + # At this point we're guaranteed to be in a channel, not a thread + # (the Thread case is handled above). Apply mention_only for all + # non-thread messages — no special case needed. + channel_id = str(message.channel.id) + + # Check if there's an active thread for this channel + if channel_id in self._active_threads: + # respect mention_only: if enabled, only process messages that mention the bot + # (unless the channel is in allowed_channels) + # Messages within a thread are always allowed through (continuation). + # At this code point we know the message is in a channel, not a thread + # (Thread case handled above), so always apply the check. + if self._mention_only and not has_mention and channel_id not in self._allowed_channels: + logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id) + return + # mention_only + fresh @ → create new thread instead of routing to existing one + if self._mention_only and has_mention: + thread_obj = await self._create_thread(message) + if thread_obj is not None: + target_thread_id = str(thread_obj.id) + self._active_threads[channel_id] = target_thread_id + self._save_thread(channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = thread_obj + logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id) + else: + logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id) + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel + else: + # Existing session → route to the existing thread + target_thread_id = self._active_threads[channel_id] + logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = await self._get_channel_or_thread(target_thread_id) + elif self._mention_only and not has_mention and channel_id not in self._allowed_channels: + # Not mentioned and not in an allowed channel → skip + logger.debug("[Discord] skipping message without mention in channel %s", channel_id) + return + elif self._mention_only and has_mention: + # First mention in this channel → create thread + thread_obj = await self._create_thread(message) + if thread_obj is not None: + target_thread_id = str(thread_obj.id) + self._active_threads[channel_id] = target_thread_id + self._save_thread(channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = thread_obj # Type into the new thread + logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name) + else: + # Fallback: thread creation failed (disabled/permissions), reply in channel + logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id) + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel # Type into the channel + elif self._thread_mode: + # thread_mode but mention_only is False → create thread anyway for conversation grouping + thread_obj = await self._create_thread(message) + if thread_obj is None: + # Thread creation failed (disabled/permissions), fall back to channel replies + logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id) + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel # Type into the channel + else: + target_thread_id = str(thread_obj.id) + self._active_threads[channel_id] = target_thread_id + self._save_thread(channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = thread_obj # Type into the new thread + else: + # No threading — reply directly in channel + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel # Type into the channel msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT inbound = self._make_inbound( @@ -177,6 +422,15 @@ class DiscordChannel(Channel): ) inbound.topic_id = thread_id + # Start typing indicator in the correct target (thread or channel) + if typing_target: + asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id)) + + self._publish(inbound) + asyncio.create_task(self._add_reaction(message)) + + def _publish(self, inbound) -> None: + """Publish an inbound message to the main event loop.""" if self._main_loop and self._main_loop.is_running(): future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop) future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None) @@ -198,14 +452,40 @@ class DiscordChannel(Channel): async def _create_thread(self, message): try: + if self._discord_module is None: + return None + + # Only TextChannel (type 0) and NewsChannel (type 10) support threads + channel_type = message.channel.type + if channel_type not in ( + self._discord_module.ChannelType.text, + self._discord_module.ChannelType.news, + ): + logger.info( + "[Discord] channel type %s (%s) does not support threads", + channel_type.value, + channel_type.name, + ) + return None + thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100] return await message.create_thread(name=thread_name) + except self._discord_module.errors.HTTPException as exc: + if exc.code == 50024: + logger.info( + "[Discord] cannot create thread in channel %s (error code 50024): %s", + message.channel.id, + channel_type.name if (channel_type := message.channel.type) else "unknown", + ) + else: + logger.exception( + "[Discord] failed to create thread for message=%s (HTTPException %s)", + message.id, + exc.code, + ) + return None except Exception: logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id) - try: - await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.") - except Exception: - pass return None async def _resolve_target(self, msg: OutboundMessage): diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index e59dbcf2c..aa52fa298 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -787,13 +787,22 @@ class ChannelManager: return logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) - result = await client.runs.wait( - thread_id, - assistant_id, - input={"messages": [{"role": "human", "content": msg.text}]}, - config=run_config, - context=run_context, - ) + try: + result = await client.runs.wait( + thread_id, + assistant_id, + input={"messages": [{"role": "human", "content": msg.text}]}, + config=run_config, + context=run_context, + multitask_strategy="reject", + ) + except Exception as exc: + if _is_thread_busy_error(exc): + logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id) + await self._send_error(msg, THREAD_BUSY_MESSAGE) + return + else: + raise response_text = _extract_response_text(result) artifacts = _extract_artifacts(result) diff --git a/backend/app/channels/service.py b/backend/app/channels/service.py index 4a3df9060..1b9526297 100644 --- a/backend/app/channels/service.py +++ b/backend/app/channels/service.py @@ -167,6 +167,8 @@ class ChannelService: return False try: + config = dict(config) + config["channel_store"] = self.store channel = channel_cls(bus=self.bus, config=config) self._channels[name] = channel await channel.start() diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 8848f473e..2c13f571c 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -62,7 +62,7 @@ async def _ensure_admin_user(app: FastAPI) -> None: Subsequent boots (admin already exists): - Runs the one-time "no-auth → with-auth" orphan thread migration for - existing LangGraph thread metadata that has no owner_id. + existing LangGraph thread metadata that has no user_id. No SQL persistence migration is needed: the four user_id columns (threads_meta, runs, run_events, feedback) only come into existence @@ -177,7 +177,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async with langgraph_runtime(app): logger.info("LangGraph runtime initialised") - # Ensure admin user exists (auto-create on first boot) + # Check admin bootstrap state and migrate orphan threads after admin exists. # Must run AFTER langgraph_runtime so app.state.store is available for thread migration await _ensure_admin_user(app) diff --git a/backend/app/gateway/auth/config.py b/backend/app/gateway/auth/config.py index 4734f0897..27c1984f1 100644 --- a/backend/app/gateway/auth/config.py +++ b/backend/app/gateway/auth/config.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, Field logger = logging.getLogger(__name__) +_SECRET_FILE = ".jwt_secret" + class AuthConfig(BaseModel): """JWT and auth-related configuration. Parsed once at startup. @@ -30,6 +32,32 @@ class AuthConfig(BaseModel): _auth_config: AuthConfig | None = None +def _load_or_create_secret() -> str: + """Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one.""" + from deerflow.config.paths import get_paths + + paths = get_paths() + secret_file = paths.base_dir / _SECRET_FILE + + try: + if secret_file.exists(): + secret = secret_file.read_text(encoding="utf-8").strip() + if secret: + return secret + except OSError as exc: + raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc + + secret = secrets.token_urlsafe(32) + try: + secret_file.parent.mkdir(parents=True, exist_ok=True) + fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, "w", encoding="utf-8") as fh: + fh.write(secret) + except OSError as exc: + raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc + return secret + + def get_auth_config() -> AuthConfig: """Get the global AuthConfig instance. Parses from env on first call.""" global _auth_config @@ -39,11 +67,11 @@ def get_auth_config() -> AuthConfig: load_dotenv() jwt_secret = os.environ.get("AUTH_JWT_SECRET") if not jwt_secret: - jwt_secret = secrets.token_urlsafe(32) + jwt_secret = _load_or_create_secret() os.environ["AUTH_JWT_SECRET"] = jwt_secret logger.warning( - "⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. " - "Sessions will be invalidated on restart. " + "⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret " + "persisted to .jwt_secret. Sessions will survive restarts. " "For production, add AUTH_JWT_SECRET to your .env file: " 'python -c "import secrets; print(secrets.token_urlsafe(32))"' ) diff --git a/backend/app/gateway/auth/models.py b/backend/app/gateway/auth/models.py index d8f9b954a..25c6476fe 100644 --- a/backend/app/gateway/auth/models.py +++ b/backend/app/gateway/auth/models.py @@ -28,7 +28,7 @@ class User(BaseModel): oauth_id: str | None = Field(None, description="User ID from OAuth provider") # Auth lifecycle - needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes") + needs_setup: bool = Field(default=False, description="True when a reset account must complete setup") token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs") diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py index 38e020150..202fab2d5 100644 --- a/backend/app/gateway/langgraph_auth.py +++ b/backend/app/gateway/langgraph_auth.py @@ -1,8 +1,12 @@ -"""LangGraph Server auth handler — shares JWT logic with Gateway. +"""LangGraph compatibility auth handler — shares JWT logic with Gateway. -Loaded by LangGraph Server via langgraph.json ``auth.path``. -Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway, -so both modes validate tokens with the same secret and rules. +The default DeerFlow runtime is embedded in the FastAPI Gateway; scripts and +Docker deployments do not load this module. It is retained for LangGraph +tooling, Studio, or direct LangGraph Server compatibility through +``langgraph.json``'s ``auth.path``. + +When that compatibility path is used, this module reuses the same JWT and CSRF +rules as Gateway so both modes validate sessions consistently. Two layers: 1. @auth.authenticate — validates JWT cookie, extracts user_id, diff --git a/backend/app/gateway/routers/artifacts.py b/backend/app/gateway/routers/artifacts.py index 78ea5fa00..a2cc5b02b 100644 --- a/backend/app/gateway/routers/artifacts.py +++ b/backend/app/gateway/routers/artifacts.py @@ -20,6 +20,9 @@ ACTIVE_CONTENT_MIME_TYPES = { "image/svg+xml", } +MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024 +_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024 + def _build_content_disposition(disposition_type: str, filename: str) -> str: """Build an RFC 5987 encoded Content-Disposition header value.""" @@ -44,6 +47,22 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool: return False +def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes: + """Read a .skill archive member while enforcing an uncompressed size cap.""" + if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES: + raise HTTPException(status_code=413, detail="Skill archive member is too large to preview") + + chunks: list[bytes] = [] + total_read = 0 + with zip_ref.open(info, "r") as src: + while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE): + total_read += len(chunk) + if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES: + raise HTTPException(status_code=413, detail="Skill archive member is too large to preview") + chunks.append(chunk) + return b"".join(chunks) + + def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None: """Extract a file from a .skill ZIP archive. @@ -60,16 +79,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte try: with zipfile.ZipFile(zip_path, "r") as zip_ref: # List all files in the archive - namelist = zip_ref.namelist() + infos_by_name = {info.filename: info for info in zip_ref.infolist()} # Try direct path first - if internal_path in namelist: - return zip_ref.read(internal_path) + if internal_path in infos_by_name: + return _read_skill_archive_member(zip_ref, infos_by_name[internal_path]) # Try with any top-level directory prefix (e.g., "skill-name/SKILL.md") - for name in namelist: + for name, info in infos_by_name.items(): if name.endswith("/" + internal_path) or name == internal_path: - return zip_ref.read(name) + return _read_skill_archive_member(zip_ref, info) # Not found return None diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py index 3a41e13eb..6192456fb 100644 --- a/backend/app/gateway/routers/auth.py +++ b/backend/app/gateway/routers/auth.py @@ -305,7 +305,7 @@ async def login_local( async def register(request: Request, response: Response, body: RegisterRequest): """Register a new user account (always 'user' role). - Admin is auto-created on first boot. This endpoint creates regular users. + The first admin is created explicitly through /initialize. This endpoint creates regular users. Auto-login by setting the session cookie. """ try: diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index cb048152e..e6f4fa2ae 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -90,6 +90,28 @@ class ThreadSearchRequest(BaseModel): offset: int = Field(default=0, ge=0, description="Pagination offset") status: str | None = Field(default=None, description="Filter by thread status") + @field_validator("metadata") + @classmethod + def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]: + """Reject filter entries the SQL backend cannot compile. + + Enforces consistent behaviour across SQL and memory backends. + See ``deerflow.persistence.json_compat`` for the shared validators. + """ + if not v: + return v + from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value + + bad_entries: list[str] = [] + for key, value in v.items(): + if not validate_metadata_filter_key(key): + bad_entries.append(f"{key!r} (unsafe key)") + elif not validate_metadata_filter_value(value): + bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})") + if bad_entries: + raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}") + return v + class ThreadStateResponse(BaseModel): """Response model for thread state.""" @@ -294,14 +316,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th (SQL-backed for sqlite/postgres, Store-backed for memory mode). """ from app.gateway.deps import get_thread_store + from deerflow.persistence.thread_meta import InvalidMetadataFilterError repo = get_thread_store(request) - rows = await repo.search( - metadata=body.metadata or None, - status=body.status, - limit=body.limit, - offset=body.offset, - ) + try: + rows = await repo.search( + metadata=body.metadata or None, + status=body.status, + limit=body.limit, + offset=body.offset, + ) + except InvalidMetadataFilterError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc return [ ThreadResponse( thread_id=r["thread_id"], diff --git a/backend/docs/API.md b/backend/docs/API.md index 293c1ebd1..762a135c4 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -535,14 +535,28 @@ All APIs return errors in a consistent format: ## Authentication -Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials. +DeerFlow enforces authentication for all non-public HTTP routes. Public routes are limited to health/docs metadata and these public auth endpoints: -Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers. +- `POST /api/v1/auth/initialize` creates the first admin account when no admin exists. +- `POST /api/v1/auth/login/local` logs in with email/password and sets an HttpOnly `access_token` cookie. +- `POST /api/v1/auth/register` creates a regular `user` account and sets the session cookie. +- `POST /api/v1/auth/logout` clears the session cookie. +- `GET /api/v1/auth/setup-status` reports whether the first admin still needs to be created. -For production deployments, it is recommended to: -1. Use Nginx for basic auth or OAuth integration -2. Deploy behind a VPN or private network -3. Implement custom authentication middleware +The authenticated auth endpoints are: + +- `GET /api/v1/auth/me` returns the current user. +- `POST /api/v1/auth/change-password` changes password, optionally changes email during setup, increments `token_version`, and reissues the cookie. + +Protected state-changing requests also require the CSRF double-submit token: send the `csrf_token` cookie value as the `X-CSRF-Token` header. Login/register/initialize/logout are bootstrap auth endpoints: they are exempt from the double-submit token but still reject hostile browser `Origin` headers. + +User isolation is enforced from the authenticated user context: + +- Thread metadata is scoped by `threads_meta.user_id`; search/read/write/delete APIs only expose the current user's threads. +- Thread files live under `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/` and are exposed inside the sandbox as `/mnt/user-data/`. +- Memory and custom agents are stored under `{base_dir}/users/{user_id}/...`. + +Note: MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers; that is separate from DeerFlow API authentication. --- @@ -561,12 +575,13 @@ location /api/ { --- -## WebSocket Support +## Streaming Support -The LangGraph server supports WebSocket connections for real-time streaming. Connect to: +Gateway's LangGraph-compatible API streams run events with Server-Sent Events (SSE): -``` -ws://localhost:2026/api/langgraph/threads/{thread_id}/runs/stream +```http +POST /api/langgraph/threads/{thread_id}/runs/stream +Accept: text/event-stream ``` --- @@ -602,13 +617,21 @@ const response = await fetch('/api/models'); const data = await response.json(); console.log(data.models); -// Using EventSource for streaming -const eventSource = new EventSource( - `/api/langgraph/threads/${threadId}/runs/stream` -); -eventSource.onmessage = (event) => { - console.log(JSON.parse(event.data)); -}; +// Create a run and stream SSE events +const streamResponse = await fetch(`/api/langgraph/threads/${threadId}/runs/stream`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ + input: { messages: [{ role: "user", content: "Hello" }] }, + stream_mode: ["values", "messages-tuple", "custom"], + }), +}); + +const reader = streamResponse.body?.getReader(); +// Decode and parse SSE frames from reader in your client code. ``` ### cURL Examples diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index e6fdbe217..47859cc9c 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -20,24 +20,22 @@ This document provides a comprehensive overview of the DeerFlow backend architec │ └────────────────────────────────────────────────────────────────────┘ │ └─────────────────────────────────┬────────────────────────────────────────┘ │ - ┌───────────────────────┼───────────────────────┐ - │ │ │ - ▼ ▼ ▼ -┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ -│ Embedded Runtime │ │ Gateway API │ │ Frontend │ -│ (inside Gateway) │ │ (Port 8001) │ │ (Port 3000) │ -│ │ │ │ │ │ -│ - Agent Runtime │ │ - Models API │ │ - Next.js App │ -│ - Thread Mgmt │ │ - MCP Config │ │ - React UI │ -│ - SSE Streaming │ │ - Skills Mgmt │ │ - Chat Interface │ -│ - Checkpointing │ │ - File Uploads │ │ │ -│ │ │ - Thread Cleanup │ │ │ -│ │ │ - Artifacts │ │ │ -└─────────────────────┘ └─────────────────────┘ └─────────────────────┘ - │ │ - │ ┌─────────────────┘ - │ │ - ▼ ▼ + ┌───────────────────────┴───────────────────────┐ + │ │ + ▼ ▼ +┌─────────────────────────────────────────────┐ ┌─────────────────────┐ +│ Gateway API │ │ Frontend │ +│ (Port 8001) │ │ (Port 3000) │ +│ │ │ │ +│ - LangGraph-compatible runs/threads API │ │ - Next.js App │ +│ - Embedded Agent Runtime │ │ - React UI │ +│ - SSE Streaming │ │ - Chat Interface │ +│ - Checkpointing │ │ │ +│ - Models, MCP, Skills, Uploads, Artifacts │ │ │ +│ - Thread Cleanup │ │ │ +└─────────────────────────────────────────────┘ └─────────────────────┘ + │ + ▼ ┌──────────────────────────────────────────────────────────────────────────┐ │ Shared Configuration │ │ ┌─────────────────────────┐ ┌────────────────────────────────────────┐ │ @@ -52,9 +50,9 @@ This document provides a comprehensive overview of the DeerFlow backend architec ## Component Details -### Embedded LangGraph Runtime +### Gateway Embedded Agent Runtime -The LangGraph-compatible runtime runs inside the Gateway process and is built on LangGraph for robust multi-agent workflow orchestration. +The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for robust multi-agent workflow orchestration. Nginx rewrites `/api/langgraph/*` to Gateway's native `/api/*` routes, so the public API remains compatible with LangGraph SDK clients without running a separate LangGraph server. **Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent` @@ -65,7 +63,8 @@ The LangGraph-compatible runtime runs inside the Gateway process and is built on - Tool execution orchestration - SSE streaming for real-time responses -**Configuration**: `langgraph.json` +**Graph registry**: `langgraph.json` remains available for tooling, Studio, or direct LangGraph Server compatibility. +It is not the default service entrypoint; scripts and Docker deployments run the Gateway embedded runtime. ```json { @@ -84,6 +83,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl **Routers**: - `models.py` - `/api/models` - Model listing and details +- `thread_runs.py` / `runs.py` - `/api/threads/{id}/runs`, `/api/runs/*` - LangGraph-compatible runs and streaming - `mcp.py` - `/api/mcp` - MCP server configuration - `skills.py` - `/api/skills` - Skills management - `uploads.py` - `/api/threads/{id}/uploads` - File upload @@ -91,7 +91,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl - `artifacts.py` - `/api/threads/{id}/artifacts` - Artifact serving - `suggestions.py` - `/api/threads/{id}/suggestions` - Follow-up suggestion generation -The web conversation delete flow is now split across both backend surfaces: LangGraph handles `DELETE /api/langgraph/threads/{thread_id}` for thread state, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`. +The web conversation delete flow first deletes Gateway-managed thread state through the LangGraph-compatible route, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`. ### Agent Architecture @@ -354,9 +354,9 @@ SKILL.md Format: {"input": {"messages": [{"role": "user", "content": "Hello"}]}} 2. Nginx → Gateway API (8001) - Routes `/api/langgraph/*` to the Gateway's LangGraph-compatible runtime + `/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes -3. Embedded LangGraph runtime +3. Gateway embedded runtime a. Load/create thread state b. Execute middleware chain: - ThreadDataMiddleware: Set up paths @@ -412,7 +412,7 @@ SKILL.md Format: ### Thread Cleanup Flow ``` -1. Client deletes conversation via LangGraph +1. Client deletes conversation via the LangGraph-compatible Gateway route DELETE /api/langgraph/threads/{thread_id} 2. Web UI follows up with Gateway cleanup diff --git a/backend/docs/AUTH_DESIGN.md b/backend/docs/AUTH_DESIGN.md new file mode 100644 index 000000000..9a740871d --- /dev/null +++ b/backend/docs/AUTH_DESIGN.md @@ -0,0 +1,331 @@ +# 用户认证与隔离设计 + +本文档描述 DeerFlow 当前内置认证模块的设计,而不是历史 RFC。它覆盖浏览器登录、API 认证、CSRF、用户隔离、首次初始化、密码重置、内部调用和升级迁移。 + +## 设计目标 + +认证模块的核心目标是把 DeerFlow 从“本地单用户工具”提升为“可多用户部署的 agent runtime”,并让用户身份贯穿 HTTP API、LangGraph-compatible runtime、文件系统、memory、自定义 agent 和反馈数据。 + +设计约束: + +- 默认强制认证:除健康检查、文档和 auth bootstrap 端点外,HTTP 路由都必须有有效 session。 +- 服务端持有所有权:客户端 metadata 不能声明 `user_id` 或 `owner_id`。 +- 隔离默认开启:repository(仓储)、文件路径、memory、agent 配置默认按当前用户解析。 +- 旧数据可升级:无认证版本留下的 thread 可以在 admin 存在后迁移到 admin。 +- 密码不进日志:首次初始化由操作者设置密码;`reset_admin` 只写 0600 凭据文件。 + +非目标: + +- 当前 OAuth 端点只是占位,尚未实现第三方登录。 +- 当前用户角色只有 `admin` 和 `user`,尚未实现细粒度 RBAC。 +- 当前登录限速是进程内字典,多 worker 下不是全局精确限速。 + +## 核心模型 + +```mermaid +graph TB + classDef actor fill:#D8CFC4,stroke:#6E6259,color:#2F2A26; + classDef api fill:#C9D7D2,stroke:#5D706A,color:#21302C; + classDef state fill:#D7D3E8,stroke:#6B6680,color:#29263A; + classDef data fill:#E5D2C4,stroke:#806A5B,color:#30251E; + + Browser["Browser — access_token cookie and csrf_token cookie"]:::actor + AuthMiddleware["AuthMiddleware — strict session gate"]:::api + CSRFMiddleware["CSRFMiddleware — double-submit token and Origin check"]:::api + AuthRoutes["Auth routes — initialize login register logout me change-password"]:::api + UserContext["Current user ContextVar — request-scoped identity"]:::state + Repositories["Repositories — AUTO resolves user_id from context"]:::state + Files["Filesystem — users/{user_id}/threads/{thread_id}/user-data"]:::data + Memory["Memory and agents — users/{user_id}/memory.json and agents"]:::data + + Browser --> AuthMiddleware + Browser --> CSRFMiddleware + AuthMiddleware --> AuthRoutes + AuthMiddleware --> UserContext + UserContext --> Repositories + UserContext --> Files + UserContext --> Memory +``` + +### 用户表 + +用户记录定义在 `app.gateway.auth.models.User`,持久化到 `users` 表。关键字段: + +| 字段 | 语义 | +|---|---| +| `id` | 用户主键,JWT `sub` 使用该值 | +| `email` | 唯一登录名 | +| `password_hash` | bcrypt hash,OAuth 用户可为空 | +| `system_role` | `admin` 或 `user` | +| `needs_setup` | reset 后要求用户完成邮箱 / 密码设置 | +| `token_version` | 改密码或 reset 时递增,用于废弃旧 JWT | + +### 运行时身份 + +认证成功后,`AuthMiddleware` 把用户同时写入: + +- `request.state.user` +- `request.state.auth` +- `deerflow.runtime.user_context` 的 `ContextVar` + +`ContextVar` 是这里的核心边界。上层 Gateway 负责写入身份,下层 persistence / file path 只读取结构化的当前用户,不反向依赖 `app.gateway.auth` 具体类型。 + +可以把 repository 调用的用户参数理解成一个三态 ADT: + +```scala +enum UserScope: + case AutoFromContext + case Explicit(userId: String) + case BypassForMigration +``` + +对应 Python 实现是 `AUTO | str | None`: + +- `AUTO`:从 `ContextVar` 解析当前用户;没有上下文则抛错。 +- `str`:显式指定用户,主要用于测试或管理脚本。 +- `None`:跳过用户过滤,只允许迁移脚本或 admin CLI 使用。 + +## 登录与初始化流程 + +### 首次初始化 + +首次启动时,如果没有 admin,服务不会自动创建账号,只记录日志提示访问 `/setup`。 + +流程: + +1. 用户访问 `/setup`。 +2. 前端调用 `GET /api/v1/auth/setup-status`。 +3. 如果返回 `{"needs_setup": true}`,前端展示创建 admin 表单。 +4. 表单提交 `POST /api/v1/auth/initialize`。 +5. 服务端确认当前没有 admin,创建 `system_role="admin"`、`needs_setup=false` 的用户。 +6. 服务端设置 `access_token` HttpOnly cookie,用户进入 workspace。 + +`/api/v1/auth/initialize` 只在没有 admin 时可用。并发初始化由数据库唯一约束兜底,失败方返回 409。 + +### 普通登录 + +`POST /api/v1/auth/login/local` 使用 `OAuth2PasswordRequestForm`: + +- `username` 是邮箱。 +- `password` 是密码。 +- 成功后签发 JWT,放入 `access_token` HttpOnly cookie。 +- 响应体只返回 `expires_in` 和 `needs_setup`,不返回 token。 + +登录失败会按客户端 IP 计数。IP 解析只在 TCP peer 属于 `AUTH_TRUSTED_PROXIES` 时信任 `X-Real-IP`,不使用 `X-Forwarded-For`。 + +### 注册 + +`POST /api/v1/auth/register` 创建普通 `user`,并自动登录。 + +当前实现允许在没有 admin 时注册普通用户,但 `setup-status` 仍会返回 `needs_setup=true`,因为 admin 仍不存在。这是当前产品策略边界:如果后续要求“必须先初始化 admin 才能注册普通用户”,需要在 `/register` 增加 admin-exists gate。 + +### 改密码与 reset setup + +`POST /api/v1/auth/change-password` 需要当前密码和新密码: + +- 校验当前密码。 +- 更新 bcrypt hash。 +- `token_version += 1`,使旧 JWT 立即失效。 +- 重新签发 cookie。 +- 如果 `needs_setup=true` 且传了 `new_email`,则更新邮箱并清除 `needs_setup`。 + +`python -m app.gateway.auth.reset_admin` 会: + +- 找到 admin 或指定邮箱用户。 +- 生成随机密码。 +- 更新密码 hash。 +- `token_version += 1`。 +- 设置 `needs_setup=true`。 +- 写入 `.deer-flow/admin_initial_credentials.txt`,权限 `0600`。 + +命令行只输出凭据文件路径,不输出明文密码。 + +## HTTP 认证边界 + +`AuthMiddleware` 是 fail-closed(默认拒绝)的全局认证门。 + +公开路径: + +- `/health` +- `/docs` +- `/redoc` +- `/openapi.json` +- `/api/v1/auth/login/local` +- `/api/v1/auth/register` +- `/api/v1/auth/logout` +- `/api/v1/auth/setup-status` +- `/api/v1/auth/initialize` + +其余路径都要求有效 `access_token` cookie。存在 cookie 但 JWT 无效、过期、用户不存在或 `token_version` 不匹配时,直接返回 401,而不是让请求穿透到业务路由。 + +路由级别的 owner check 由 `require_permission(..., owner_check=True)` 完成: + +- 读类请求允许旧的未追踪 legacy thread 兼容读取。 +- 写 / 删除类请求使用 `require_existing=True`,要求 thread row 存在且属于当前用户,避免删除后缺 row 导致其他用户误通过。 + +## CSRF 设计 + +DeerFlow 使用 Double Submit Cookie: + +- 服务端设置 `csrf_token` cookie。 +- 前端 state-changing 请求发送同值 `X-CSRF-Token` header。 +- 服务端用 `secrets.compare_digest` 比较 cookie/header。 + +需要 CSRF 的方法: + +- `POST` +- `PUT` +- `DELETE` +- `PATCH` + +auth bootstrap 端点(login/register/initialize/logout)不要求 double-submit token,因为首次调用时浏览器还没有 token;但这些端点会校验 browser `Origin`,拒绝 hostile Origin,避免 login CSRF / session fixation。 + +## 用户隔离 + +### Thread metadata + +Thread metadata 存在 `threads_meta`,关键隔离字段是 `user_id`。 + +创建 thread 时: + +- 客户端传入的 `metadata.user_id` 和 `metadata.owner_id` 会被剥离。 +- `ThreadMetaRepository.create(..., user_id=AUTO)` 从 `ContextVar` 解析真实用户。 +- `/api/threads/search` 默认只返回当前用户的 thread。 + +读取 / 修改 / 删除时: + +- `get()` 默认按当前用户过滤。 +- `check_access()` 用于路由 owner check。 +- 对其他用户的 thread 返回 404,避免泄露资源存在性。 + +### 文件系统 + +当前线程文件布局: + +```text +{base_dir}/users/{user_id}/threads/{thread_id}/user-data/ +├── workspace/ +├── uploads/ +└── outputs/ +``` + +agent 在 sandbox 内看到统一虚拟路径: + +```text +/mnt/user-data/workspace +/mnt/user-data/uploads +/mnt/user-data/outputs +``` + +`ThreadDataMiddleware` 使用 `get_effective_user_id()` 解析当前用户并生成线程路径。没有认证上下文时会落到 `default` 用户桶,主要用于内部调用、嵌入式 client 或无 HTTP 的本地执行路径。 + +### Memory + +默认 memory 存储: + +```text +{base_dir}/users/{user_id}/memory.json +{base_dir}/users/{user_id}/agents/{agent_name}/memory.json +``` + +有用户上下文时,空或相对 `memory.storage_path` 都使用上述 per-user 默认路径;只有绝对 `memory.storage_path` 会视为显式 opt-out(退出) per-user isolation,所有用户共享该路径。无用户上下文的 legacy 路径仍会把相对 `storage_path` 解析到 `Paths.base_dir` 下。 + +### 自定义 agent + +用户自定义 agent 写入: + +```text +{base_dir}/users/{user_id}/agents/{agent_name}/ +├── config.yaml +├── SOUL.md +└── memory.json +``` + +旧布局 `{base_dir}/agents/{agent_name}/` 只作为只读兼容回退。更新或删除旧共享 agent 会要求先运行迁移脚本。 + +## 内部调用与 IM 渠道 + +IM channel worker 不是浏览器用户,不持有浏览器 cookie。它们通过 Gateway 内部认证: + +- 请求带 `X-DeerFlow-Internal-Token`。 +- 同时带匹配的 CSRF cookie/header。 +- 服务端识别为内部用户,`id="default"`、`system_role="internal"`。 + +这意味着 channel 产生的数据默认进入 `default` 用户桶。这个选择适合“平台级 bot 身份”,但不是“每个 IM 用户单独隔离”。如果后续要做到外部 IM 用户隔离,需要把外部 platform user 映射到 DeerFlow user,并让 channel manager 设置对应的 scoped identity。 + +## LangGraph-compatible 认证 + +Gateway 内嵌 runtime 路径由 `AuthMiddleware` 和 `CSRFMiddleware` 保护。 + +仓库仍保留 `app.gateway.langgraph_auth`,用于 LangGraph Server 直连模式: + +- `@auth.authenticate` 校验 JWT cookie、CSRF、用户存在性和 `token_version`。 +- `@auth.on` 在写入 metadata 时注入 `user_id`,并在读路径返回 `{"user_id": current_user}` 过滤条件。 + +这保证 Gateway 路由和 LangGraph-compatible 直连模式使用同一 JWT 语义。 + +## 升级与迁移 + +从无认证版本升级时,可能存在没有 `user_id` 的历史 thread。 + +当前策略: + +1. 首次启动如果没有 admin,只提示访问 `/setup`,不迁移。 +2. 操作者创建 admin。 +3. 后续启动时,`_ensure_admin_user()` 找到 admin,并把 LangGraph store 中缺少 `metadata.user_id` 的 thread 迁移到 admin。 + +文件系统旧布局迁移由脚本处理: + +```bash +cd backend +PYTHONPATH=. python scripts/migrate_user_isolation.py --dry-run +PYTHONPATH=. python scripts/migrate_user_isolation.py --user-id +``` + +迁移脚本覆盖 legacy `memory.json`、`threads/` 和 `agents/` 到 per-user layout。 + +## 安全不变量 + +必须长期保持的不变量: + +- JWT 只在 HttpOnly cookie 中传输,不出现在响应 JSON。 +- 任何非 public HTTP 路由都不能只靠“cookie 存在”放行,必须严格验证 JWT。 +- `token_version` 不匹配必须拒绝,保证改密码 / reset 后旧 session 失效。 +- 客户端 metadata 中的 `user_id` / `owner_id` 必须剥离。 +- repository 默认 `AUTO` 必须从当前用户上下文解析,不能静默退化成全局查询。 +- 只有迁移脚本和 admin CLI 可以显式传 `user_id=None` 绕过隔离。 +- 本地文件路径必须通过 `Paths` 和 sandbox path validation 解析,不能拼接未校验的用户输入。 +- 捕获认证、迁移、后台任务异常必须记录日志;不能空 catch。 + +## 已知边界 + +| 边界 | 当前行为 | 后续方向 | +|---|---|---| +| 无 admin 时注册普通用户 | 允许注册普通 `user` | 如产品要求先初始化 admin,给 `/register` 加 gate | +| 登录限速 | 进程内 dict,单 worker 精确,多 worker 近似 | Redis / DB-backed rate limiter | +| OAuth | 端点占位,未实现 | 接入 provider 并统一 `token_version` / role 语义 | +| IM 用户隔离 | channel 使用 `default` 内部用户 | 建立外部用户到 DeerFlow user 的映射 | +| 绝对 memory path | 显式共享 memory | UI / docs 明确提示 opt-out 风险 | + +## 相关文件 + +| 文件 | 职责 | +|---|---| +| `app/gateway/auth_middleware.py` | 全局认证门、JWT 严格验证、写入 user context | +| `app/gateway/csrf_middleware.py` | CSRF double-submit 和 auth Origin 校验 | +| `app/gateway/routers/auth.py` | initialize/login/register/logout/me/change-password | +| `app/gateway/auth/jwt.py` | JWT 创建与解析 | +| `app/gateway/auth/reset_admin.py` | 密码 reset CLI | +| `app/gateway/auth/credential_file.py` | 0600 凭据文件写入 | +| `app/gateway/authz.py` | 路由权限与 owner check | +| `deerflow/runtime/user_context.py` | 当前用户 ContextVar 与 `AUTO` sentinel | +| `deerflow/persistence/thread_meta/` | thread metadata owner filter | +| `deerflow/config/paths.py` | per-user filesystem layout | +| `deerflow/agents/middlewares/thread_data_middleware.py` | run 时解析用户线程目录 | +| `deerflow/agents/memory/storage.py` | per-user memory storage | +| `deerflow/config/agents_config.py` | per-user custom agents | +| `app/channels/manager.py` | IM channel 内部认证调用 | +| `scripts/migrate_user_isolation.py` | legacy 数据迁移到 per-user layout | +| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库,包含 users / threads_meta / runs / feedback 等表 | +| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | +| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | diff --git a/backend/docs/AUTH_TEST_DOCKER_GAP.md b/backend/docs/AUTH_TEST_DOCKER_GAP.md index adf4916a3..969aad92c 100644 --- a/backend/docs/AUTH_TEST_DOCKER_GAP.md +++ b/backend/docs/AUTH_TEST_DOCKER_GAP.md @@ -24,11 +24,11 @@ All other test plan sections were executed against either: | Case | Title | What it covers | Why not run | |---|---|---|---| -| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` | +| TC-DOCKER-01 | `deerflow.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` | | TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` | | TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container | -| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` | -| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | +| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` | +| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | | TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` | ## Coverage already provided by non-Docker tests @@ -41,8 +41,8 @@ the test cases that ran on sg_dev or local: | TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between | | TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` | | TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs | -| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP | -| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | +| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies | +| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | | TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change | ## Reproduction steps when Docker becomes available @@ -72,6 +72,6 @@ Then run TC-DOCKER-01..06 from the test plan as written. about *container packaging* details (bind mounts, multi-worker, log collection), not about whether the auth code paths work. - **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect - the post-simplify reality (credentials file → 0600 file, no log leak). + the current reset flow (`reset_admin` → 0600 credentials file, no log leak). The old "grep 'Password:' in docker logs" expectation would have failed silently and given a false sense of coverage. diff --git a/backend/docs/AUTH_TEST_PLAN.md b/backend/docs/AUTH_TEST_PLAN.md index 15b20494a..e5245d60b 100644 --- a/backend/docs/AUTH_TEST_PLAN.md +++ b/backend/docs/AUTH_TEST_PLAN.md @@ -19,7 +19,7 @@ ```bash # 清除已有数据 -rm -f backend/.deer-flow/users.db +rm -f backend/.deer-flow/data/deerflow.db # 选择模式启动 make dev # 标准模式 @@ -28,10 +28,11 @@ make dev-pro # Gateway 模式 ``` **验证点:** -- [ ] 控制台输出 admin 邮箱和随机密码 -- [ ] 密码格式为 `secrets.token_urlsafe(16)` 的 22 字符字符串 -- [ ] 邮箱为 `admin@deerflow.dev` -- [ ] 提示 `Change it after login: Settings -> Account` +- [ ] 控制台不输出 admin 邮箱或明文密码 +- [ ] 控制台提示 `First boot detected — no admin account exists.` +- [ ] 控制台提示访问 `/setup` 完成 admin 创建 +- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": true}` +- [ ] 前端访问 `/login` 会跳转 `/setup` ### 1.2 非首次启动 @@ -42,7 +43,8 @@ make dev **验证点:** - [ ] 控制台不输出密码 -- [ ] 如果 admin 仍 `needs_setup=True`,控制台有 warning 提示 +- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": false}` +- [ ] 已登录用户如果 `needs_setup=True`,访问 workspace 会被引导到 `/setup` 完成改邮箱 / 改密码流程 ### 1.3 环境变量配置 @@ -76,19 +78,22 @@ make dev curl -s $BASE/api/v1/auth/setup-status | jq . ``` -**预期:** 返回 `{"needs_setup": false}`(admin 在启动时已自动创建,`count_users() > 0`)。仅在启动完成前的极短窗口内可能返回 `true`。 +**预期:** +- 干净数据库且尚未初始化 admin:返回 `{"needs_setup": true}` +- 已存在 admin:返回 `{"needs_setup": false}` -#### TC-API-02: Admin 首次登录 +#### TC-API-02: 首次初始化 Admin ```bash -curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<控制台密码>" \ +curl -s -X POST $BASE/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ -c cookies.txt | jq . ``` **预期:** -- 状态码 200 -- Body: `{"expires_in": 604800, "needs_setup": true}` +- 状态码 201 +- Body: `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}` - `cookies.txt` 包含 `access_token`(HttpOnly)和 `csrf_token`(非 HttpOnly) #### TC-API-03: 获取当前用户 @@ -97,9 +102,9 @@ curl -s -X POST $BASE/api/v1/auth/login/local \ curl -s $BASE/api/v1/auth/me -b cookies.txt | jq . ``` -**预期:** `{"id": "...", "email": "admin@deerflow.dev", "system_role": "admin", "needs_setup": true}` +**预期:** `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}` -#### TC-API-04: Setup 流程(改邮箱 + 改密码) +#### TC-API-04: 改密码流程 ```bash CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}') @@ -107,13 +112,36 @@ curl -s -X POST $BASE/api/v1/auth/change-password \ -b cookies.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"current_password":"<控制台密码>","new_password":"NewPass123!","new_email":"admin@example.com"}' | jq . + -d '{"current_password":"AdminPass1!","new_password":"NewPass123!"}' | jq . ``` **预期:** - 状态码 200 - `{"message": "Password changed successfully"}` -- 再调 `/auth/me` 邮箱变为 `admin@example.com`,`needs_setup` 变为 `false` +- 再调 `/auth/me` 仍为 `admin@example.com`,`needs_setup` 仍为 `false` + +#### TC-API-04a: reset_admin 后的 Setup 流程(改邮箱 + 改密码) + +```bash +cd backend +python -m app.gateway.auth.reset_admin --email admin@example.com +# 从 .deer-flow/admin_initial_credentials.txt 读取 reset 后密码 + +curl -s -X POST $BASE/api/v1/auth/login/local \ + -d "username=admin@example.com&password=<凭据文件密码>" \ + -c cookies.txt | jq . + +CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}') +curl -s -X POST $BASE/api/v1/auth/change-password \ + -b cookies.txt \ + -H "Content-Type: application/json" \ + -H "X-CSRF-Token: $CSRF" \ + -d '{"current_password":"<凭据文件密码>","new_password":"AdminPass2!","new_email":"admin2@example.com"}' | jq . +``` + +**预期:** +- 登录返回 `{"expires_in": 604800, "needs_setup": true}` +- `change-password` 后 `/auth/me` 邮箱变为 `admin2@example.com`,`needs_setup` 变为 `false` #### TC-API-05: 普通用户注册 @@ -493,7 +521,7 @@ curl -s -X POST $BASE/api/v1/auth/register \ ```bash # 检查数据库 -sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMIT 3;" +sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM users LIMIT 3;" ``` **预期:** `password_hash` 以 `$2b$` 开头(bcrypt 格式) @@ -506,24 +534,25 @@ sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMI ### 4.1 首次登录流程 -#### TC-UI-01: 访问首页跳转登录 +#### TC-UI-01: 无 admin 时访问 workspace 跳转 setup 1. 打开 `http://localhost:2026/workspace` -2. **预期:** 自动跳转到 `/login` +2. **预期:** 自动跳转到 `/setup` -#### TC-UI-02: Login 页面 +#### TC-UI-02: Setup 页面创建 admin -1. 输入 admin 邮箱和控制台密码 -2. 点击 Login -3. **预期:** 跳转到 `/setup`(因为 `needs_setup=true`) - -#### TC-UI-03: Setup 页面 - -1. 输入新邮箱、控制台密码(current)、新密码、确认密码 -2. 点击 Complete Setup +1. 输入 admin 邮箱、密码、确认密码 +2. 点击 Create Admin Account 3. **预期:** 跳转到 `/workspace` 4. 刷新页面不跳回 `/setup` +#### TC-UI-03: 已初始化后 Login 页面 + +1. 退出登录后访问 `/login` +2. 输入 admin 邮箱和密码 +3. 点击 Login +4. **预期:** 跳转到 `/workspace` + #### TC-UI-04: Setup 密码不匹配 1. 新密码和确认密码不一致 @@ -602,7 +631,7 @@ sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMI #### TC-UI-15: reset_admin 后重新登录 1. 执行 `cd backend && python -m app.gateway.auth.reset_admin` -2. 使用新密码登录 +2. 从 `.deer-flow/admin_initial_credentials.txt` 读取新密码并登录 3. **预期:** 跳转到 `/setup` 页面(`needs_setup` 被重置为 true) 4. 旧 session 已失效 @@ -645,18 +674,28 @@ make install make dev ``` -#### TC-UPG-01: 首次启动创建 admin +#### TC-UPG-01: 首次启动等待 admin 初始化 **预期:** -- [ ] 控制台输出 admin 邮箱(`admin@deerflow.dev`)和随机密码 +- [ ] 控制台不输出 admin 邮箱或随机密码 +- [ ] 访问 `/setup` 可创建第一个 admin - [ ] 无报错,正常启动 #### TC-UPG-02: 旧 Thread 迁移到 admin ```bash +# 创建第一个 admin +curl -s -X POST http://localhost:2026/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ + -c cookies.txt + +# 重启一次:启动迁移只在已有 admin 的启动路径执行 +make stop && make dev + # 登录 admin curl -s -X POST http://localhost:2026/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<控制台密码>" \ + -d "username=admin@example.com&password=AdminPass1!" \ -c cookies.txt # 查看 thread 列表 @@ -670,8 +709,8 @@ curl -s -X POST http://localhost:2026/api/threads/search \ **预期:** - [ ] 返回的 thread 数量 ≥ 旧版创建的数量 -- [ ] 控制台日志有 `Migrated N orphaned thread(s) to admin` -- [ ] 每个 thread 的 `metadata.owner_id` 都已被设为 admin 的 ID +- [ ] 控制台日志有 `Migrated N orphan LangGraph thread(s) to admin` +- [ ] 旧 thread 只对 admin 可见 #### TC-UPG-03: 旧 Thread 内容完整 @@ -683,7 +722,7 @@ curl -s http://localhost:2026/api/threads/ \ **预期:** - [ ] `metadata.title` 保留原值(如 `old-thread-1`) -- [ ] `metadata.owner_id` 已填充 +- [ ] 响应不回显服务端保留的 `user_id` / `owner_id` #### TC-UPG-04: 新用户看不到旧 Thread @@ -706,18 +745,19 @@ curl -s -X POST http://localhost:2026/api/threads/search \ ### 5.3 数据库 Schema 兼容 -#### TC-UPG-05: 无 users.db 时自动创建 +#### TC-UPG-05: 无 deerflow.db 时创建 schema 但不创建默认用户 ```bash -ls -la backend/.deer-flow/users.db +ls -la backend/.deer-flow/data/deerflow.db +sqlite3 backend/.deer-flow/data/deerflow.db "SELECT COUNT(*) FROM users;" ``` -**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列 +**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列;未调用 `/initialize` 前用户数为 0 -#### TC-UPG-06: users.db WAL 模式 +#### TC-UPG-06: deerflow.db WAL 模式 ```bash -sqlite3 backend/.deer-flow/users.db "PRAGMA journal_mode;" +sqlite3 backend/.deer-flow/data/deerflow.db "PRAGMA journal_mode;" ``` **预期:** 返回 `wal` @@ -768,9 +808,9 @@ make dev ``` **预期:** -- [ ] 服务正常启动(忽略 `users.db`,无 auth 相关代码不报错) +- [ ] 服务正常启动(忽略 `deerflow.db`,无 auth 相关代码不报错) - [ ] 旧对话数据仍然可访问 -- [ ] `users.db` 文件残留但不影响运行 +- [ ] `deerflow.db` 文件残留但不影响运行 #### TC-UPG-12: 再次升级到 auth 分支 @@ -781,51 +821,47 @@ make dev ``` **预期:** -- [ ] 识别已有 `users.db`,不重新创建 admin -- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `users.db`) +- [ ] 识别已有 `deerflow.db`,不重新创建 admin +- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `deerflow.db`) -### 5.7 休眠 Admin(初始密码未使用/未更改) +### 5.7 Admin 初始化与 reset_admin -> 首次启动生成 admin + 随机密码,但运维未登录、未改密码。 -> 密码只在首次启动的控制台闪过一次,后续启动不再显示。 +> 首次启动不生成默认 admin,也不在日志输出密码。忘记密码时走 `reset_admin`,新密码写入 0600 凭据文件。 -#### TC-UPG-13: 重启后自动重置密码并打印 +#### TC-UPG-13: 未初始化 admin 时重启不创建默认账号 ```bash -# 首次启动,记录密码 -rm -f backend/.deer-flow/users.db +rm -f backend/.deer-flow/data/deerflow.db make dev -# 控制台输出密码 P0,不登录 make stop -# 隔了几天,再次启动 make dev -# 控制台输出新密码 P1 +curl -s $BASE/api/v1/auth/setup-status | jq . ``` **预期:** -- [ ] 控制台输出 `Admin account setup incomplete — password reset` -- [ ] 输出新密码 P1(P0 已失效) -- [ ] 用 P1 可以登录,P0 不可以 -- [ ] 登录后 `needs_setup=true`,跳转 `/setup` -- [ ] `token_version` 递增(旧 session 如有也失效) +- [ ] 控制台不输出密码 +- [ ] `setup-status` 仍为 `{"needs_setup": true}` +- [ ] 访问 `/setup` 仍可创建第一个 admin -#### TC-UPG-14: 密码丢失 — 无需 CLI,重启即可 +#### TC-UPG-14: 密码丢失 — reset_admin 写入凭据文件 ```bash -# 忘记了控制台密码 → 直接重启服务 -make stop && make dev -# 控制台自动输出新密码 +python -m app.gateway.auth.reset_admin --email admin@example.com +ls -la backend/.deer-flow/admin_initial_credentials.txt +cat backend/.deer-flow/admin_initial_credentials.txt ``` **预期:** -- [ ] 无需 `reset_admin`,重启服务即可拿到新密码 -- [ ] `reset_admin` CLI 仍然可用作手动备选方案 +- [ ] 命令行只输出凭据文件路径,不输出明文密码 +- [ ] 凭据文件权限为 `0600` +- [ ] 凭据文件包含 email + password 行 +- [ ] 该用户下次登录返回 `needs_setup=true` -#### TC-UPG-15: 休眠 admin 期间普通用户注册 +#### TC-UPG-15: 未初始化 admin 期间普通用户注册策略边界 ```bash -# admin 存在但从未登录,普通用户先注册 +# admin 尚不存在,普通用户尝试注册 curl -s -X POST $BASE/api/v1/auth/register \ -H "Content-Type: application/json" \ -d '{"email":"earlybird@example.com","password":"EarlyPass1!"}' \ @@ -833,11 +869,11 @@ curl -s -X POST $BASE/api/v1/auth/register \ ``` **预期:** -- [ ] 注册成功(201),角色为 `user` -- [ ] 无法提权为 admin -- [ ] 普通用户的数据与 admin 隔离 +- [ ] 当前代码允许注册普通用户并自动登录(201,角色为 `user`) +- [ ] 但 `setup-status` 仍为 `{"needs_setup": true}`,因为 admin 仍不存在 +- [ ] 这是一个产品策略边界:若要求“必须先有 admin”,需要在 `/register` 增加 admin-exists gate -#### TC-UPG-16: 休眠 admin 不影响后续操作 +#### TC-UPG-16: 普通用户数据与后续 admin 隔离 ```bash # 普通用户正常创建 thread、发消息 @@ -849,14 +885,13 @@ curl -s -X POST $BASE/api/threads \ -d '{"metadata":{}}' | jq .thread_id ``` -**预期:** 正常创建,不受休眠 admin 影响 +**预期:** 普通用户正常创建 thread;后续 admin 创建后,搜索不到该普通用户 thread -#### TC-UPG-17: 休眠 admin 最终完成 Setup +#### TC-UPG-17: reset_admin 后完成 Setup ```bash -# 运维终于登录 curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=" \ + -d "username=admin@example.com&password=<凭据文件密码>" \ -c admin.txt | jq .needs_setup # 预期: true @@ -866,7 +901,7 @@ curl -s -X POST $BASE/api/v1/auth/change-password \ -b admin.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"current_password":"<密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \ + -d '{"current_password":"<凭据文件密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \ -c admin.txt # 验证 @@ -876,7 +911,7 @@ curl -s $BASE/api/v1/auth/me -b admin.txt | jq '{email, needs_setup}' **预期:** - [ ] `email` 变为 `admin@real.com` - [ ] `needs_setup` 变为 `false` -- [ ] 后续重启控制台不再有 warning +- [ ] 后续登录使用新密码 #### TC-UPG-18: 长期未用后 JWT 密钥轮换 @@ -890,8 +925,8 @@ make stop && make dev **预期:** - [ ] 服务正常启动 -- [ ] 旧密码仍可登录(密码存在 DB,与 JWT 密钥无关) -- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)— 但因为从未登录过也没有旧 token +- [ ] 账号密码仍可登录(密码存在 DB,与 JWT 密钥无关) +- [ ] 旧的 JWT token 失效(密钥变了签名不匹配) --- @@ -910,7 +945,7 @@ for i in 1 2 3; do done # 检查 admin 数量 -sqlite3 backend/.deer-flow/users.db \ +sqlite3 backend/.deer-flow/data/deerflow.db \ "SELECT COUNT(*) FROM users WHERE system_role='admin';" ``` @@ -1055,7 +1090,7 @@ curl -s -X POST $BASE/api/v1/auth/register \ wait # 检查用户数 -sqlite3 backend/.deer-flow/users.db \ +sqlite3 backend/.deer-flow/data/deerflow.db \ "SELECT COUNT(*) FROM users WHERE email='race@example.com';" ``` @@ -1165,13 +1200,16 @@ curl -s -w "%{http_code}" -X DELETE "$BASE/api/threads/$TID" \ ```bash cd backend python -m app.gateway.auth.reset_admin -# 记录密码 P1 +cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p1.txt +P1=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p1.txt) python -m app.gateway.auth.reset_admin -# 记录密码 P2 +cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p2.txt +P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt) ``` **预期:** +- [ ] `.deer-flow/admin_initial_credentials.txt` 每次都会被重写,文件权限为 `0600` - [ ] P1 ≠ P2(每次生成新随机密码) - [ ] P1 不可用,只有 P2 有效 - [ ] `token_version` 递增了 2 @@ -1324,7 +1362,8 @@ done ```bash GW=http://localhost:8001 -for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local /api/v1/auth/register; do +for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local \ + /api/v1/auth/register /api/v1/auth/initialize /api/v1/auth/logout; do echo "$path: $(curl -s -w '%{http_code}' -o /dev/null $GW$path)" done # 预期: 200 或 405/422(方法不对但不是 401) @@ -1399,9 +1438,9 @@ done > > 前置条件: > - `.env` 中设置 `AUTH_JWT_SECRET`(否则每次容器重启 session 全部失效) -> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `users.db`) +> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `deerflow.db`) -#### TC-DOCKER-01: users.db 通过 volume 持久化 +#### TC-DOCKER-01: deerflow.db 通过 volume 持久化 ```bash # 启动容器 @@ -1416,13 +1455,13 @@ curl -s -X POST $BASE/api/v1/auth/register \ -H "Content-Type: application/json" \ -d '{"email":"docker-test@example.com","password":"DockerTest1!"}' -w "\nHTTP %{http_code}" -# 检查宿主机上的 users.db -ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db -sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db \ +# 检查宿主机上的 deerflow.db +ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db +sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db \ "SELECT email FROM users WHERE email='docker-test@example.com';" ``` -**预期:** users.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。 +**预期:** deerflow.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。 #### TC-DOCKER-02: 重启容器后 session 保持 @@ -1466,22 +1505,24 @@ done **已知限制:** In-process rate limiter 不跨 worker 共享。生产环境如需精确限速,需要 Redis 等外部存储。 -#### TC-DOCKER-04: IM 渠道不经过 auth +#### TC-DOCKER-04: IM 渠道使用内部认证 ```bash -# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 通信 -# 不走 nginx,不经过 AuthMiddleware +# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 调 Gateway +# 请求携带 process-local internal auth header,并带匹配的 CSRF cookie/header # 验证方式:检查 gateway 日志中 channel manager 的请求不包含 auth 错误 docker logs deer-flow-gateway 2>&1 | grep -E "ChannelManager|channel" | head -10 ``` -**预期:** 无 auth 相关错误。渠道通过 `langgraph-sdk` 直连 LangGraph Server(`http://langgraph:2024`),不走 auth 层。 +**预期:** 无 auth 相关错误。渠道不依赖浏览器 cookie;服务端通过内部认证头把请求归入 `default` 用户桶。 -#### TC-DOCKER-05: admin 密码写入 0600 凭证文件(不再走日志) +#### TC-DOCKER-05: reset_admin 密码写入 0600 凭证文件(不再走日志) ```bash -# 凭证文件写在挂载到宿主机的 DEER_FLOW_HOME 下 +# 首次启动不会自动生成 admin 密码。先重置已有 admin,凭据文件写在挂载到宿主机的 DEER_FLOW_HOME 下。 +docker exec deer-flow-gateway python -m app.gateway.auth.reset_admin --email docker-test@example.com + ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/admin_initial_credentials.txt # 预期文件权限: -rw------- (0600) @@ -1512,14 +1553,15 @@ sleep 15 docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l # 预期: 0 -# auth 流程正常 +# auth 流程正常:未登录受保护接口返回 401 curl -s -w "%{http_code}" -o /dev/null $BASE/api/models # 预期: 401 -curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<日志密码>" \ +curl -s -X POST $BASE/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ -c cookies.txt -w "\nHTTP %{http_code}" -# 预期: 200 +# 预期: 201 ``` ### 7.4 补充边界用例 @@ -1587,13 +1629,15 @@ curl -s -D - -X POST $BASE/api/v1/auth/login/local \ #### TC-EDGE-05: HTTP 无 max_age / HTTPS 有 max_age ```bash +GW=http://localhost:8001 + # HTTP -curl -s -D - -X POST $BASE/api/v1/auth/login/local \ +curl -s -D - -X POST $GW/api/v1/auth/login/local \ -d "username=admin@example.com&password=正确密码" 2>/dev/null \ | grep "access_token=" | grep -oi "max-age=[0-9]*" || echo "NO max-age (HTTP session cookie)" -# HTTPS -curl -s -D - -X POST $BASE/api/v1/auth/login/local \ +# HTTPS:直连 Gateway 才能用 X-Forwarded-Proto 模拟 HTTPS;nginx 会覆盖该 header +curl -s -D - -X POST $GW/api/v1/auth/login/local \ -H "X-Forwarded-Proto: https" \ -d "username=admin@example.com&password=正确密码" 2>/dev/null \ | grep "access_token=" | grep -oi "max-age=[0-9]*" @@ -1712,10 +1756,10 @@ curl -s -X POST $BASE/api/threads \ -b cookies.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"metadata":{"owner_id":"victim-user-id"}}' | jq .metadata.owner_id + -d '{"metadata":{"owner_id":"victim-user-id","user_id":"victim-user-id"}}' | jq .metadata ``` -**预期:** 返回的 `metadata.owner_id` 应为当前登录用户的 ID,不是请求中注入的 `victim-user-id`。服务端应覆盖客户端提供的 `user_id`。 +**预期:** 返回的 `metadata` 不包含 `owner_id` 或 `user_id`。真实所有权写入 `threads_meta.user_id`,不从客户端 metadata 接收,也不通过 metadata 回显。 #### 7.5.6 HTTP Method 探测 @@ -1796,6 +1840,6 @@ cd backend && PYTHONPATH=. uv run pytest \ # 核心接口冒烟 curl -s $BASE/health # 200 curl -s $BASE/api/models # 401 (无 cookie) -curl -s -X POST $BASE/api/v1/auth/setup-status # 200 +curl -s $BASE/api/v1/auth/setup-status # 200 curl -s $BASE/api/v1/auth/me -b cookies.txt # 200 (有 cookie) ``` diff --git a/backend/docs/AUTH_UPGRADE.md b/backend/docs/AUTH_UPGRADE.md index 344c488c4..b54283d24 100644 --- a/backend/docs/AUTH_UPGRADE.md +++ b/backend/docs/AUTH_UPGRADE.md @@ -2,13 +2,16 @@ DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。 +完整设计见 [AUTH_DESIGN.md](AUTH_DESIGN.md)。 + ## 核心概念 认证模块采用**始终强制**策略: -- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志 +- 首次启动时不会自动创建账号;首次访问 `/setup` 时由操作者创建第一个 admin 账号 - 认证从一开始就是强制的,无竞争窗口 -- 历史对话(升级前创建的 thread)自动迁移到 admin 名下 +- 已有 admin 后,服务启动时会把历史对话(升级前创建且缺少 `user_id` 的 thread)迁移到 admin 名下 +- 新数据按用户隔离:thread、workspace/uploads/outputs、memory、自定义 agent 都归属当前用户 ## 升级步骤 @@ -25,39 +28,41 @@ cd backend && make install make dev ``` -控制台会输出: +如果没有 admin 账号,控制台只会提示: ``` ============================================================ - Admin account created on first boot - Email: admin@deerflow.dev - Password: aB3xK9mN_pQ7rT2w - Change it after login: Settings → Account + First boot detected — no admin account exists. + Visit /setup to complete admin account creation. ============================================================ ``` -如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。 +首次启动不会在日志里打印随机密码,也不会写入默认 admin。这样避免启动日志泄露凭据,也避免在操作者创建账号前出现可被猜测的默认身份。 -### 3. 登录 +### 3. 创建 admin -访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。 +访问 `http://localhost:2026/setup`,填写邮箱和密码创建第一个 admin 账号。创建成功后会自动登录并进入 workspace。 -### 4. 修改密码 +如果这是从无认证版本升级,创建 admin 后重启一次服务,让启动迁移把缺少 `user_id` 的历史 thread 归属到 admin。 -登录后进入 Settings → Account → Change Password。 +### 4. 登录 + +后续访问 `http://localhost:2026/login`,使用已创建的邮箱和密码登录。 ### 5. 添加用户(可选) -其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。 +其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话、上传文件、输出文件、memory 和自定义 agent。 ## 安全机制 | 机制 | 说明 | |------|------| | JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 | -| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` | +| CSRF Double Submit Cookie | 受保护的 POST/PUT/PATCH/DELETE 请求需携带 `X-CSRF-Token`;登录/注册/初始化/登出走 auth 端点 Origin 校验 | | bcrypt 密码哈希 | 密码不以明文存储 | -| 多租户隔离 | 用户只能访问自己的 thread | +| Thread owner filter | `threads_meta.user_id` 由服务端认证上下文写入,搜索、读取、更新、删除默认按当前用户过滤 | +| 文件系统隔离 | 线程数据写入 `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/`,sandbox 内统一映射为 `/mnt/user-data/` | +| Memory / agent 隔离 | 用户 memory 和自定义 agent 写入 `{base_dir}/users/{user_id}/...`;旧共享 agent 只作为只读兼容回退 | | HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 | ## 常见操作 @@ -74,23 +79,27 @@ python -m app.gateway.auth.reset_admin python -m app.gateway.auth.reset_admin --email user@example.com ``` -会输出新的随机密码。 +会把新的随机密码写入 `.deer-flow/admin_initial_credentials.txt`,文件权限为 `0600`。命令行只输出文件路径,不输出明文密码。 ### 完全重置 -删除用户数据库,重启后自动创建新 admin: +删除统一 SQLite 数据库,重启后重新访问 `/setup` 创建新 admin: ```bash -rm -f backend/.deer-flow/users.db -# 重启服务,控制台输出新密码 +rm -f backend/.deer-flow/data/deerflow.db +# 重启服务后访问 http://localhost:2026/setup ``` ## 数据存储 | 文件 | 内容 | |------|------| -| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) | -| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) | +| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库(users、threads_meta、runs、feedback 等应用数据) | +| `.deer-flow/users/{user_id}/threads/{thread_id}/user-data/` | 用户线程的 workspace、uploads、outputs | +| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory | +| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | +| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | +| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持) | ### 生产环境建议 @@ -111,19 +120,21 @@ python -c "import secrets; print(secrets.token_urlsafe(32))" | `/api/v1/auth/me` | GET | 获取当前用户信息 | | `/api/v1/auth/change-password` | POST | 修改密码 | | `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 | +| `/api/v1/auth/initialize` | POST | 首次初始化第一个 admin(仅无 admin 时可调用) | ## 兼容性 -- **标准模式**(`make dev`):完全兼容,admin 自动创建 +- **标准模式**(`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化 - **Gateway 模式**(`make dev-pro`):完全兼容 -- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载 -- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层 +- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载 +- **IM 渠道**(Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶 - **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响 ## 故障排查 | 症状 | 原因 | 解决 | |------|------|------| -| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` | +| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` | +| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin | | 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 | -| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 | +| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 | diff --git a/backend/docs/README.md b/backend/docs/README.md index da566005d..27e33f854 100644 --- a/backend/docs/README.md +++ b/backend/docs/README.md @@ -8,6 +8,7 @@ This directory contains detailed documentation for the DeerFlow backend. |----------|-------------| | [ARCHITECTURE.md](ARCHITECTURE.md) | System architecture overview | | [API.md](API.md) | Complete API reference | +| [AUTH_DESIGN.md](AUTH_DESIGN.md) | User authentication, CSRF, and per-user isolation design | | [CONFIGURATION.md](CONFIGURATION.md) | Configuration options | | [SETUP.md](SETUP.md) | Quick setup guide | @@ -42,6 +43,7 @@ docs/ ├── README.md # This file ├── ARCHITECTURE.md # System architecture ├── API.md # API reference +├── AUTH_DESIGN.md # User authentication and isolation design ├── CONFIGURATION.md # Configuration guide ├── SETUP.md # Setup instructions ├── FILE_UPLOAD.md # File upload feature diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index b2a147bce..129a28c66 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -40,6 +40,15 @@ class MemoryUpdateQueue: self._timer: threading.Timer | None = None self._processing = False + @staticmethod + def _queue_key( + thread_id: str, + user_id: str | None, + agent_name: str | None, + ) -> tuple[str, str | None, str | None]: + """Return the debounce identity for a memory update target.""" + return (thread_id, user_id, agent_name) + def add( self, thread_id: str, @@ -115,8 +124,9 @@ class MemoryUpdateQueue: correction_detected: bool, reinforcement_detected: bool, ) -> None: + queue_key = self._queue_key(thread_id, user_id, agent_name) existing_context = next( - (context for context in self._queue if context.thread_id == thread_id), + (context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key), None, ) merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) @@ -130,7 +140,7 @@ class MemoryUpdateQueue: reinforcement_detected=merged_reinforcement_detected, ) - self._queue = [c for c in self._queue if c.thread_id != thread_id] + self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key] self._queue.append(context) def _reset_timer(self) -> None: diff --git a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py index dafa7d977..307548e0a 100644 --- a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py +++ b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py @@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_ from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent from deerflow.config.memory_config import get_memory_config +from deerflow.runtime.user_context import resolve_runtime_user_id def memory_flush_hook(event: SummarizationEvent) -> None: @@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None: correction_detected = detect_correction(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) + user_id = resolve_runtime_user_id(event.runtime) queue = get_memory_queue() queue.add_nowait( thread_id=event.thread_id, messages=filtered_messages, agent_name=event.agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 5bb54f3e5..000ca51a2 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): return "[Tool call was interrupted and did not return a result.]" def _build_patched_messages(self, messages: list) -> list | None: - """Return a new message list with patches inserted at the correct positions. + """Return messages with tool results grouped after their tool-call AIMessage. - For each AIMessage with dangling tool_calls (no corresponding ToolMessage), - a synthetic ToolMessage is inserted immediately after that AIMessage. - Returns None if no patches are needed. + This normalizes model-bound causal order before provider serialization while + preserving already-valid transcripts unchanged. """ - # Collect IDs of all existing ToolMessages - existing_tool_msg_ids: set[str] = set() + tool_messages_by_id: dict[str, ToolMessage] = {} for msg in messages: if isinstance(msg, ToolMessage): - existing_tool_msg_ids.add(msg.tool_call_id) + tool_messages_by_id.setdefault(msg.tool_call_id, msg) - # Check if any patching is needed - needs_patch = False + tool_call_ids: set[str] = set() for msg in messages: if getattr(msg, "type", None) != "ai": continue for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if tc_id and tc_id not in existing_tool_msg_ids: - needs_patch = True - break - if needs_patch: - break + if tc_id: + tool_call_ids.add(tc_id) - if not needs_patch: - return None - - # Build new list with patches inserted right after each dangling AIMessage patched: list = [] - patched_ids: set[str] = set() + consumed_tool_msg_ids: set[str] = set() patch_count = 0 for msg in messages: + if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: + continue + patched.append(msg) if getattr(msg, "type", None) != "ai": continue + for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: + if not tc_id or tc_id in consumed_tool_msg_ids: + continue + + existing_tool_msg = tool_messages_by_id.get(tc_id) + if existing_tool_msg is not None: + patched.append(existing_tool_msg) + consumed_tool_msg_ids.add(tc_id) + else: patched.append( ToolMessage( content=self._synthetic_tool_message_content(tc), @@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): status="error", ) ) - patched_ids.add(tc_id) + consumed_tool_msg_ids.add(tc_id) patch_count += 1 - logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") + if patched == messages: + return None + + if patch_count: + logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") return patched @override diff --git a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py index b8cd10884..9215aefc5 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py @@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list. Additionally, this middleware prevents the agent from exiting the loop while there are still incomplete todo items. When the model produces a final response -(no tool calls) but todos are not yet complete, the middleware injects a reminder -and jumps back to the model node to force continued engagement. +(no tool calls) but todos are not yet complete, the middleware queues a reminder +for the next model request and jumps back to the model node to force continued +engagement. The completion reminder is injected via ``wrap_model_call`` instead +of being persisted into graph state as a normal user-visible message. """ from __future__ import annotations +import threading +from collections.abc import Awaitable, Callable from typing import Any, override from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware.todo import PlanningState, Todo -from langchain.agents.middleware.types import hook_config +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config from langchain_core.messages import AIMessage, HumanMessage from langgraph.runtime import Runtime @@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str: return "\n".join(lines) +def _format_completion_reminder(todos: list[Todo]) -> str: + """Format a completion reminder for incomplete todo items.""" + incomplete = [t for t in todos if t.get("status") != "completed"] + incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) + return ( + "\n" + "You have incomplete todo items that must be finished before giving your final response:\n\n" + f"{incomplete_text}\n\n" + "Please continue working on these tasks. Call `write_todos` to mark items as completed " + "as you finish them, and only respond when all items are done.\n" + "" + ) + + +_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"} + + +def _has_tool_call_intent_or_error(message: AIMessage) -> bool: + """Return True when an AIMessage is not a clean final answer. + + Todo completion reminders should only fire when the model has produced a + plain final response. Provider/tool parsing details have moved across + LangChain versions and integrations, so keep all tool-intent/error signals + behind this helper instead of checking one concrete field at the call site. + """ + if message.tool_calls: + return True + + if getattr(message, "invalid_tool_calls", None): + return True + + # Backward/provider compatibility: some integrations preserve raw or legacy + # tool-call intent in additional_kwargs even when structured tool_calls is + # empty. If this helper changes, update the matching sentinel test + # `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`; + # if that test fails after a LangChain upgrade, review this helper so new + # tool-call/error fields are not silently treated as clean final answers. + additional_kwargs = getattr(message, "additional_kwargs", {}) or {} + if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"): + return True + + response_metadata = getattr(message, "response_metadata", {}) or {} + return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS + + class TodoMiddleware(TodoListMiddleware): """Extends TodoListMiddleware with `write_todos` context-loss detection. @@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware): formatted = _format_todos(todos) reminder = HumanMessage( name="todo_reminder", + additional_kwargs={"hide_from_ui": True}, content=( "\n" "Your todo list from earlier is no longer visible in the current context window, " @@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware): # Maximum number of completion reminders before allowing the agent to exit. # This prevents infinite loops when the agent cannot make further progress. _MAX_COMPLETION_REMINDERS = 2 + # Hard cap for per-run reminder bookkeeping in long-lived middleware instances. + _MAX_COMPLETION_REMINDER_KEYS = 4096 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._lock = threading.Lock() + self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {} + self._completion_reminder_counts: dict[tuple[str, str], int] = {} + self._completion_reminder_touch_order: dict[tuple[str, str], int] = {} + self._completion_reminder_next_order = 0 + + @staticmethod + def _get_thread_id(runtime: Runtime) -> str: + context = getattr(runtime, "context", None) + thread_id = context.get("thread_id") if context else None + return str(thread_id) if thread_id else "default" + + @staticmethod + def _get_run_id(runtime: Runtime) -> str: + context = getattr(runtime, "context", None) + run_id = context.get("run_id") if context else None + return str(run_id) if run_id else "default" + + def _pending_key(self, runtime: Runtime) -> tuple[str, str]: + return self._get_thread_id(runtime), self._get_run_id(runtime) + + def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None: + self._completion_reminder_next_order += 1 + self._completion_reminder_touch_order[key] = self._completion_reminder_next_order + + def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]: + keys = set(self._pending_completion_reminders) + keys.update(self._completion_reminder_counts) + keys.update(self._completion_reminder_touch_order) + return keys + + def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None: + self._pending_completion_reminders.pop(key, None) + self._completion_reminder_counts.pop(key, None) + self._completion_reminder_touch_order.pop(key, None) + + def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None: + keys = self._completion_reminder_keys_locked() + overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS + if overflow <= 0: + return + + candidates = [key for key in keys if key != protected_key] + candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0)) + for key in candidates[:overflow]: + self._drop_completion_reminder_key_locked(key) + + def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None: + key = self._pending_key(runtime) + with self._lock: + self._pending_completion_reminders.setdefault(key, []).append(reminder) + self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1 + self._touch_completion_reminder_key_locked(key) + self._prune_completion_reminder_state_locked(protected_key=key) + + def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int: + key = self._pending_key(runtime) + with self._lock: + return self._completion_reminder_counts.get(key, 0) + + def _drain_completion_reminders(self, runtime: Runtime) -> list[str]: + key = self._pending_key(runtime) + with self._lock: + reminders = self._pending_completion_reminders.pop(key, []) + if reminders or key in self._completion_reminder_counts: + self._touch_completion_reminder_key_locked(key) + return reminders + + def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None: + thread_id, current_run_id = self._pending_key(runtime) + with self._lock: + for key in self._completion_reminder_keys_locked(): + if key[0] == thread_id and key[1] != current_run_id: + self._drop_completion_reminder_key_locked(key) + + def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None: + key = self._pending_key(runtime) + with self._lock: + self._drop_completion_reminder_key_locked(key) + + @override + def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_other_run_completion_reminders(runtime) + return None + + @override + async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_other_run_completion_reminders(runtime) + return None @hook_config(can_jump_to=["model"]) @override @@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware): if base_result is not None: return base_result - # 2. Only intervene when the agent wants to exit (no tool calls). + # 2. Only intervene when the agent wants to exit cleanly. Tool-call + # intent or tool-call parse errors should be handled by the tool path + # instead of being masked by todo reminders. messages = state.get("messages") or [] last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None) - if not last_ai or last_ai.tool_calls: + if not last_ai or _has_tool_call_intent_or_error(last_ai): return None # 3. Allow exit when all todos are completed or there are no todos. @@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware): return None # 4. Enforce a reminder cap to prevent infinite re-engagement loops. - if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS: + if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS: return None - # 5. Inject a reminder and force the agent back to the model. - incomplete = [t for t in todos if t.get("status") != "completed"] - incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) - reminder = HumanMessage( - name="todo_completion_reminder", - content=( - "\n" - "You have incomplete todo items that must be finished before giving your final response:\n\n" - f"{incomplete_text}\n\n" - "Please continue working on these tasks. Call `write_todos` to mark items as completed " - "as you finish them, and only respond when all items are done.\n" - "" - ), - ) - return {"jump_to": "model", "messages": [reminder]} + # 5. Queue a reminder for the next model request and jump back. We must + # not persist this control prompt as a normal HumanMessage, otherwise it + # can leak into user-visible message streams and saved transcripts. + self._queue_completion_reminder(runtime, _format_completion_reminder(todos)) + return {"jump_to": "model"} @override @hook_config(can_jump_to=["model"]) @@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware): ) -> dict[str, Any] | None: """Async version of after_model.""" return self.after_model(state, runtime) + + @staticmethod + def _format_pending_completion_reminders(reminders: list[str]) -> str: + return "\n\n".join(dict.fromkeys(reminders)) + + def _augment_request(self, request: ModelRequest) -> ModelRequest: + reminders = self._drain_completion_reminders(request.runtime) + if not reminders: + return request + new_messages = [ + *request.messages, + HumanMessage( + content=self._format_pending_completion_reminders(reminders), + name="todo_completion_reminder", + additional_kwargs={"hide_from_ui": True}, + ), + ] + return request.override(messages=new_messages) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(self._augment_request(request)) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(self._augment_request(request)) + + @override + def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_current_run_completion_reminders(runtime) + return None + + @override + async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_current_run_completion_reminders(runtime) + return None diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py index f59e7f2b7..0d3607faf 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py @@ -9,7 +9,7 @@ from typing import Any, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.todo import Todo -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str: return "thinking" +def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool: + """Return True if the AIMessage contains a tool_call with the given id.""" + for tc in message.tool_calls or []: + if isinstance(tc, dict): + if tc.get("id") == tool_call_id: + return True + elif hasattr(tc, "id") and tc.id == tool_call_id: + return True + return False + + def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]: tool_calls = getattr(message, "tool_calls", None) or [] actions: list[dict[str, Any]] = [] @@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware): if not messages: return None + # Annotate subagent token usage onto the AIMessage that dispatched it. + # When a task tool completes, its usage is cached by tool_call_id. Detect + # the ToolMessage → search backward for the corresponding AIMessage → merge. + # Walk backward through consecutive ToolMessages before the new AIMessage + # so that multiple concurrent task tool calls all get their subagent tokens + # written back to the same dispatch message (merging into one update). + state_updates: dict[int, AIMessage] = {} + if len(messages) >= 2: + from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage + + idx = len(messages) - 2 + while idx >= 0: + tool_msg = messages[idx] + if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id: + break + + subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id) + if subagent_usage: + # Search backward from the ToolMessage to find the AIMessage + # that dispatched it. A single model response can dispatch + # multiple task tool calls, so we can't assume a fixed offset. + dispatch_idx = idx - 1 + while dispatch_idx >= 0: + candidate = messages[dispatch_idx] + if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id): + # Accumulate into an existing update for the same + # AIMessage (multiple task calls in one response), + # or merge fresh from the original message. + existing_update = state_updates.get(dispatch_idx) + prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {}) + merged = { + **prev, + "input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"], + "output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"], + "total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"], + } + state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged}) + break + dispatch_idx -= 1 + idx -= 1 + last = messages[-1] if not isinstance(last, AIMessage): + if state_updates: + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} return None usage = getattr(last, "usage_metadata", None) @@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware): additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {}) if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution: - return None + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs}) - return {"messages": [updated_msg]} + state_updates[len(messages) - 1] = updated_msg + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: diff --git a/backend/packages/harness/deerflow/persistence/json_compat.py b/backend/packages/harness/deerflow/persistence/json_compat.py new file mode 100644 index 000000000..442b29e22 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/json_compat.py @@ -0,0 +1,195 @@ +"""Dialect-aware JSON value matching for SQLAlchemy (SQLite + PostgreSQL).""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import BigInteger, Float, String, bindparam +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.expression import ColumnElement +from sqlalchemy.sql.visitors import InternalTraversal +from sqlalchemy.types import Boolean, TypeEngine + +# Key is interpolated into compiled SQL; restrict charset to prevent injection. +_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$") + +# Allowed value types for metadata filter values (same set accepted by JsonMatch). +ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str) + +# SQLite raises an overflow when binding values outside signed 64-bit range; +# PostgreSQL overflows during BIGINT cast. Reject at validation time instead. +_INT64_MIN = -(2**63) +_INT64_MAX = 2**63 - 1 + + +def validate_metadata_filter_key(key: object) -> bool: + """Return True if *key* is safe for use as a JSON metadata filter key. + + A key is "safe" when it is a string matching ``[A-Za-z0-9_-]+``. The + charset is restricted because the key is interpolated into the + compiled SQL path expression (``$.""`` / ``->`` literal), so any + laxer pattern would open a SQL/JSONPath injection surface. + """ + return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key)) + + +def validate_metadata_filter_value(value: object) -> bool: + """Return True if *value* is an allowed type for a JSON metadata filter. + + Matches the set of types ``_build_clause`` knows how to compile into + a dialect-portable predicate. Anything else (list/dict/bytes/...) is + intentionally rejected rather than silently coerced via ``str()`` — + silent coercion would (a) produce wrong matches and (b) break + SQLAlchemy's ``inherit_cache`` invariant when ``value`` is unhashable. + + Integer values are additionally restricted to the signed 64-bit range + ``[-2**63, 2**63 - 1]``: SQLite overflows when binding larger values + and PostgreSQL overflows during the ``BIGINT`` cast. + """ + if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES): + return False + if isinstance(value, int) and not isinstance(value, bool): + if not (_INT64_MIN <= value <= _INT64_MAX): + return False + return True + + +class JsonMatch(ColumnElement): + """Dialect-portable ``column[key] == value`` for JSON columns. + + Compiles to ``json_type``/``json_extract`` on SQLite and + ``json_typeof``/``->>`` on PostgreSQL, with type-safe comparison + that distinguishes bool vs int and NULL vs missing key. + + *key* must be a single literal key matching ``[A-Za-z0-9_-]+``. + *value* must be one of: ``None``, ``bool``, ``int`` (signed 64-bit), ``float``, ``str``. + """ + + inherit_cache = True + type = Boolean() + _is_implicitly_boolean = True + + _traverse_internals = [ + ("column", InternalTraversal.dp_clauseelement), + ("key", InternalTraversal.dp_string), + ("value", InternalTraversal.dp_plain_obj), + ] + + def __init__(self, column: ColumnElement, key: str, value: object) -> None: + if not validate_metadata_filter_key(key): + raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}") + if not validate_metadata_filter_value(value): + if isinstance(value, int) and not isinstance(value, bool): + raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}") + raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}") + self.column = column + self.key = key + self.value = value + super().__init__() + + +@dataclass(frozen=True) +class _Dialect: + """Per-dialect names used when emitting JSON type/value comparisons.""" + + null_type: str + num_types: tuple[str, ...] + num_cast: str + int_types: tuple[str, ...] + int_cast: str + # None for SQLite where json_type already returns 'integer'/'real'; + # regex literal for PostgreSQL where json_typeof returns 'number' for + # both ints and floats, so an extra guard prevents CAST errors on floats. + int_guard: str | None + string_type: str + bool_type: str | None + + +_SQLITE = _Dialect( + null_type="null", + num_types=("integer", "real"), + num_cast="REAL", + int_types=("integer",), + int_cast="INTEGER", + int_guard=None, + string_type="text", + bool_type=None, +) + +_PG = _Dialect( + null_type="null", + num_types=("number",), + num_cast="DOUBLE PRECISION", + int_types=("number",), + int_cast="BIGINT", + int_guard="'^-?[0-9]+$'", + string_type="string", + bool_type="boolean", +) + + +def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str: + param = bindparam(None, value, type_=sa_type) + return compiler.process(param, **kw) + + +def _type_check(typeof: str, types: tuple[str, ...]) -> str: + if len(types) == 1: + return f"{typeof} = '{types[0]}'" + quoted = ", ".join(f"'{t}'" for t in types) + return f"{typeof} IN ({quoted})" + + +def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str: + if value is None: + return f"{typeof} = '{dialect.null_type}'" + if isinstance(value, bool): + # bool check must precede int check — bool is a subclass of int in Python + bool_str = "true" if value else "false" + if dialect.bool_type is None: + return f"{typeof} = '{bool_str}'" + return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')" + if isinstance(value, int): + bp = _bind(compiler, value, BigInteger(), **kw) + if dialect.int_guard: + # CASE prevents CAST error when json_typeof = 'number' also matches floats + return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" + return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})" + if isinstance(value, float): + bp = _bind(compiler, value, Float(), **kw) + return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})" + bp = _bind(compiler, str(value), String(), **kw) + return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})" + + +@compiles(JsonMatch, "sqlite") +def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + path = f'$."{element.key}"' + typeof = f"json_type({col}, '{path}')" + extract = f"json_extract({col}, '{path}')" + return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw) + + +@compiles(JsonMatch, "postgresql") +def _compile_pg(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + typeof = f"json_typeof({col} -> '{element.key}')" + extract = f"({col} ->> '{element.key}')" + return _build_clause(compiler, typeof, extract, element.value, _PG, **kw) + + +@compiles(JsonMatch) +def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + raise NotImplementedError(f"JsonMatch supports only sqlite and postgresql; got dialect: {compiler.dialect.name}") + + +def json_match(column: ColumnElement, key: str, value: object) -> JsonMatch: + return JsonMatch(column, key, value) diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 430fbe4f6..5331451e3 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -223,10 +223,11 @@ class RunRepository(RunStore): """Aggregate token usage via a single SQL GROUP BY query.""" _completed = RunRow.status.in_(("success", "error")) _thread = RunRow.thread_id == thread_id + model_name = func.coalesce(RunRow.model_name, "unknown") stmt = ( select( - func.coalesce(RunRow.model_name, "unknown").label("model"), + model_name.label("model"), func.count().label("runs"), func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"), func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"), @@ -236,7 +237,7 @@ class RunRepository(RunStore): func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"), ) .where(_thread, _completed) - .group_by(func.coalesce(RunRow.model_name, "unknown")) + .group_by(model_name) ) async with self._sf() as session: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py index 080ce8093..b5231f0f9 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.sql import ThreadMetaRepository @@ -14,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker __all__ = [ + "InvalidMetadataFilterError", "MemoryThreadMetaStore", "ThreadMetaRepository", "ThreadMetaRow", diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index c87c10a16..ed55ade8e 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -15,10 +15,15 @@ three-state semantics (see :mod:`deerflow.runtime.user_context`): from __future__ import annotations import abc +from typing import Any from deerflow.runtime.user_context import AUTO, _AutoSentinel +class InvalidMetadataFilterError(ValueError): + """Raised when all client-supplied metadata filter keys are rejected.""" + + class ThreadMetaStore(abc.ABC): @abc.abstractmethod async def create( @@ -40,12 +45,12 @@ class ThreadMetaStore(abc.ABC): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: pass @abc.abstractmethod diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index fbe66fdaf..4f642a938 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -69,12 +69,12 @@ class MemoryThreadMetaStore(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") filter_dict: dict[str, Any] = {} if metadata: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 688fbb247..0d3f587de 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -2,16 +2,20 @@ from __future__ import annotations +import logging from datetime import UTC, datetime from typing import Any from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.json_compat import json_match +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +logger = logging.getLogger(__name__) + class ThreadMetaRepository(ThreadMetaStore): def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: @@ -20,7 +24,7 @@ class ThreadMetaRepository(ThreadMetaStore): @staticmethod def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]: d = row.to_dict() - d["metadata"] = d.pop("metadata_json", {}) + d["metadata"] = d.pop("metadata_json", None) or {} for key in ("created_at", "updated_at"): val = d.get(key) if isinstance(val, datetime): @@ -104,39 +108,43 @@ class ThreadMetaRepository(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: """Search threads with optional metadata and status filters. Owner filter is enforced by default: caller must be in a user context. Pass ``user_id=None`` to bypass (migration/CLI). """ resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search") - stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) + stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc(), ThreadMetaRow.thread_id.desc()) if resolved_user_id is not None: stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id) if status: stmt = stmt.where(ThreadMetaRow.status == status) if metadata: - # When metadata filter is active, fetch a larger window and filter - # in Python. TODO(Phase 2): use JSON DB operators (Postgres @>, - # SQLite json_extract) for server-side filtering. - stmt = stmt.limit(limit * 5 + offset) - async with self._sf() as session: - result = await session.execute(stmt) - rows = [self._row_to_dict(r) for r in result.scalars()] - rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())] - return rows[offset : offset + limit] - else: - stmt = stmt.limit(limit).offset(offset) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] + applied = 0 + for key, value in metadata.items(): + try: + stmt = stmt.where(json_match(ThreadMetaRow.metadata_json, key, value)) + applied += 1 + except (ValueError, TypeError) as exc: + logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc) + if applied == 0: + # Comma-separated plain string (no list repr / nested + # quoting) so the 400 detail surfaced by the Gateway is + # easy for clients to read. Sorted for determinism. + rejected_keys = ", ".join(sorted(str(k) for k in metadata)) + raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}") + + stmt = stmt.limit(limit).offset(offset) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool: """Return True if the row exists and is owned (or filter bypassed).""" diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 9374769f3..b7e54754f 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -11,7 +11,7 @@ import logging from datetime import UTC, datetime from typing import Any -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow @@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore): user = get_current_user() return str(user.id) if user is not None else None + @staticmethod + async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None: + """Return the current max seq while serializing writers per thread. + + PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate + results are not lockable rows. As a release-safe workaround, take a + transaction-level advisory lock keyed by thread_id before reading the + aggregate. Other dialects keep the existing row-locking statement. + """ + stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id) + bind = session.get_bind() + dialect_name = bind.dialect.name if bind is not None else "" + + if dialect_name == "postgresql": + await session.execute( + text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"), + {"thread_id": thread_id}, + ) + return await session.scalar(stmt) + + return await session.scalar(stmt.with_for_update()) + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 """Write a single event — low-frequency path only. @@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore): user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): - # Use FOR UPDATE to serialize seq assignment within a thread. - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = (max_seq or 0) + 1 row = RunEventRow( thread_id=thread_id, @@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore): async with self._sf() as session: async with session.begin(): # Get max seq for the thread (assume all events in batch belong to same thread). - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. thread_id = events[0]["thread_id"] - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = max_seq or 0 rows = [] for e in events: diff --git a/backend/packages/harness/deerflow/runtime/user_context.py b/backend/packages/harness/deerflow/runtime/user_context.py index ffe4be690..cfbb68c94 100644 --- a/backend/packages/harness/deerflow/runtime/user_context.py +++ b/backend/packages/harness/deerflow/runtime/user_context.py @@ -109,6 +109,34 @@ def get_effective_user_id() -> str: return str(user.id) +def resolve_runtime_user_id(runtime: object | None) -> str: + """Single source of truth for a tool/middleware's effective user_id. + + Resolution order (most authoritative first): + 1. ``runtime.context["user_id"]`` — set by ``inject_authenticated_user_context`` + in the gateway from the auth-validated ``request.state.user``. This is + the only source that survives boundaries where the contextvar may have + been lost (background tasks scheduled outside the request task, + worker pools that don't copy_context, future cross-process drivers). + 2. The ``_current_user`` ContextVar — set by the auth middleware at + request entry. Reliable for in-task work; copied by ``asyncio`` + child tasks and by ``ContextThreadPoolExecutor``. + 3. ``DEFAULT_USER_ID`` — last-resort fallback so unauthenticated + CLI / migration / test paths keep working without raising. + + Tools that persist user-scoped state (custom agents, memory, uploads) + MUST call this instead of ``get_effective_user_id()`` directly so they + benefit from the runtime.context channel that ``setup_agent`` already + relies on. + """ + context = getattr(runtime, "context", None) + if isinstance(context, dict): + ctx_user_id = context.get("user_id") + if ctx_user_id: + return str(ctx_user_id) + return get_effective_user_id() + + # --------------------------------------------------------------------------- # Sentinel-based user_id resolution # --------------------------------------------------------------------------- diff --git a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py index 2f796b005..dfbcf8b6e 100644 --- a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py @@ -7,19 +7,12 @@ from langgraph.types import Command from deerflow.config.agents_config import validate_agent_name from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.user_context import resolve_runtime_user_id from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) -def _get_runtime_user_id(runtime: Runtime) -> str: - context_user_id = runtime.context.get("user_id") if runtime.context else None - if context_user_id: - return str(context_user_id) - return get_effective_user_id() - - @tool(parse_docstring=True) def setup_agent( soul: str, @@ -45,7 +38,7 @@ def setup_agent( if agent_name: # Custom agents are persisted under the current user's bucket so # different users do not see each other's agents. - user_id = _get_runtime_user_id(runtime) + user_id = resolve_runtime_user_id(runtime) agent_dir = paths.user_agent_dir(user_id, agent_name) else: # Default agent (no agent_name): SOUL.md lives at the global base dir. diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 861c45b45..cf9281ff4 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -26,6 +26,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can +# write it back to the triggering AIMessage's usage_metadata. +_subagent_usage_cache: dict[str, dict[str, int]] = {} + + +def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool: + if app_config is None: + try: + app_config = get_app_config() + except FileNotFoundError: + return False + return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False)) + + +def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None: + if enabled and usage: + _subagent_usage_cache[tool_call_id] = usage + + +def pop_cached_subagent_usage(tool_call_id: str) -> dict | None: + return _subagent_usage_cache.pop(tool_call_id, None) + def _is_subagent_terminal(result: Any) -> bool: """Return whether a background subagent result is safe to clean up.""" @@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None: return None +def _summarize_usage(records: list[dict] | None) -> dict | None: + """Summarize token usage records into a compact dict for SSE events.""" + if not records: + return None + return { + "input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records), + "output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records), + "total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records), + } + + def _report_subagent_usage(runtime: Any, result: Any) -> None: """Report subagent token usage to the parent RunJournal, if available. @@ -177,6 +210,7 @@ async def task_tool( subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. """ runtime_app_config = _get_runtime_app_config(runtime) + cache_token_usage = _token_usage_cache_enabled(runtime_app_config) available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names() # Get subagent configuration @@ -312,27 +346,32 @@ async def task_tool( last_message_count = current_message_count # Check if task completed, failed, or timed out + usage = _summarize_usage(getattr(result, "token_usage_records", None)) if result.status == SubagentStatus.COMPLETED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_completed", "task_id": task_id, "result": result.result}) + writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") cleanup_background_task(task_id) return f"Task Succeeded. Result: {result.result}" elif result.status == SubagentStatus.FAILED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_failed", "task_id": task_id, "error": result.error}) + writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage}) logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") cleanup_background_task(task_id) return f"Task failed. Error: {result.error}" elif result.status == SubagentStatus.CANCELLED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_cancelled", "task_id": task_id, "error": result.error}) + writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") cleanup_background_task(task_id) return "Task cancelled by user." elif result.status == SubagentStatus.TIMED_OUT: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) + writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage}) logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") cleanup_background_task(task_id) return f"Task timed out. Error: {result.error}" @@ -351,7 +390,9 @@ async def task_tool( timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id}) + usage = _summarize_usage(getattr(result, "token_usage_records", None)) + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) + writer({"type": "task_timed_out", "task_id": task_id, "usage": usage}) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. @@ -374,4 +415,8 @@ async def task_tool( cleanup_background_task(task_id) else: _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) + _subagent_usage_cache.pop(tool_call_id, None) + raise + except Exception: + _subagent_usage_cache.pop(tool_call_id, None) raise diff --git a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py index b2dc8ca72..18500a248 100644 --- a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py @@ -27,7 +27,7 @@ from langgraph.types import Command from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.app_config import get_app_config from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.user_context import resolve_runtime_user_id from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -118,9 +118,13 @@ def update_agent( return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.") # Resolve the active user so that updates only affect this user's agent. - # ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context - # is set (matching how memory and thread storage behave). - user_id = get_effective_user_id() + # ``resolve_runtime_user_id`` prefers ``runtime.context["user_id"]`` (set by + # the gateway from the auth-validated request) and falls back to the + # contextvar, then DEFAULT_USER_ID. This matches setup_agent so a user + # creating an agent and later refining it always touches the same files, + # even if the contextvar gets lost across an async/thread boundary + # (issue #2782 / #2862 class of bugs). + user_id = resolve_runtime_user_id(runtime) # Reject an unknown ``model`` *before* touching the filesystem. Otherwise # ``_resolve_model_name`` silently falls back to the default at runtime diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 01bfce43f..5c97962fc 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig from deerflow.reflection import resolve_variable from deerflow.sandbox.security import is_host_bash_allowed from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool -from deerflow.tools.builtins.tool_search import reset_deferred_registry +from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.sync import make_sync_tool_wrapper logger = logging.getLogger(__name__) @@ -116,8 +116,6 @@ def get_available_tools( # made through the Gateway API (which runs in a separate process) are immediately # reflected when loading MCP tools. mcp_tools = [] - # Reset deferred registry upfront to prevent stale state from previous calls - reset_deferred_registry() if include_mcp: try: from deerflow.config.extensions_config import ExtensionsConfig @@ -135,12 +133,51 @@ def get_available_tools( from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool - registry = DeferredToolRegistry() - for t in mcp_tools: - registry.register(t) - set_deferred_registry(registry) + # Reuse the existing registry if one is already set for + # this async context. ``get_available_tools`` is + # re-entered whenever a subagent is spawned + # (``task_tool`` calls it to build the child agent's + # toolset), and previously we used to unconditionally + # rebuild the registry — wiping out the parent agent's + # tool_search promotions. The + # ``DeferredToolFilterMiddleware`` then re-hid those + # tools from subsequent model calls, leaving the agent + # able to see a tool's name but unable to invoke it + # (issue #2884). ``contextvars`` already gives us the + # lifetime semantics we want: a fresh request / graph + # run starts in a new asyncio task with the + # ContextVar at its default of ``None``, so reuse is + # only triggered for re-entrant calls inside one run. + # + # Intentionally NOT reconciling against the current + # ``mcp_tools`` snapshot. The MCP cache only refreshes + # on ``extensions_config.json`` mtime changes, which + # in practice happens between graph runs — not inside + # one. And even if a refresh did happen mid-run, the + # already-built lead agent's ``ToolNode`` still holds + # the *previous* tool set (LangGraph binds tools at + # graph construction time), so a brand-new MCP tool + # couldn't actually be invoked anyway. The + # ``DeferredToolRegistry`` doesn't retain the names + # of previously-promoted tools (``promote()`` drops + # the entry entirely), so re-syncing the registry + # against a fresh ``mcp_tools`` list would + # mis-classify those promotions as new tools and + # re-register them as deferred — exactly the bug + # this fix exists to prevent. + existing_registry = get_deferred_registry() + if existing_registry is None: + registry = DeferredToolRegistry() + for t in mcp_tools: + registry.register(t) + set_deferred_registry(registry) + logger.info(f"Tool search active: {len(mcp_tools)} tools deferred") + else: + mcp_tool_names = {t.name for t in mcp_tools} + still_deferred = len(existing_registry) + promoted_count = max(0, len(mcp_tool_names) - still_deferred) + logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted") builtin_tools.append(tool_search_tool) - logger.info(f"Tool search active: {len(mcp_tools)} tools deferred") except ImportError: logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.") except Exception as e: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6d2edb0bb..082c3d07d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ [project.optional-dependencies] postgres = ["deerflow-harness[postgres]"] +discord = ["discord.py>=2.7.0"] [dependency-groups] dev = [ diff --git a/backend/tests/_agent_e2e_helpers.py b/backend/tests/_agent_e2e_helpers.py new file mode 100644 index 000000000..2f28390a9 --- /dev/null +++ b/backend/tests/_agent_e2e_helpers.py @@ -0,0 +1,68 @@ +"""Shared helpers for user-isolation e2e tests on the custom-agent tooling. + +Centralises the small fake-LLM shim and a few test-data builders that the +three e2e files in this PR (``test_setup_agent_e2e_user_isolation``, +``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``) +all need. The shim is what lets a real ``langchain.agents.create_agent`` +graph run without an API key — every other layer in those tests is real +production code, which is the entire point of the test design. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage +from langchain_core.runnables import Runnable + + +class FakeToolCallingModel(FakeMessagesListChatModel): + """FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent. + + ``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to + expose the tool schemas to the model; the upstream fake raises + ``NotImplementedError`` there. We just return ``self`` because we + drive deterministic tool_call output via ``responses=...``, no schema + handling needed. + """ + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + +def build_single_tool_call_model( + *, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str = "call_e2e_1", + final_text: str = "done", +) -> FakeToolCallingModel: + """Build a fake model that emits exactly one tool_call then finishes. + + Two-turn behaviour, identical across our e2e tests: + turn 1 → AIMessage with a single tool_call for *tool_name* + turn 2 → AIMessage with *final_text* (terminates the agent loop) + """ + return FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": tool_args, + "id": tool_call_id, + "type": "tool_call", + } + ], + ), + AIMessage(content=final_text), + ] + ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index a357a3962..9bc8d4884 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import issues when unit-testing lightweight config/registry code in isolation. """ +from __future__ import annotations + import importlib.util import sys from pathlib import Path @@ -11,11 +13,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io # Make 'app' and 'deerflow' importable from any working directory sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts")) +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT) +_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector" + # Break the circular import chain that exists in production code: # deerflow.subagents.__init__ # -> .executor (SubagentExecutor, SubagentResult) @@ -56,6 +63,92 @@ def provisioner_module(): return module +@pytest.fixture() +def blocking_io_detector(): + """Fail a focused test if blocking calls run on the event loop thread.""" + with detect_blocking_io(fail_on_exit=True) as detector: + yield detector + + +def pytest_addoption(parser: pytest.Parser) -> None: + group = parser.getgroup("blocking-io") + group.addoption( + "--detect-blocking-io", + action="store_true", + default=False, + help="Collect blocking calls made while an asyncio event loop is running and report a summary.", + ) + group.addoption( + "--detect-blocking-io-fail", + action="store_true", + default=False, + help="Set a failing exit status when --detect-blocking-io records violations.", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe") + + +def pytest_sessionstart(session: pytest.Session) -> None: + if _blocking_io_probe_enabled(session.config): + _blocking_io_probe.clear() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_call(item: pytest.Item): + if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item): + yield + return + + detector = detect_blocking_io(fail_on_exit=False, stack_limit=18) + detector.__enter__() + setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector) + yield + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item: pytest.Item): + yield + + detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None) + if detector is None: + return + + try: + detector.__exit__(None, None, None) + _blocking_io_probe.record(item.nodeid, detector.violations) + finally: + delattr(item, _BLOCKING_IO_DETECTOR_ATTR) + + +def pytest_sessionfinish(session: pytest.Session) -> None: + if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK: + session.exitstatus = pytest.ExitCode.TESTS_FAILED + + +def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None: + if not _blocking_io_probe_enabled(terminalreporter.config): + return + + header, *details = _blocking_io_probe.format_summary().splitlines() + terminalreporter.write_sep("=", header) + for line in details: + terminalreporter.write_line(line) + + +def _blocking_io_probe_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_fail_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_probe_skipped(item: pytest.Item) -> bool: + return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None + + # --------------------------------------------------------------------------- # Auto-set user context for every test unless marked no_auto_user # --------------------------------------------------------------------------- diff --git a/backend/tests/support/__init__.py b/backend/tests/support/__init__.py new file mode 100644 index 000000000..38361eaf5 --- /dev/null +++ b/backend/tests/support/__init__.py @@ -0,0 +1 @@ +"""Shared test support helpers.""" diff --git a/backend/tests/support/detectors/__init__.py b/backend/tests/support/detectors/__init__.py new file mode 100644 index 000000000..cf9568cb6 --- /dev/null +++ b/backend/tests/support/detectors/__init__.py @@ -0,0 +1 @@ +"""Runtime and static detectors used by tests.""" diff --git a/backend/tests/support/detectors/blocking_io.py b/backend/tests/support/detectors/blocking_io.py new file mode 100644 index 000000000..c1adfd55a --- /dev/null +++ b/backend/tests/support/detectors/blocking_io.py @@ -0,0 +1,287 @@ +"""Test helper for detecting blocking calls on an asyncio event loop. + +The detector is intentionally test-only. It monkeypatches a small set of +well-known blocking entry points and their already-loaded module-level aliases, +then records calls only when they happen on a thread that is currently running +an asyncio event loop. Aliases captured in closures or default arguments remain +out of scope. +""" + +from __future__ import annotations + +import asyncio +import importlib +import sys +import traceback +from collections import Counter +from collections.abc import Callable, Iterable, Iterator +from contextlib import AbstractContextManager +from dataclasses import dataclass +from functools import wraps +from pathlib import Path +from types import TracebackType +from typing import Any + +BlockingCallable = Callable[..., Any] + + +@dataclass(frozen=True) +class BlockingCallSpec: + """Describes one blocking callable to wrap during a detector run.""" + + name: str + target: str + record_on_iteration: bool = False + + +@dataclass(frozen=True) +class BlockingCall: + """One blocking call observed on an asyncio event loop thread.""" + + name: str + target: str + stack: tuple[traceback.FrameSummary, ...] + + +DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = ( + BlockingCallSpec("time.sleep", "time:sleep"), + BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"), + BlockingCallSpec("httpx.Client.request", "httpx:Client.request"), + BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True), + BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"), + BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"), + BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"), +) + + +def _is_event_loop_thread() -> bool: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return False + return loop.is_running() + + +def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]: + module_name, attr_path = target.split(":", maxsplit=1) + owner: object = importlib.import_module(module_name) + parts = attr_path.split(".") + for part in parts[:-1]: + owner = getattr(owner, part) + + attr_name = parts[-1] + original = getattr(owner, attr_name) + return owner, attr_name, original + + +def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]: + return tuple(frame for frame in stack if frame.filename != __file__) + + +class BlockingIODetector(AbstractContextManager["BlockingIODetector"]): + """Record blocking calls made from async runtime code. + + By default the detector reports violations but does not fail on context + exit. Tests can set ``fail_on_exit=True`` or call + ``assert_no_blocking_calls()`` explicitly. + """ + + def __init__( + self, + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, + ) -> None: + self._specs = tuple(specs) + self._fail_on_exit = fail_on_exit + self._patch_loaded_aliases_enabled = patch_loaded_aliases + self._stack_limit = stack_limit + self._patches: list[tuple[object, str, BlockingCallable]] = [] + self._patch_keys: set[tuple[int, str]] = set() + self.violations: list[BlockingCall] = [] + self._active = False + + def __enter__(self) -> BlockingIODetector: + try: + self._active = True + alias_replacements: dict[int, BlockingCallable] = {} + for spec in self._specs: + owner, attr_name, original = _resolve_target(spec.target) + wrapper = self._wrap(spec, original) + self._patch_attribute(owner, attr_name, original, wrapper) + alias_replacements[id(original)] = wrapper + + if self._patch_loaded_aliases_enabled: + self._patch_loaded_module_aliases(alias_replacements) + except Exception: + self._restore() + self._active = False + raise + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback_value: TracebackType | None, + ) -> bool | None: + self._restore() + self._active = False + if exc_type is None and self._fail_on_exit: + self.assert_no_blocking_calls() + return None + + def _restore(self) -> None: + for owner, attr_name, original in reversed(self._patches): + setattr(owner, attr_name, original) + self._patches.clear() + self._patch_keys.clear() + + def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None: + key = (id(owner), attr_name) + if key in self._patch_keys: + return + setattr(owner, attr_name, replacement) + self._patches.append((owner, attr_name, original)) + self._patch_keys.add(key) + + def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None: + for module in tuple(sys.modules.values()): + namespace = getattr(module, "__dict__", None) + if not isinstance(namespace, dict): + continue + + for attr_name, value in tuple(namespace.items()): + replacement = replacements_by_id.get(id(value)) + if replacement is not None: + self._patch_attribute(module, attr_name, value, replacement) + + def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable: + @wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if spec.record_on_iteration: + result = original(*args, **kwargs) + return self._wrap_iteration(spec, result) + self._record_if_blocking(spec) + return original(*args, **kwargs) + + return wrapper + + def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]: + iterator = iter(iterable) + reported = False + + while True: + if not reported: + reported = self._record_if_blocking(spec) + try: + yield next(iterator) + except StopIteration: + return + + def _record_if_blocking(self, spec: BlockingCallSpec) -> bool: + if self._active and _is_event_loop_thread(): + stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit)) + self.violations.append(BlockingCall(spec.name, spec.target, stack)) + return True + return False + + def assert_no_blocking_calls(self) -> None: + if self.violations: + raise AssertionError(format_blocking_calls(self.violations)) + + +class BlockingIOProbe: + """Collect detector output across tests and format a compact summary.""" + + def __init__(self, project_root: Path) -> None: + self._project_root = project_root.resolve() + self._observed: list[tuple[str, BlockingCall]] = [] + + @property + def violation_count(self) -> int: + return len(self._observed) + + @property + def test_count(self) -> int: + return len({nodeid for nodeid, _violation in self._observed}) + + def clear(self) -> None: + self._observed.clear() + + def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None: + for violation in violations: + self._observed.append((nodeid, violation)) + + def format_summary(self, *, limit: int = 30) -> str: + if not self._observed: + return "blocking io probe: no violations" + + call_sites: Counter[tuple[str, str, int, str, str]] = Counter() + for _nodeid, violation in self._observed: + frame = self._local_call_site(violation.stack) + if frame is None: + call_sites[(violation.name, "", 0, "", "")] += 1 + continue + + call_sites[ + ( + violation.name, + self._relative(frame.filename), + frame.lineno, + frame.name, + (frame.line or "").strip(), + ) + ] += 1 + + lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"] + for (name, filename, lineno, function, line), count in call_sites.most_common(limit): + lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}") + return "\n".join(lines) + + def _relative(self, filename: str) -> str: + try: + return str(Path(filename).resolve().relative_to(self._project_root)) + except ValueError: + return filename + + def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None: + local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")] + if local_frames: + return local_frames[-1] + + test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename] + return test_frames[-1] if test_frames else None + + +def detect_blocking_io( + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, +) -> BlockingIODetector: + """Create a detector context manager for a focused test scope.""" + + return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit) + + +def format_blocking_calls(violations: Iterable[BlockingCall]) -> str: + """Format detector output with enough stack context to locate call sites.""" + + lines = ["Blocking calls were executed on an asyncio event loop thread:"] + for index, violation in enumerate(violations, start=1): + lines.append(f"{index}. {violation.name} ({violation.target})") + lines.extend(_format_stack(violation.stack)) + return "\n".join(lines) + + +def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]: + for frame in stack: + location = f"{frame.filename}:{frame.lineno}" + lines = [f" at {frame.name} ({location})"] + if frame.line: + lines.append(f" {frame.line.strip()}") + yield from lines diff --git a/backend/tests/test_artifacts_router.py b/backend/tests/test_artifacts_router.py index df32e45dc..f0627ff7b 100644 --- a/backend/tests/test_artifacts_router.py +++ b/backend/tests/test_artifacts_router.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest from _router_auth_helpers import call_unwrapped, make_authed_test_app +from fastapi import HTTPException from fastapi.testclient import TestClient from starlette.requests import Request from starlette.responses import FileResponse @@ -102,3 +103,17 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path assert response.status_code == 200 assert response.text == "hello" assert response.headers.get("content-disposition", "").startswith("attachment;") + + +def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None: + skill_path = tmp_path / "sample.skill" + payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1) + with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref: + zip_ref.writestr("SKILL.md", payload) + + assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + + with pytest.raises(HTTPException) as exc_info: + artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md") + + assert exc_info.value.status_code == 413 diff --git a/backend/tests/test_auth_config.py b/backend/tests/test_auth_config.py index 21b8bd81b..61d1d7d2e 100644 --- a/backend/tests/test_auth_config.py +++ b/backend/tests/test_auth_config.py @@ -5,28 +5,26 @@ from unittest.mock import patch import pytest -from app.gateway.auth.config import AuthConfig +import app.gateway.auth.config as cfg def test_auth_config_defaults(): - config = AuthConfig(jwt_secret="test-secret-key-123") + config = cfg.AuthConfig(jwt_secret="test-secret-key-123") assert config.token_expiry_days == 7 def test_auth_config_token_expiry_range(): - AuthConfig(jwt_secret="s", token_expiry_days=1) - AuthConfig(jwt_secret="s", token_expiry_days=30) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=1) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=30) with pytest.raises(Exception): - AuthConfig(jwt_secret="s", token_expiry_days=0) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=0) with pytest.raises(Exception): - AuthConfig(jwt_secret="s", token_expiry_days=31) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=31) def test_auth_config_from_env(): env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"} with patch.dict(os.environ, env, clear=False): - import app.gateway.auth.config as cfg - old = cfg._auth_config cfg._auth_config = None try: @@ -36,19 +34,57 @@ def test_auth_config_from_env(): cfg._auth_config = old -def test_auth_config_missing_secret_generates_ephemeral(caplog): +def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog): import logging - import app.gateway.auth.config as cfg + from deerflow.config.paths import Paths old = cfg._auth_config cfg._auth_config = None + secret_file = tmp_path / ".jwt_secret" try: with patch.dict(os.environ, {}, clear=True): os.environ.pop("AUTH_JWT_SECRET", None) - with caplog.at_level(logging.WARNING): + with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING): config = cfg.get_auth_config() assert config.jwt_secret assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) + assert secret_file.exists() + assert secret_file.read_text().strip() == config.jwt_secret + finally: + cfg._auth_config = old + + +def test_auth_config_reuses_persisted_secret(tmp_path): + from deerflow.config.paths import Paths + + old = cfg._auth_config + cfg._auth_config = None + persisted = "persisted-secret-from-file-min-32-chars!!" + (tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8") + try: + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AUTH_JWT_SECRET", None) + with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)): + config = cfg.get_auth_config() + assert config.jwt_secret == persisted + finally: + cfg._auth_config = old + + +def test_auth_config_empty_secret_file_generates_new(tmp_path): + from deerflow.config.paths import Paths + + old = cfg._auth_config + cfg._auth_config = None + (tmp_path / ".jwt_secret").write_text("", encoding="utf-8") + try: + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AUTH_JWT_SECRET", None) + with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)): + config = cfg.get_auth_config() + assert config.jwt_secret + assert len(config.jwt_secret) > 20 + assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret finally: cfg._auth_config = old diff --git a/backend/tests/test_blocking_io_detector.py b/backend/tests/test_blocking_io_detector.py new file mode 100644 index 000000000..af44d746d --- /dev/null +++ b/backend/tests/test_blocking_io_detector.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio +import os +import time +from os import walk as imported_walk +from pathlib import Path +from time import sleep as imported_sleep + +import httpx +import pytest +import requests +from support.detectors.blocking_io import ( + BlockingCallSpec, + BlockingIOProbe, + detect_blocking_io, +) + +pytestmark = pytest.mark.asyncio + + +TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),) +REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),) +HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),) +OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),) +PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),) + + +async def test_records_time_sleep_on_event_loop() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_records_already_imported_sleep_alias_on_event_loop() -> None: + original_alias = imported_sleep + + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + imported_sleep(0) + + assert imported_sleep is original_alias + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_can_disable_loaded_alias_patching() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector: + imported_sleep(0) + + assert detector.violations == [] + + +async def test_does_not_record_time_sleep_offloaded_to_thread() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + await asyncio.to_thread(time.sleep, 0) + + assert detector.violations == [] + + +async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None: + await asyncio.to_thread(time.sleep, 0) + + assert blocking_io_detector.violations == [] + + +async def test_does_not_record_sync_call_without_running_event_loop() -> None: + def call_sleep() -> list[str]: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + return [violation.name for violation in detector.violations] + + assert await asyncio.to_thread(call_sleep) == [] + + +async def test_fail_on_exit_includes_call_site() -> None: + with pytest.raises(AssertionError) as exc_info: + with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True): + time.sleep(0) + + message = str(exc_info.value) + assert "time.sleep" in message + assert "test_fail_on_exit_includes_call_site" in message + + +async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str: + return f"{method}:{url}" + + monkeypatch.setattr(requests.sessions.Session, "request", fake_request) + + with detect_blocking_io(REQUESTS_ONLY) as detector: + assert requests.get("https://example.invalid") == "get:https://example.invalid" + + assert [violation.name for violation in detector.violations] == ["requests.Session.request"] + + +async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response: + return httpx.Response(200, request=httpx.Request(method, url)) + + monkeypatch.setattr(httpx.Client, "request", fake_request) + + with detect_blocking_io(HTTPX_ONLY) as detector: + with httpx.Client() as client: + response = client.get("https://example.invalid") + + assert response.status_code == 200 + assert [violation.name for violation in detector.violations] == ["httpx.Client.request"] + + +async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(os.walk(tmp_path)) + + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + original_alias = imported_walk + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(imported_walk(tmp_path)) + + assert imported_walk is original_alias + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None: + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + + assert list(walker) + assert detector.violations == [] + + +async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + assert await asyncio.to_thread(lambda: list(walker)) + + assert detector.violations == [] + + +async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None: + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector: + assert path.read_text(encoding="utf-8") == "content" + + assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"] + + +async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + summary = probe.format_summary() + + assert "blocking io probe: 1 violations across 1 tests" in summary + assert "pathlib.Path.read_text" in summary + + +async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + + assert probe.format_summary() == "blocking io probe: no violations" + + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + assert probe.violation_count == 1 + + probe.clear() + + assert probe.violation_count == 0 + assert probe.format_summary() == "blocking io probe: no violations" diff --git a/backend/tests/test_blocking_io_probe_integration.py b/backend/tests/test_blocking_io_probe_integration.py new file mode 100644 index 000000000..af7a31b9d --- /dev/null +++ b/backend/tests/test_blocking_io_probe_integration.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import time + +import pytest + +ORIGINAL_SLEEP = time.sleep + + +def replacement_sleep(seconds: float) -> None: + return None + + +def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(time, "sleep", replacement_sleep) + assert time.sleep is replacement_sleep + + +@pytest.mark.no_blocking_io_probe +def test_probe_restores_original_after_monkeypatch_teardown() -> None: + assert time.sleep is ORIGINAL_SLEEP + assert getattr(time.sleep, "__wrapped__", None) is None diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index d68701c4e..f85062a17 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -761,7 +761,7 @@ class TestChannelManager: history_by_checkpoint: dict[tuple[str, str], list[str]] = {} - async def _runs_wait(thread_id, assistant_id, *, input, config, context): + async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None): del assistant_id, context # unused in this test, kept for signature parity checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns") diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index b1d5c476a..f9f47369d 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching: assert patched[1].name == "bash" assert patched[1].status == "error" + def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self): + middleware = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + HumanMessage(content="interruption"), + _tool_msg("call_1", "bash"), + ] + patched = middleware._build_patched_messages(msgs) + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert isinstance(patched[2], HumanMessage) + + def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]), + HumanMessage(content="interruption"), + _tool_msg("call_2", "read"), + _tool_msg("call_1", "bash"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert isinstance(patched[2], ToolMessage) + assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"] + assert isinstance(patched[3], HumanMessage) + + def test_valid_adjacent_tool_results_are_unchanged(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + _tool_msg("call_1", "bash"), + HumanMessage(content="next"), + ] + + assert mw._build_patched_messages(msgs) is None + + def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + HumanMessage(content="interruption"), + _ai_with_tool_calls([_tc("read", "call_2")]), + _tool_msg("call_1", "bash"), + _tool_msg("call_2", "read"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert isinstance(patched[2], HumanMessage) + assert isinstance(patched[3], AIMessage) + assert isinstance(patched[4], ToolMessage) + assert patched[4].tool_call_id == "call_2" + + def test_orphan_tool_message_is_preserved_during_grouping(self): + mw = DanglingToolCallMiddleware() + orphan = _tool_msg("orphan_call", "orphan") + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + orphan, + HumanMessage(content="interruption"), + _tool_msg("call_1", "bash"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert orphan in patched + assert patched.count(orphan) == 1 + def test_invalid_tool_call_is_patched(self): mw = DanglingToolCallMiddleware() msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])] diff --git a/backend/tests/test_deferred_tool_promotion_real_llm.py b/backend/tests/test_deferred_tool_promotion_real_llm.py new file mode 100644 index 000000000..46ae24d41 --- /dev/null +++ b/backend/tests/test_deferred_tool_promotion_real_llm.py @@ -0,0 +1,222 @@ +"""Real-LLM end-to-end verification for issue #2884. + +Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI- +compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware`` +and the production ``get_available_tools`` pipeline. The only thing we mock is +the MCP tool source — we hand-roll two ``@tool``s and inject them through +``deerflow.mcp.cache.get_cached_mcp_tools``. + +The flow exercised: + 1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger`` + that re-enters ``get_available_tools`` on the same task — this is the + code path issue #2884 reports). It must call ``tool_search`` to + discover the deferred ``fake_calculator`` tool. + 2. Tool batch: ``tool_search`` promotes ``fake_calculator``; + ``fake_subagent_trigger`` re-enters ``get_available_tools``. + 3. Turn 2: the promoted ``fake_calculator`` schema must reach the model + so it can actually call it. Without this PR's fix, the re-entry wipes + the promotion and the model can no longer invoke the tool. + +Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every +test run. Run with:: + + ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \ + PYTHONPATH=. uv run pytest \ + tests/test_deferred_tool_promotion_real_llm.py -v -s +""" + +from __future__ import annotations + +import os + +import pytest +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool as as_tool + +# --------------------------------------------------------------------------- +# Skip control: only run when explicitly opted in. +# --------------------------------------------------------------------------- + + +pytestmark = pytest.mark.skipif( + os.getenv("ONEAPI_E2E") != "1", + reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)", +) + + +# --------------------------------------------------------------------------- +# Fake "MCP" tools the agent should discover via tool_search. +# Keep them obviously synthetic so the model can pattern-match the search. +# --------------------------------------------------------------------------- + + +_calls: list[str] = [] + + +@as_tool +def fake_calculator(expression: str) -> str: + """Evaluate a tiny arithmetic expression like '2 + 2'. + + Reserved for the user — only call this if the user asks for arithmetic. + """ + _calls.append(f"fake_calculator:{expression}") + try: + # Trivially safe-eval just for the e2e check + allowed = set("0123456789+-*/() .") + if not set(expression) <= allowed: + return "expression contains disallowed characters" + return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307 + except Exception as e: + return f"error: {e}" + + +@as_tool +def fake_translator(text: str, target_lang: str) -> str: + """Translate text into the given language code. Decorative — not used.""" + _calls.append(f"fake_translator:{text}:{target_lang}") + return f"[{target_lang}] {text}" + + +# --------------------------------------------------------------------------- +# Pipeline wiring (same shape as the in-process tests). +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_registry_between_tests(): + from deerflow.tools.builtins.tool_search import reset_deferred_registry + + reset_deferred_registry() + yield + reset_deferred_registry() + + +def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: + from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig + + real_ext = ExtensionsConfig( + mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)}, + ) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: real_ext), + ) + monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools)) + + +def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Build a minimal mock AppConfig and patch the symbol — never call the + real loader, which would trigger ``_apply_singleton_configs`` and + permanently mutate cross-test singletons (memory, title, …).""" + from deerflow.config.app_config import AppConfig + from deerflow.config.tool_search_config import ToolSearchConfig + + mock_cfg = AppConfig.model_construct( + log_level="info", + models=[], + tools=[], + tool_groups=[], + sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"), + tool_search=ToolSearchConfig(enabled=True), + ) + monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg) + + +# --------------------------------------------------------------------------- +# Real-LLM e2e test +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch): + """End-to-end against a real OpenAI-compatible LLM. + + The model must: + Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and + batch-call BOTH ``tool_search(select:fake_calculator)`` AND + ``fake_subagent_trigger(...)``. + Turn 2 — call ``fake_calculator`` and finish. + + Pass criterion: ``fake_calculator`` actually gets invoked at the tool + layer — recorded in ``_calls`` — which proves the model received the + promoted schema after the re-entrant ``get_available_tools`` call. + """ + from langchain.agents import create_agent + from langchain_openai import ChatOpenAI + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator]) + _force_tool_search_enabled(monkeypatch) + _calls.clear() + + @as_tool + async def fake_subagent_trigger(prompt: str) -> str: + """Pretend to spawn a subagent. Internally rebuilds the toolset. + + Use this whenever the user asks you to delegate work — pass a short + description as ``prompt``. + """ + # ``task_tool`` does this internally. Whether the registry-reset that + # used to happen here actually leaks back to the parent task depends + # on asyncio's implicit context-copying semantics (gather creates + # child tasks with copied contexts, so reset_deferred_registry is + # task-local) — but the fix in this PR is what GUARANTEES the + # promotion sticks regardless of which integration path triggers a + # re-entrant ``get_available_tools`` call. + get_available_tools(subagent_enabled=False) + _calls.append(f"fake_subagent_trigger:{prompt}") + return "subagent completed" + + tools = get_available_tools() + [fake_subagent_trigger] + + model = ChatOpenAI( + model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"), + api_key=os.environ["OPENAI_API_KEY"], + base_url=os.environ["OPENAI_API_BASE"], + temperature=0, + max_retries=1, + ) + + system_prompt = ( + "You are a meticulous assistant. Available deferred tools include a " + "calculator and a translator — their schemas are hidden until you " + "search for them via tool_search.\n\n" + "Procedure for the user's request:\n" + " 1. Call tool_search with query 'select:fake_calculator' AND " + "in the SAME tool batch also call fake_subagent_trigger(prompt='go') " + "to delegate the side work. Put both tool_calls in your first response.\n" + " 2. After both tool messages come back, call fake_calculator with " + "the user's expression.\n" + " 3. Reply with just the numeric result." + ) + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt=system_prompt, + ) + + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]}, + config={"recursion_limit": 12}, + ) + + print("\n=== tool calls recorded ===") + for c in _calls: + print(f" {c}") + print("\n=== final message ===") + final_text = result["messages"][-1].content if result["messages"] else "(none)" + print(f" {final_text!r}") + + # The smoking-gun assertion: fake_calculator was actually invoked at the + # tool layer. This is only possible if the promoted schema reached the + # model in turn 2, despite the subagent-style re-entry in turn 1. + calc_calls = [c for c in _calls if c.startswith("fake_calculator:")] + assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}" + + # And the math should actually be done correctly (sanity that the LLM + # really used the result, not just hallucinated the answer). + assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}" diff --git a/backend/tests/test_deferred_tool_registry_promotion.py b/backend/tests/test_deferred_tool_registry_promotion.py new file mode 100644 index 000000000..23b7649ec --- /dev/null +++ b/backend/tests/test_deferred_tool_registry_promotion.py @@ -0,0 +1,390 @@ +"""Reproduce + regression-guard issue #2884. + +Hypothesis from the issue: + ``tools.tools.get_available_tools`` unconditionally calls + ``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry`` + every time it is invoked. If anything calls ``get_available_tools`` again + during the same async context (after the agent has promoted tools via + ``tool_search``), the promotion is wiped and the next model call hides the + tool's schema again. + +These tests pin two things: + +A. **At the unit boundary** — verify the failure mode directly. Promote a + tool in the registry, then call ``get_available_tools`` again and observe + that the ContextVar registry is reset and the promotion is lost. + +B. **At the graph-execution boundary** — drive a real ``create_agent`` graph + with the real ``DeferredToolFilterMiddleware`` through two model turns. + The first turn calls ``tool_search`` which promotes a tool. The second + turn must see that tool's schema in ``request.tools``. If + ``get_available_tools`` were to run again between the two turns and reset + the registry, the second turn's filter would strip the tool. + +Strategy: use the production ``deerflow.tools.tools.get_available_tools`` +unmodified; mock only the LLM and the MCP tool source. Patch +``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that +``get_available_tools`` resolves via lazy import) to return our fixture +tools so we don't need a real MCP server. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import tool as as_tool + + +class FakeToolCallingModel(FakeMessagesListChatModel): + """FakeMessagesListChatModel + no-op bind_tools so create_agent works.""" + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + +# --------------------------------------------------------------------------- +# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled +# --------------------------------------------------------------------------- + + +@as_tool +def fake_mcp_search(query: str) -> str: + """Pretend to search a knowledge base for the given query.""" + return f"results for {query}" + + +@as_tool +def fake_mcp_fetch(url: str) -> str: + """Pretend to fetch a page at the given URL.""" + return f"content of {url}" + + +@pytest.fixture(autouse=True) +def _supply_env(monkeypatch: pytest.MonkeyPatch): + """config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + +@pytest.fixture(autouse=True) +def _reset_deferred_registry_between_tests(): + """Each test must start with a clean ContextVar. + + The registry lives in a module-level ContextVar with no per-task isolation + in a synchronous test runner, so one test's promotion can leak into the + next and silently break filter assertions. + """ + from deerflow.tools.builtins.tool_search import reset_deferred_registry + + reset_deferred_registry() + yield + reset_deferred_registry() + + +def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: + """Make get_available_tools believe an MCP server is registered. + + Build a real ``ExtensionsConfig`` with one enabled MCP server entry so + that both ``AppConfig.from_file`` (which calls + ``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools`` + (which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``) + see a valid instance. Then point the MCP tool cache at our fixture tools. + """ + from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig + + real_ext = ExtensionsConfig( + mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)}, + ) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: real_ext), + ) + monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools)) + + +def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Force config.tool_search.enabled=True without touching the yaml. + + Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs`` + which permanently mutates module-level singletons (``_memory_config``, + ``_title_config``, …) to match the developer's ``config.yaml`` — even + after pytest restores our patch. That leaks across tests later in the + run that rely on those singletons' DEFAULTS (e.g. memory queue tests + require ``_memory_config.enabled = True``, which is the dataclass default + but FALSE in the actual yaml). + + Build a minimal mock AppConfig instead and never call the real loader. + """ + from deerflow.config.app_config import AppConfig + from deerflow.config.tool_search_config import ToolSearchConfig + + mock_cfg = AppConfig.model_construct( + log_level="info", + models=[], + tools=[], + tool_groups=[], + sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"), + tool_search=ToolSearchConfig(enabled=True), + ) + monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg) + + +# --------------------------------------------------------------------------- +# Section A — direct unit-level reproduction +# --------------------------------------------------------------------------- + + +def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch): + """Re-entrant ``get_available_tools()`` must preserve prior promotions. + + Step 1: call get_available_tools() — registers MCP tools as deferred. + Step 2: simulate the agent calling tool_search by promoting one tool. + Step 3: call get_available_tools() again (the same code path + ``task_tool`` exercises mid-run). + + Assertion: after step 3, the promoted tool is STILL promoted (not + re-deferred). On ``main`` before the fix, step 3's + ``reset_deferred_registry()`` wiped the promotion and re-registered + every MCP tool as deferred — this assertion fired with + ``REGRESSION (#2884)``. + """ + from deerflow.tools.builtins.tool_search import get_deferred_registry + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + # Step 1: first call — both MCP tools start deferred + get_available_tools() + reg1 = get_deferred_registry() + assert reg1 is not None + assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"} + + # Step 2: simulate tool_search promoting one of them + reg1.promote({"fake_mcp_search"}) + assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search" + + # Step 3: second call — registry must NOT silently undo the promotion + get_available_tools() + reg2 = get_deferred_registry() + assert reg2 is not None + deferred_after = {e.name for e in reg2.entries} + assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}" + + +# --------------------------------------------------------------------------- +# Section B — graph-execution reproduction +# --------------------------------------------------------------------------- + + +class _ToolSearchPromotingModel(FakeToolCallingModel): + """Two-turn model that: + + Turn 1 → emit a tool_call for ``tool_search`` (the real one) + Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool) + + Records the tools it received on each turn so the test can inspect what + DeferredToolFilterMiddleware actually fed to ``bind_tools``. + """ + + bound_tools_per_turn: list[list[str]] = [] + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + # Record the tool names the model would see in this turn + names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools] + self.bound_tools_per_turn.append(names) + return self + + +def _build_promoting_model() -> _ToolSearchPromotingModel: + return _ToolSearchPromotingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "tool_search", + "args": {"query": "select:fake_mcp_search"}, + "id": "call_search_1", + "type": "tool_call", + } + ], + ), + AIMessage( + content="", + tool_calls=[ + { + "name": "fake_mcp_search", + "args": {"query": "hello"}, + "id": "call_mcp_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="all done"), + ] + ) + + +def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch): + """End-to-end: drive a real create_agent graph through two turns. + + Without the fix, the second-turn bind_tools call should NOT contain + fake_mcp_search (because DeferredToolFilterMiddleware sees it in the + registry and strips it). With the fix, the model sees the schema and can + invoke it. + """ + from langchain.agents import create_agent + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + tools = get_available_tools() + # Sanity: the assembled tool list includes the deferred tools (they're in + # bind_tools but DeferredToolFilterMiddleware strips deferred ones before + # they reach the model) + tool_names = {getattr(t, "name", "") for t in tools} + assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names + + model = _build_promoting_model() + model.bound_tools_per_turn = [] # reset class-level recorder + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt="bug-2884-repro", + ) + + graph.invoke({"messages": [HumanMessage(content="use the search tool")]}) + + # Turn 1: model should NOT see fake_mcp_search (it's deferred) + turn1 = set(model.bound_tools_per_turn[0]) + assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}" + assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}" + + # Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it. + # This is the load-bearing assertion for issue #2884. + assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}" + turn2 = set(model.bound_tools_per_turn[1]) + assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}" + + +# --------------------------------------------------------------------------- +# Section C — the actual issue #2884 trigger: a re-entrant +# get_available_tools call (e.g. when task_tool spawns a subagent) must not +# wipe the parent's promotion. +# --------------------------------------------------------------------------- + + +def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch): + """Issue #2884 in its real shape: a re-entrant get_available_tools call + (the same pattern that happens when ``task_tool`` builds a subagent's + toolset mid-run) must not wipe the parent agent's tool_search promotions. + + Turn 1's tool batch contains BOTH ``tool_search`` (which promotes + ``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls + ``get_available_tools`` again — exactly what ``task_tool`` does when it + builds a subagent's toolset). With the fix, turn 2's bind_tools sees the + promoted tool. Without the fix, the re-entry wipes the registry and + the filter re-hides it. + """ + from langchain.agents import create_agent + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + # The trigger tool simulates what task_tool does internally: rebuild the + # toolset by calling get_available_tools while the registry is live. + @as_tool + def fake_subagent_trigger(prompt: str) -> str: + """Pretend to spawn a subagent. Internally rebuilds the toolset.""" + get_available_tools(subagent_enabled=False) + return f"spawned subagent for: {prompt}" + + tools = get_available_tools() + [fake_subagent_trigger] + + bound_per_turn: list[list[str]] = [] + + class _Model(FakeToolCallingModel): + def bind_tools(self, tools_arg, **kwargs): # type: ignore[override] + bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg]) + return self + + model = _Model( + responses=[ + # Turn 1: do both in one batch — promote AND trigger the + # subagent-style rebuild. LangGraph executes them in order in the + # same agent step. + AIMessage( + content="", + tool_calls=[ + { + "name": "tool_search", + "args": {"query": "select:fake_mcp_search"}, + "id": "call_search_1", + "type": "tool_call", + }, + { + "name": "fake_subagent_trigger", + "args": {"prompt": "go"}, + "id": "call_trigger_1", + "type": "tool_call", + }, + ], + ), + # Turn 2: try to invoke the promoted tool. The model gets this + # turn only if turn 1's bind_tools recorded what the filter sent. + AIMessage( + content="", + tool_calls=[ + { + "name": "fake_mcp_search", + "args": {"query": "hello"}, + "id": "call_mcp_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="all done"), + ] + ) + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt="bug-2884-subagent-repro", + ) + graph.invoke({"messages": [HumanMessage(content="use the search tool")]}) + + # Turn 1 sanity: deferred tool not visible yet + assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0] + + # The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the + # re-entrant get_available_tools call that happened in turn 1's tool batch. + assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}" + turn2 = set(bound_per_turn[1]) + assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}" diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 27808b0e8..3d62f0497 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -1,6 +1,6 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue from deerflow.config.memory_config import MemoryConfig @@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None: assert elapsed < 0.1 assert finished.is_set() is False assert finished.wait(1.0) is True + + +def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a") + queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b") + + assert queue.pending_count == 2 + assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"] + + +def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add( + thread_id="thread-1", + messages=["first"], + agent_name="agent-a", + correction_detected=True, + ) + queue.add( + thread_id="thread-1", + messages=["second"], + agent_name="agent-a", + correction_detected=False, + ) + + assert queue.pending_count == 1 + assert queue._queue[0].agent_name == "agent-a" + assert queue._queue[0].messages == ["second"] + assert queue._queue[0].correction_detected is True + + +def test_process_queue_updates_different_agents_in_same_thread_separately() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a") + queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b") + + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + + with ( + patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater), + patch("deerflow.agents.memory.queue.time.sleep"), + ): + queue.flush() + + assert mock_updater.update_memory.call_count == 2 + mock_updater.update_memory.assert_has_calls( + [ + call( + messages=["agent-a"], + thread_id="thread-1", + agent_name="agent-a", + correction_detected=False, + reinforcement_detected=False, + user_id=None, + ), + call( + messages=["agent-b"], + thread_id="thread-1", + agent_name="agent-b", + correction_detected=False, + reinforcement_detected=False, + user_id=None, + ), + ] + ) diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py index cf068e095..ce5d41210 100644 --- a/backend/tests/test_memory_queue_user_isolation.py +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue +from deerflow.config.memory_config import MemoryConfig def test_conversation_context_has_user_id(): @@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none(): def test_queue_add_stores_user_id(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") assert len(q._queue) == 1 assert q._queue[0].user_id == "alice" @@ -26,7 +27,7 @@ def test_queue_add_stores_user_id(): def test_queue_process_passes_user_id_to_updater(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") mock_updater = MagicMock() @@ -37,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater(): mock_updater.update_memory.assert_called_once() call_kwargs = mock_updater.update_memory.call_args.kwargs assert call_kwargs["user_id"] == "alice" + + +def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent(): + q = MemoryUpdateQueue() + + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): + q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice") + q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob") + + assert q.pending_count == 2 + assert [context.user_id for context in q._queue] == ["alice", "bob"] + assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]] + + +def test_queue_still_coalesces_updates_for_same_user_thread_and_agent(): + q = MemoryUpdateQueue() + + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): + q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice") + q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice") + + assert q.pending_count == 1 + assert q._queue[0].messages == ["second"] + assert q._queue[0].user_id == "alice" + assert q._queue[0].agent_name == "researcher" + + +def test_add_nowait_keeps_different_users_separate(): + q = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), + patch.object(q, "_schedule_timer"), + ): + q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice") + q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob") + + assert q.pending_count == 2 + assert [context.user_id for context in q._queue] == ["alice", "bob"] diff --git a/backend/tests/test_mindie_provider.py b/backend/tests/test_mindie_provider.py index 78bc0d972..cfbffbb07 100644 --- a/backend/tests/test_mindie_provider.py +++ b/backend/tests/test_mindie_provider.py @@ -454,7 +454,6 @@ class TestAStream: @pytest.mark.asyncio async def test_with_tools_emits_tool_call_chunk(self): - tool_calls = [{"name": "fn", "args": {}, "id": "c1"}] with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None): mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls) diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index d2c78ccf0..17b796af7 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -268,6 +268,39 @@ class TestEdgeCases: class TestDbRunEventStore: """Tests for DbRunEventStore with temp SQLite.""" + @pytest.mark.anyio + async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self): + from sqlalchemy.dialects import postgresql + + from deerflow.runtime.events.store.db import DbRunEventStore + + class FakeSession: + def __init__(self): + self.dialect = postgresql.dialect() + self.execute_calls = [] + self.scalar_stmt = None + + def get_bind(self): + return self + + async def execute(self, stmt, params=None): + self.execute_calls.append((stmt, params)) + + async def scalar(self, stmt): + self.scalar_stmt = stmt + return 41 + + session = FakeSession() + + max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1") + + assert max_seq == 41 + assert session.execute_calls + assert session.execute_calls[0][1] == {"thread_id": "thread-1"} + assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0]) + compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect())) + assert "FOR UPDATE" not in compiled + @pytest.mark.anyio async def test_basic_crud(self, tmp_path): from deerflow.persistence.engine import close_engine, get_session_factory, init_engine diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 6fd534829..5e230e790 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -3,7 +3,10 @@ Uses a temp SQLite DB to test ORM-backed CRUD operations. """ +import re + import pytest +from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository @@ -278,3 +281,48 @@ class TestRunRepository: assert row4["model_name"] is None await _cleanup() + + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self): + captured = [] + + class FakeResult: + def all(self): + return [] + + class FakeSession: + async def execute(self, stmt): + captured.append(stmt) + return FakeResult() + + class FakeSessionContext: + async def __aenter__(self): + return FakeSession() + + async def __aexit__(self, exc_type, exc, tb): + return None + + repo = RunRepository(lambda: FakeSessionContext()) + + agg = await repo.aggregate_tokens_by_thread("t1") + assert agg == { + "total_tokens": 0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_runs": 0, + "by_model": {}, + "by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0}, + } + assert len(captured) == 1 + + stmt = captured[0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect())) + select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1) + model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)" + + select_match = re.search(model_expr_pattern + r" AS model", select_sql) + group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip()) + + assert select_match is not None + assert group_by_match is not None + assert select_match.group(1) == group_by_match.group(1) diff --git a/backend/tests/test_setup_agent_e2e_user_isolation.py b/backend/tests/test_setup_agent_e2e_user_isolation.py new file mode 100644 index 000000000..034d4da84 --- /dev/null +++ b/backend/tests/test_setup_agent_e2e_user_isolation.py @@ -0,0 +1,429 @@ +"""End-to-end verification for issue #2862 (and the regression of #2782). + +Goal: prove — without trusting any single layer's claim — that an authenticated +user creating a custom agent through the real ``setup_agent`` tool, driven by a +real LangGraph ``create_agent`` graph, ends up with files under +``users//agents/`` and **not** under ``users/default/agents/...``. + +We intentionally exercise the full pipeline: + + HTTP body shape (mimics LangGraph SDK wire format) + -> app.gateway.services.start_run config-assembly chain + -> deerflow.runtime.runs.worker._build_runtime_context + -> langchain.agents.create_agent graph + -> ToolNode dispatch + -> setup_agent tool + +The only thing we mock is the LLM (FakeMessagesListChatModel) — every layer +that handles ``user_id`` is the real production code path. If the +``user_id`` propagation is broken anywhere in this chain, these tests will +fail. + +These tests intentionally ``no_auto_user`` so that the ``contextvar`` +fallback would put files into ``default/`` if propagation breaks. +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch +from uuid import UUID + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel +from langchain_core.messages import AIMessage, HumanMessage + +from app.gateway.services import ( + build_run_config, + inject_authenticated_user_context, + merge_run_context_overrides, +) +from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context + +# --------------------------------------------------------------------------- +# Helpers — real production code paths +# --------------------------------------------------------------------------- + + +def _make_request(user_id_str: str | None) -> SimpleNamespace: + """Build a fake FastAPI Request that carries an authenticated user.""" + if user_id_str is None: + user = None + else: + # User.id is UUID in production; honour that + user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") + return SimpleNamespace(state=SimpleNamespace(user=user)) + + +def _assemble_config( + *, + body_config: dict | None, + body_context: dict | None, + request_user_id: str | None, + thread_id: str = "thread-e2e", + assistant_id: str = "lead_agent", +) -> dict: + """Replay the **exact** start_run config-assembly sequence.""" + config = build_run_config(thread_id, body_config, None, assistant_id=assistant_id) + merge_run_context_overrides(config, body_context) + inject_authenticated_user_context(config, _make_request(request_user_id)) + return config + + +def _make_paths_mock(tmp_path: Path): + """Mirror the production paths.user_agent_dir signature.""" + from unittest.mock import MagicMock + + paths = MagicMock() + paths.base_dir = tmp_path + paths.agent_dir = lambda name: tmp_path / "agents" / name + paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name + return paths + + +# --------------------------------------------------------------------------- +# L1-L3: HTTP wire format → start_run → worker._build_runtime_context +# --------------------------------------------------------------------------- + + +class TestConfigAssembly: + """Covers L1-L3: validate that user_id reaches runtime_ctx for every wire shape.""" + + def test_typical_wire_format_user_id_in_runtime_ctx(self): + """Real frontend: body.config={recursion_limit}, body.context={agent_name,...}.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent", "is_bootstrap": True, "mode": "flash"}, + request_user_id="11111111-2222-3333-4444-555555555555", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555" + assert runtime_ctx["agent_name"] == "myagent" + + def test_body_context_none_still_injects_user_id(self): + """If frontend omits body.context entirely, inject must still create it.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context=None, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_body_context_empty_dict_still_injects_user_id(self): + """body.context={} (falsy) path: inject must still produce user_id.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={}, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_body_config_already_contains_context_field(self): + """body.config={'context': {...}} (LG 0.6 alt wire): inject still wins.""" + config = _assemble_config( + body_config={"context": {"agent_name": "myagent"}, "recursion_limit": 1000}, + body_context=None, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_client_supplied_user_id_is_overridden(self): + """Spoofed client user_id must be overwritten by inject (auth-trusted source).""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent", "user_id": "spoofed"}, + request_user_id="11111111-2222-3333-4444-555555555555", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555" + + def test_unauthenticated_request_does_not_inject(self): + """If request.state.user is missing (impossible under fail-closed auth, but + verify defensively), inject must not write user_id and runtime_ctx must + therefore lack it — forcing the tool fallback path to reveal itself.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent"}, + request_user_id=None, + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert "user_id" not in runtime_ctx + + +# --------------------------------------------------------------------------- +# L4-L7: Real LangGraph create_agent driving the real setup_agent tool +# --------------------------------------------------------------------------- + + +def _build_real_bootstrap_graph(authenticated_user_id: str): + """Construct a real LangGraph using create_agent + the real setup_agent tool. + + The LLM is faked (FakeMessagesListChatModel) so we don't need an API key. + Everything else — ToolNode dispatch, runtime injection, middleware — is + the real production code path. + """ + from langchain.agents import create_agent + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + # First model turn: emit a tool_call for setup_agent + # Second model turn (after tool result): final answer (terminates the loop) + fake_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": { + "soul": "# My E2E Agent\n\nA SOUL written by the model.", + "description": "End-to-end test agent", + }, + "id": "call_setup_1", + "type": "tool_call", + } + ], + ), + AIMessage(content=f"Done. Agent created for user {authenticated_user_id}."), + ] + ) + + graph = create_agent( + model=fake_model, + tools=[setup_agent], + system_prompt="You are a bootstrap agent. Call setup_agent immediately.", + ) + return graph + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_real_graph_real_setup_agent_writes_to_authenticated_user_dir(tmp_path: Path): + """The smoking-gun test for issue #2862. + + Under no_auto_user (contextvar = empty), if user_id propagation through + runtime.context is broken, setup_agent will fall back to DEFAULT_USER_ID + and write to users/default/agents/... The assertion that this directory + DOES NOT exist is what makes this test load-bearing. + """ + from langgraph.runtime import Runtime + + auth_uid = "abcdef01-2345-6789-abcd-ef0123456789" + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "e2e-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-1", + ) + + # Replay worker.run_agent's runtime construction. This is the key step: + # it is what makes ToolRuntime.context contain user_id when the tool + # actually fires. + runtime_ctx = _build_runtime_context("thread-e2e-1", "run-1", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_real_bootstrap_graph(auth_uid) + + # Patch get_paths only (the file-system rooting); everything else is real + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Drive the real graph. This goes through real ToolNode + real Runtime merge. + final_state = await graph.ainvoke( + {"messages": [HumanMessage(content="Create an agent named e2e-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "e2e-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "e2e-agent" + + # Load-bearing assertions: + assert expected_dir.exists(), f"Agent directory not found at the authenticated user's path. Expected: {expected_dir}. tmp_path tree: {[str(p) for p in tmp_path.rglob('*')]}" + assert (expected_dir / "SOUL.md").read_text() == "# My E2E Agent\n\nA SOUL written by the model." + assert (expected_dir / "config.yaml").exists() + assert not default_dir.exists(), "REGRESSION: agent landed under users/default/. user_id propagation broke somewhere between HTTP layer and ToolRuntime.context." + + # And final state should reflect tool success + last = final_state["messages"][-1] + assert "Done" in (last.content if isinstance(last.content, str) else str(last.content)) + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_inject_failure_falls_back_to_default_proving_test_is_load_bearing(tmp_path: Path): + """Negative control: if inject does NOT happen (no user in request), and + contextvar is empty (no_auto_user), setup_agent must land in default/. + + This proves the positive test is actually load-bearing — i.e. it would + have failed before PR #2784, not passed accidentally. + """ + from langgraph.runtime import Runtime + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "fallback-agent", "is_bootstrap": True}, + request_user_id=None, # no auth — inject is a no-op + thread_id="thread-e2e-2", + ) + + runtime_ctx = _build_runtime_context("thread-e2e-2", "run-2", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_real_bootstrap_graph("does-not-matter") + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + await graph.ainvoke( + {"messages": [HumanMessage(content="Create fallback-agent")]}, + config=config, + ) + + default_dir = tmp_path / "users" / "default" / "agents" / "fallback-agent" + assert default_dir.exists(), "Negative control failed: even without inject + contextvar, agent did not land in default/. The test infrastructure may not be reproducing the bug condition." + + +# --------------------------------------------------------------------------- +# L5: Sub-graph runtime propagation (the task tool case) +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_subgraph_invocation_preserves_user_id_in_runtime(tmp_path: Path): + """When a parent graph invokes a child graph (the pattern used by + subagents), parent_runtime.merge() must keep user_id intact. + + We construct a child graph that contains setup_agent and call it from + a parent graph's tool. If LangGraph re-creates the Runtime and drops + user_id at the sub-graph boundary, this fails. + """ + from langchain.agents import create_agent + from langgraph.runtime import Runtime + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + auth_uid = "deadbeef-0000-1111-2222-333344445555" + + # Inner graph: same as the bootstrap flow + inner_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": {"soul": "# Inner", "description": "subgraph"}, + "id": "call_inner_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="inner done"), + ] + ) + inner_graph = create_agent( + model=inner_model, + tools=[setup_agent], + system_prompt="inner", + ) + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "subgraph-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-3", + ) + runtime_ctx = _build_runtime_context("thread-e2e-3", "run-3", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Direct sub-graph invoke (mimics what a subagent invocation looks like + # — distinct ainvoke call, but parent config carries the same runtime). + await inner_graph.ainvoke( + {"messages": [HumanMessage(content="Create subgraph-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "subgraph-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "subgraph-agent" + assert expected_dir.exists() + assert not default_dir.exists() + + +# --------------------------------------------------------------------------- +# L6: Sync tool path through ContextThreadPoolExecutor +# --------------------------------------------------------------------------- + + +def test_sync_tool_dispatch_through_thread_pool_uses_runtime_context(tmp_path: Path): + """setup_agent is a sync function. When dispatched through ToolNode's + ContextThreadPoolExecutor, runtime.context must still carry user_id — + not via thread-local copy_context (which only carries contextvars), but + because it was passed in as the ToolRuntime constructor argument. + """ + from langchain.agents import create_agent + from langgraph.runtime import Runtime + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + auth_uid = "11112222-3333-4444-5555-666677778888" + + fake_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": {"soul": "# Sync", "description": "sync path"}, + "id": "call_sync_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="sync done"), + ] + ) + graph = create_agent(model=fake_model, tools=[setup_agent], system_prompt="sync") + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "sync-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-4", + ) + runtime_ctx = _build_runtime_context("thread-e2e-4", "run-4", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Use SYNC invoke to hit the ContextThreadPoolExecutor path + graph.invoke( + {"messages": [HumanMessage(content="Create sync-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "sync-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "sync-agent" + assert expected_dir.exists() + assert not default_dir.exists() diff --git a/backend/tests/test_setup_agent_http_e2e_real_server.py b/backend/tests/test_setup_agent_http_e2e_real_server.py new file mode 100644 index 000000000..950d040a0 --- /dev/null +++ b/backend/tests/test_setup_agent_http_e2e_real_server.py @@ -0,0 +1,326 @@ +"""Real HTTP end-to-end verification for issue #2862's setup_agent path. + +This test drives the **entire** FastAPI gateway through ``starlette.testclient.TestClient``: + + starlette.testclient.TestClient (real ASGI stack) + -> AuthMiddleware (real cookie parsing, real JWT decode) + -> /api/v1/auth/register endpoint (real password hash + sqlite write) + -> /api/threads/{id}/runs/stream endpoint (real start_run config-assembly) + -> background asyncio.create_task(run_agent) (real worker, real Runtime) + -> langchain.agents.create_agent graph (real, with fake LLM) + -> ToolNode dispatch (real) + -> setup_agent tool (real file I/O) + +The only mock is the LLM (no API key needed). Every layer that participates +in ``user_id`` propagation — auth, ContextVar, ``inject_authenticated_user_context``, +``worker._build_runtime_context``, ``Runtime.merge`` — is the real production +code path. If the chain is broken at any layer, this test fails. + +This is what "真实验证" looks like for a server that lives behind authentication: +register a user, log in (cookie), POST to /runs/stream, wait for the run to +finish, then read the filesystem. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model + + +def _build_fake_create_chat_model(agent_name: str): + """Return a callable matching the real ``create_chat_model`` signature. + + Whenever the lead agent constructs a chat model during the bootstrap flow, + we hand it a fake that emits a single setup_agent tool_call on its first + turn, then a benign final answer on its second turn. + """ + + def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel: + return build_single_tool_call_model( + tool_name="setup_agent", + tool_args={ + "soul": f"# Real HTTP E2E SOUL for {agent_name}", + "description": "real-http-e2e agent", + }, + tool_call_id="call_real_http_1", + final_text=f"Agent {agent_name} created via real HTTP e2e.", + ) + + return fake_create_chat_model + + +@pytest.fixture +def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Stand up an isolated DeerFlow data root + config under tmp_path. + + - Sets ``DEER_FLOW_HOME`` so paths land under tmp_path, not the real + ``.deer-flow`` directory. + - Stages a copy of the project's ``config.yaml`` (or ``config.example.yaml`` + on a fresh CI checkout where ``config.yaml`` is gitignored) and pins + ``DEER_FLOW_CONFIG_PATH`` to it, so lifespan boot doesn't depend on the + developer's local config layout. + - Sets a placeholder OPENAI_API_KEY because the config has + ``$OPENAI_API_KEY`` that gets resolved at parse time; the LLM itself is + mocked, so any non-empty value works. + """ + home = tmp_path / "deer-flow-home" + home.mkdir() + monkeypatch.setenv("DEER_FLOW_HOME", str(home)) + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used-because-llm-is-mocked") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + # Hermetic config: do not depend on whether the dev machine has a real + # ``config.yaml`` at the repo root. CI's ``actions/checkout`` only ships + # ``config.example.yaml`` (and its ``models:`` list is commented out, so + # AppConfig validation would reject it). Write a minimal, self-sufficient + # config to tmp_path and pin ``DEER_FLOW_CONFIG_PATH`` to it. + staged_config = tmp_path / "config.yaml" + staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config)) + + return home + + +# Minimal config that satisfies AppConfig + LeadAgent's _resolve_model_name. +# The model `use` path must resolve to a real class for config parsing to +# succeed; the test patches ``create_chat_model`` on the lead agent module, +# so the model is never actually instantiated. SandboxConfig.use is required +# at schema level; LocalSandboxProvider is the only sandbox that runs without +# Docker. +_MINIMAL_CONFIG_YAML = """\ +log_level: info +models: + - name: fake-test-model + display_name: Fake Test Model + use: langchain_openai:ChatOpenAI + model: gpt-4o-mini + api_key: $OPENAI_API_KEY + base_url: $OPENAI_API_BASE +sandbox: + use: deerflow.sandbox.local:LocalSandboxProvider +agents_api: + enabled: true +database: + backend: sqlite +""" + + +def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None: + """Reset every process-wide cache that would survive across tests. + + This fixture stands up a full FastAPI app + sqlite DB + LangGraph runtime + inside ``tmp_path``. To get true per-test isolation we have to invalidate + a handful of module-level caches that production normally never resets, + so they pick up our test-only ``DEER_FLOW_HOME`` and sqlite path: + + - ``deerflow.config.app_config`` caches the parsed ``config.yaml``. + - ``deerflow.config.paths`` caches the ``Paths`` singleton derived from + ``DEER_FLOW_HOME`` at first access. + - ``deerflow.persistence.engine`` caches the SQLAlchemy engine and + session factory after the first call to ``init_engine_from_config``. + + ``raising=False`` keeps the fixture resilient if upstream renames or + drops one of these attributes — the test will simply skip that reset + instead of failing with a confusing AttributeError, and the next test + to call ``get_app_config()``/``get_paths()`` will surface the real + incompatibility loudly. + """ + from deerflow.config import app_config as app_config_module + from deerflow.config import paths as paths_module + from deerflow.persistence import engine as engine_module + + for module, attr in ( + (app_config_module, "_app_config"), + (app_config_module, "_app_config_path"), + (app_config_module, "_app_config_mtime"), + (paths_module, "_paths_singleton"), + (engine_module, "_engine"), + (engine_module, "_session_factory"), + ): + monkeypatch.setattr(module, attr, None, raising=False) + + +@pytest.fixture +def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch): + """Build a fresh FastAPI app inside a clean DEER_FLOW_HOME. + + Each test gets its own sqlite DB and checkpoint store under ``tmp_path``, + with no cross-test contamination. + """ + _reset_process_singletons(monkeypatch) + + # Re-resolve the config from the test-only DEER_FLOW_HOME and pin its + # sqlite path into tmp_path so the lifespan-time engine init lands there. + from deerflow.config import app_config as app_config_module + + cfg = app_config_module.get_app_config() + cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db") + + from app.gateway.app import create_app + + return create_app() + + +def _drain_stream(response, *, timeout: float = 30.0, max_bytes: int = 4 * 1024 * 1024) -> str: + """Consume an SSE response body until the run terminates and return the text. + + Bounded to keep the test fail-fast: + - Stops as soon as an ``event: end`` SSE frame is observed (the gateway + sends this when the background run finishes — see ``services.format_sse`` + and ``StreamBridge.publish_end``). + - Stops at ``timeout`` seconds wall-clock so a stuck run / runaway heartbeat + loop surfaces a real failure instead of hanging pytest. + - Stops at ``max_bytes`` so a runaway producer can't OOM the test process. + """ + import time as _time + + deadline = _time.monotonic() + timeout + body = b"" + for chunk in response.iter_bytes(): + body += chunk + if b"event: end" in body: + break + if len(body) >= max_bytes: + break + if _time.monotonic() >= deadline: + break + return body.decode("utf-8", errors="replace") + + +def _wait_for_file(path: Path, *, timeout: float = 10.0) -> bool: + """Block until *path* exists or *timeout* elapses. + + The run completes inside ``asyncio.create_task`` after start_run returns, + so the test must wait for the background task to flush its writes. + """ + import time as _time + + deadline = _time.monotonic() + timeout + while _time.monotonic() < deadline: + if path.exists(): + return True + _time.sleep(0.05) + return False + + +@pytest.mark.no_auto_user +def test_real_http_create_agent_lands_in_authenticated_user_dir( + isolated_app: Any, + isolated_deer_flow_home: Path, + monkeypatch: pytest.MonkeyPatch, +): + """The full real-server contract test. + + 1. Register a real user via POST /api/v1/auth/register (also auto-logs in) + 2. POST to /api/threads/{tid}/runs/stream with the **exact** body shape the + frontend (LangGraph SDK) sends during the bootstrap flow. + 3. Wait for the background run to finish. + 4. Assert SOUL.md exists under users//agents//. + 5. Assert NOTHING exists under users/default/agents//. + """ + # ``deerflow.agents.lead_agent.agent`` imports ``create_chat_model`` with + # ``from deerflow.models import create_chat_model`` at module load time, + # rebinding the symbol into its own namespace. So the only patch that + # intercepts the call is the bound name on ``lead_agent.agent`` — patching + # ``deerflow.models.create_chat_model`` would be too late. + agent_name = "real-http-agent" + + from starlette.testclient import TestClient + + with ( + patch( + "deerflow.agents.lead_agent.agent.create_chat_model", + new=_build_fake_create_chat_model(agent_name), + ), + TestClient(isolated_app) as client, + ): + # --- 1. Register & auto-login --- + register = client.post( + "/api/v1/auth/register", + json={"email": "e2e-user@example.com", "password": "very-strong-password-123"}, + ) + assert register.status_code == 201, register.text + registered = register.json() + auth_uid = registered["id"] + # The endpoint sets both access_token (auth) and csrf_token (CSRF Double + # Submit Cookie) cookies; the TestClient cookie jar propagates them. + assert client.cookies.get("access_token"), "register endpoint must set session cookie" + csrf_token = client.cookies.get("csrf_token") + assert csrf_token, "register endpoint must set csrf_token cookie" + + # --- 2. Create a thread (require_existing=True on /runs/stream means + # we must call POST /api/threads first; the React frontend does the + # same via the LangGraph SDK's threads.create) --- + import uuid as _uuid + + thread_id = str(_uuid.uuid4()) + created = client.post( + "/api/threads", + json={"thread_id": thread_id, "metadata": {}}, + headers={"X-CSRF-Token": csrf_token}, + ) + assert created.status_code == 200, created.text + + # --- 3. POST /runs/stream with the bootstrap wire format --- + # This is the EXACT shape the React frontend sends after PR #2784: + # thread.submit(input, {config, context}) -> + # POST /api/threads/{id}/runs/stream body = + # {assistant_id, input, config, context} + body = { + "assistant_id": "lead_agent", + "input": { + "messages": [ + { + "role": "user", + "content": (f"The new custom agent name is {agent_name}. Help me design its SOUL.md before saving it."), + } + ] + }, + "config": {"recursion_limit": 50}, + "context": { + "agent_name": agent_name, + "is_bootstrap": True, + "mode": "flash", + "thinking_enabled": False, + "is_plan_mode": False, + "subagent_enabled": False, + }, + "stream_mode": ["values"], + } + # The /stream endpoint returns SSE; we drain it so the server-side + # background task (run_agent) gets to completion before we look at disk. + with client.stream( + "POST", + f"/api/threads/{thread_id}/runs/stream", + json=body, + headers={"X-CSRF-Token": csrf_token}, + ) as resp: + assert resp.status_code == 200, resp.read().decode() + transcript = _drain_stream(resp) + + # Sanity: the stream should have produced at least one event + assert "event:" in transcript, f"no SSE events in response: {transcript[:500]!r}" + + # --- 4. Verify filesystem outcome --- + expected_dir = isolated_deer_flow_home / "users" / auth_uid / "agents" / agent_name + default_dir = isolated_deer_flow_home / "users" / "default" / "agents" / agent_name + + # The setup_agent tool runs inside the background asyncio task spawned + # by start_run; SSE-drain typically waits for it, but we add a bounded + # poll to be robust against scheduler jitter. + assert _wait_for_file(expected_dir / "SOUL.md", timeout=15.0), ( + "SOUL.md did not appear under users//agents/. " + f"Expected: {expected_dir / 'SOUL.md'}. " + f"tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}. " + f"SSE transcript tail: {transcript[-1000:]!r}" + ) + + soul_text = (expected_dir / "SOUL.md").read_text() + assert agent_name in soul_text, f"unexpected SOUL content: {soul_text!r}" + + # The smoking-gun assertion: the agent must NOT have landed in default/ + assert not default_dir.exists(), f"REGRESSION: agent landed under users/default/{agent_name} instead of the authenticated user. Default-dir contents: {list(default_dir.rglob('*')) if default_dir.exists() else 'n/a'}" diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index b147acaf6..58f372488 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage: ) -def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace: +def _runtime( + thread_id: str | None = "thread-1", + agent_name: str | None = None, + user_id: str | None = None, +) -> SimpleNamespace: context = {} if thread_id is not None: context["thread_id"] = thread_id if agent_name is not None: context["agent_name"] = agent_name + if user_id is not None: + context["user_id"] = user_id return SimpleNamespace(context=context) @@ -693,3 +699,22 @@ def test_before_model_summary_message_has_hide_from_ui() -> None: summary_msg = emitted[1] assert summary_msg.name == "summary" assert summary_msg.additional_kwargs.get("hide_from_ui") is True + + +def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(_messages()[:2]), + preserved_messages=(), + thread_id="main", + agent_name="researcher", + runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"), + ) + ) + + queue.add_nowait.assert_called_once() + assert queue.add_nowait.call_args.kwargs["user_id"] == "alice" diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 0591c0e8d..658968d65 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -59,12 +59,15 @@ def _make_result( ai_messages: list[dict] | None = None, result: str | None = None, error: str | None = None, + token_usage_records: list[dict] | None = None, ) -> SimpleNamespace: return SimpleNamespace( status=status, ai_messages=ai_messages or [], result=result, error=error, + token_usage_records=token_usage_records or [], + usage_reported=False, ) @@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch): assert len(report_calls) == 1 assert report_calls[0][1] is cancel_result assert cleanup_calls == ["tc-cancel-report"] + + +@pytest.mark.parametrize( + "status, expected_type", + [ + (FakeSubagentStatus.COMPLETED, "task_completed"), + (FakeSubagentStatus.FAILED, "task_failed"), + (FakeSubagentStatus.CANCELLED, "task_cancelled"), + (FakeSubagentStatus.TIMED_OUT, "task_timed_out"), + ], +) +def test_terminal_events_include_usage(monkeypatch, status, expected_type): + """Terminal task events include a usage summary from token_usage_records.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + records = [ + {"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + {"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280}, + ] + result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-usage", + ) + + terminal_events = [e for e in events if e["type"] == expected_type] + assert len(terminal_events) == 1 + assert terminal_events[0]["usage"] == { + "input_tokens": 300, + "output_tokens": 130, + "total_tokens": 430, + } + + +def test_terminal_event_usage_none_when_no_records(monkeypatch): + """Terminal event has usage=None when token_usage_records is empty.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[]) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-no-records", + ) + + completed = [e for e in events if e["type"] == "task_completed"] + assert len(completed) == 1 + assert completed[0]["usage"] is None + + +def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch): + monkeypatch.setattr( + task_tool_module, + "get_app_config", + MagicMock(side_effect=FileNotFoundError("missing config")), + ) + + assert task_tool_module._token_usage_cache_enabled(None) is False + + +def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False)) + runtime = _make_runtime(app_config=app_config) + records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}] + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records) + + task_tool_module._subagent_usage_cache.clear() + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-disabled-cache", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None + + +def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True)) + runtime = _make_runtime(app_config=app_config) + + task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2} + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed"))) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + with pytest.raises(RuntimeError, match="poll failed"): + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-error", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-error") is None diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 3a6532567..1cef3752b 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -1,28 +1,25 @@ """Tests for ThreadMetaRepository (SQLAlchemy-backed).""" +import logging + import pytest -from deerflow.persistence.thread_meta import ThreadMetaRepository +from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository -async def _make_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine +@pytest.fixture +async def repo(tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return ThreadMetaRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - + yield ThreadMetaRepository(get_session_factory()) await close_engine() class TestThreadMetaRepository: @pytest.mark.anyio - async def test_create_and_get(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_and_get(self, repo): record = await repo.create("t1") assert record["thread_id"] == "t1" assert record["status"] == "idle" @@ -31,148 +28,523 @@ class TestThreadMetaRepository: fetched = await repo.get("t1") assert fetched is not None assert fetched["thread_id"] == "t1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_assistant_id(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_assistant_id(self, repo): record = await repo.create("t1", assistant_id="agent1") assert record["assistant_id"] == "agent1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_owner_and_display_name(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_owner_and_display_name(self, repo): record = await repo.create("t1", user_id="user1", display_name="My Thread") assert record["user_id"] == "user1" assert record["display_name"] == "My Thread" - await _cleanup() @pytest.mark.anyio - async def test_create_with_metadata(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_metadata(self, repo): record = await repo.create("t1", metadata={"key": "value"}) assert record["metadata"] == {"key": "value"} - await _cleanup() @pytest.mark.anyio - async def test_get_nonexistent(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_get_nonexistent(self, repo): assert await repo.get("nonexistent") is None - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_record_allows(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_record_allows(self, repo): assert await repo.check_access("unknown", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_matches(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_matches(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_mismatch(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_mismatch(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2") is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_owner_allows_all(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_owner_allows_all(self, repo): # Explicit user_id=None to bypass the new AUTO default that # would otherwise pick up the test user from the autouse fixture. await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_missing_row_denied(self, tmp_path): + async def test_check_access_strict_missing_row_denied(self, repo): """require_existing=True flips the missing-row case to *denied*. Closes the delete-idempotence cross-user gap: after a thread is deleted, the row is gone, and the permissive default would let any caller "claim" it as untracked. The strict mode demands a row. """ - repo = await _make_repo(tmp_path) assert await repo.check_access("never-existed", "user1", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_match_allowed(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_match_allowed(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_mismatch_denied(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): + async def test_check_access_strict_null_owner_still_allowed(self, repo): """Even in strict mode, a row with NULL user_id stays shared. The strict flag tightens the *missing row* case, not the *shared row* case — legacy pre-auth rows that survived a clean migration without an owner are still everyone's. """ - repo = await _make_repo(tmp_path) await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_update_status(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_status(self, repo): await repo.create("t1") await repo.update_status("t1", "busy") record = await repo.get("t1") assert record["status"] == "busy" - await _cleanup() @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete(self, repo): await repo.create("t1") await repo.delete("t1") assert await repo.get("t1") is None - await _cleanup() @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete_nonexistent_is_noop(self, repo): await repo.delete("nonexistent") # should not raise - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_merges(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_merges(self, repo): await repo.create("t1", metadata={"a": 1, "b": 2}) await repo.update_metadata("t1", {"b": 99, "c": 3}) record = await repo.get("t1") # Existing key preserved, overlapping key overwritten, new key added assert record["metadata"] == {"a": 1, "b": 99, "c": 3} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_on_empty(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_on_empty(self, repo): await repo.create("t1") await repo.update_metadata("t1", {"k": "v"}) record = await repo.get("t1") assert record["metadata"] == {"k": "v"} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_nonexistent_is_noop(self, repo): await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise - await _cleanup() + + # --- search with metadata filter (SQL push-down) --- + + @pytest.mark.anyio + async def test_search_metadata_filter_string(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + await repo.create("t3", metadata={"env": "prod", "region": "us"}) + + results = await repo.search(metadata={"env": "prod"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_numeric(self, repo): + await repo.create("t1", metadata={"priority": 1}) + await repo.create("t2", metadata={"priority": 2}) + await repo.create("t3", metadata={"priority": 1, "extra": "x"}) + + results = await repo.search(metadata={"priority": 1}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_multiple_keys(self, repo): + await repo.create("t1", metadata={"env": "prod", "region": "us"}) + await repo.create("t2", metadata={"env": "prod", "region": "eu"}) + await repo.create("t3", metadata={"env": "staging", "region": "us"}) + + results = await repo.search(metadata={"env": "prod", "region": "us"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_metadata_no_match(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "dev"}) + assert results == [] + + @pytest.mark.anyio + async def test_search_metadata_pagination_correct(self, repo): + """Regression: SQL push-down makes limit/offset exact even when most rows don't match.""" + for i in range(30): + meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"} + await repo.create(f"t{i:03d}", metadata=meta) + + # Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows + all_matches = await repo.search(metadata={"target": "yes"}, limit=100) + assert len(all_matches) == 10 + + # Paginate: first page + page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0) + assert len(page1) == 3 + + # Paginate: second page + page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3) + assert len(page2) == 3 + + # No overlap between pages + page1_ids = {r["thread_id"] for r in page1} + page2_ids = {r["thread_id"] for r in page2} + assert page1_ids.isdisjoint(page2_ids) + + # Last page + page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9) + assert len(page_last) == 1 + + @pytest.mark.anyio + async def test_search_metadata_with_status_filter(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "prod"}) + await repo.update_status("t1", "busy") + + results = await repo.search(metadata={"env": "prod"}, status="busy") + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_without_metadata_still_works(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2") + + results = await repo.search(limit=10) + assert len(results) == 2 + + @pytest.mark.anyio + async def test_search_metadata_missing_key_no_match(self, repo): + """Rows without the requested metadata key should not match.""" + await repo.create("t1", metadata={"other": "val"}) + await repo.create("t2", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "prod"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t2" + + @pytest.mark.anyio + async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog): + """When ALL metadata keys are unsafe, raises InvalidMetadataFilterError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info: + await repo.search(metadata={"bad;key": "x"}) + assert any("bad;key" in r.message for r in caplog.records) + # Subclass of ValueError for backward compatibility + assert isinstance(exc_info.value, ValueError) + + @pytest.mark.anyio + async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog): + """Valid keys filter rows; only the invalid key is warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + results = await repo.search(metadata={"env": "prod", "bad;key": "x"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + assert any("bad;key" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_filter_boolean(self, repo): + """True matches only boolean true, not integer 1.""" + await repo.create("t1", metadata={"active": True}) + await repo.create("t2", metadata={"active": False}) + await repo.create("t3", metadata={"active": True, "extra": "x"}) + await repo.create("t4", metadata={"active": 1}) + + results = await repo.search(metadata={"active": True}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_none(self, repo): + """Only rows with explicit JSON null match; missing key does not.""" + await repo.create("t1", metadata={"tag": None}) + await repo.create("t2", metadata={"tag": "present"}) + await repo.create("t3", metadata={"other": "val"}) + + results = await repo.search(metadata={"tag": None}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + + @pytest.mark.anyio + async def test_search_metadata_non_string_key_skipped(self, repo, caplog): + """Non-string keys raise ValueError from isinstance check; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={1: "x"}) + assert any("1" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog): + """Unsupported value types (list, dict) raise TypeError; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"env": ["prod", "staging"]}) + + @pytest.mark.anyio + async def test_search_metadata_dotted_key_raises(self, repo, caplog): + """Dotted keys are rejected; when ALL keys are dotted, raises ValueError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"a.b": "anything"}) + assert any("a.b" in r.message for r in caplog.records) + + # --- dialect-aware type-safe filtering edge cases --- + + @pytest.mark.anyio + async def test_search_metadata_bool_vs_int_distinction(self, repo): + """True must not match 1; False must not match 0.""" + await repo.create("bool_true", metadata={"flag": True}) + await repo.create("bool_false", metadata={"flag": False}) + await repo.create("int_one", metadata={"flag": 1}) + await repo.create("int_zero", metadata={"flag": 0}) + + true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})} + assert true_hits == {"bool_true"} + + false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})} + assert false_hits == {"bool_false"} + + @pytest.mark.anyio + async def test_search_metadata_int_does_not_match_bool(self, repo): + """Integer 1 must not match boolean True.""" + await repo.create("bool_true", metadata={"val": True}) + await repo.create("int_one", metadata={"val": 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})} + assert hits == {"int_one"} + + @pytest.mark.anyio + async def test_search_metadata_none_excludes_missing_key(self, repo): + """Filtering by None matches explicit JSON null only, not missing key or empty {}.""" + await repo.create("explicit_null", metadata={"k": None}) + await repo.create("missing_key", metadata={"other": "x"}) + await repo.create("empty_obj", metadata={}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})} + assert hits == {"explicit_null"} + + @pytest.mark.anyio + async def test_search_metadata_float_value(self, repo): + await repo.create("t1", metadata={"score": 3.14}) + await repo.create("t2", metadata={"score": 2.71}) + await repo.create("t3", metadata={"score": 3.14}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})} + assert hits == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_mixed_types_same_key(self, repo): + """Each type query only matches its own type, even when the key is shared.""" + await repo.create("str_row", metadata={"x": "hello"}) + await repo.create("int_row", metadata={"x": 42}) + await repo.create("bool_row", metadata={"x": True}) + await repo.create("null_row", metadata={"x": None}) + + assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"} + + @pytest.mark.anyio + async def test_search_metadata_large_int_precision(self, repo): + """Integers beyond float precision (> 2**53) must match exactly.""" + large = 2**53 + 1 + await repo.create("t1", metadata={"id": large}) + await repo.create("t2", metadata={"id": large - 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})} + assert hits == {"t1"} + + +class TestJsonMatchCompilation: + """Verify compiled SQL for both SQLite and PostgreSQL dialects.""" + + def test_json_match_compiles_sqlite(self): + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + cases = [ + (None, "json_type(t.data, '$.\"k\"') = 'null'"), + (True, "json_type(t.data, '$.\"k\"') = 'true'"), + (False, "json_type(t.data, '$.\"k\"') = 'false'"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: uses INTEGER cast for precision, type-check narrows to 'integer' only + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "= 'integer'" in sql + assert "INTEGER" in sql + assert "CAST" in sql + + # float: uses REAL cast, type-check spans 'integer' and 'real' + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "IN ('integer', 'real')" in sql + assert "REAL" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "'text'" in sql + + def test_json_match_compiles_pg(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + dialect = postgresql.dialect() + + cases = [ + (None, "json_typeof(t.data -> 'k') = 'null'"), + (True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"), + (False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: CASE guard prevents CAST error when 'number' also matches floats + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "BIGINT" in sql + assert "CASE WHEN" in sql + assert "'^-?[0-9]+$'" in sql + + # float: uses DOUBLE PRECISION cast + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "DOUBLE PRECISION" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'string'" in sql + + def test_json_match_rejects_unsafe_key(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, bad_key, "x") + + # Non-string keys must also raise ValueError (not TypeError from re.match) + for non_str_key in [42, None, ("k",)]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, non_str_key, "x") + + def test_json_match_rejects_unsupported_value_type(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_value in [[], {}, object()]: + with pytest.raises(TypeError, match="JsonMatch value must be"): + json_match(t.c.data, "k", bad_value) + + def test_json_match_unsupported_dialect_raises(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import mysql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + expr = json_match(t.c.data, "k", "v") + + with pytest.raises(NotImplementedError, match="mysql"): + str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True})) + + def test_json_match_rejects_out_of_range_int(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + # boundary values must be accepted + json_match(t.c.data, "k", 2**63 - 1) + json_match(t.c.data, "k", -(2**63)) + + # one beyond each boundary must be rejected + for out_of_range in [2**63, -(2**63) - 1, 10**30]: + with pytest.raises(TypeError, match="out of signed 64-bit range"): + json_match(t.c.data, "k", out_of_range) + + def test_compiler_raises_on_escaped_key(self): + """Compiler raises ValueError even when __init__ validation is bypassed.""" + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + elem = json_match(t.c.data, "k", "v") + elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index daf0c0b13..9e37f3c86 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore from app.gateway.routers import threads from deerflow.config.paths import Paths +from deerflow.persistence.thread_meta import InvalidMetadataFilterError from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore _ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None assert entries, "expected at least one history entry" for entry in entries: assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry + + +# ── Metadata filter validation at API boundary ──────────────────────────────── + + +def test_search_threads_rejects_invalid_key_at_api_boundary() -> None: + """Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic + validator on ThreadSearchRequest.metadata — 422 from both backends. + """ + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}}) + + assert response.status_code == 422 + + +def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None: + """Value types outside (None, bool, int, float, str) are rejected.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}}) + + assert response.status_code == 422 + + +def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None: + """If the backend still raises InvalidMetadataFilterError (defense in + depth), the handler surfaces it as HTTP 400. + """ + app, _store, _checkpointer = _build_thread_app() + thread_store = app.state.thread_store + + async def _raise(**kwargs): + raise InvalidMetadataFilterError("rejected") + + with TestClient(app) as client: + with patch.object(thread_store, "search", side_effect=_raise): + response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}}) + + assert response.status_code == 400 + assert "rejected" in response.json()["detail"] + + +def test_search_threads_succeeds_with_valid_metadata() -> None: + """Sanity check: valid metadata passes through without error.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}}) + + assert response.status_code == 200 diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/test_title_middleware_core_logic.py index 5395f816e..3fdf4d3f9 100644 --- a/backend/tests/test_title_middleware_core_logic.py +++ b/backend/tests/test_title_middleware_core_logic.py @@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic: assert middleware._should_generate_title(state) is False def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch): - _set_test_title_config(max_chars=12) + _set_test_title_config(max_chars=12, model_name=None) middleware = TitleMiddleware() model = MagicMock() model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题")) diff --git a/backend/tests/test_todo_middleware.py b/backend/tests/test_todo_middleware.py index efeee9eb0..934e730f2 100644 --- a/backend/tests/test_todo_middleware.py +++ b/backend/tests/test_todo_middleware.py @@ -1,14 +1,19 @@ """Tests for TodoMiddleware context-loss detection.""" import asyncio -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import AsyncMock, MagicMock +from langchain.agents import create_agent +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel from langchain_core.messages import AIMessage, HumanMessage +from pydantic import PrivateAttr from deerflow.agents.middlewares.todo_middleware import ( TodoMiddleware, _completion_reminder_count, _format_todos, + _has_tool_call_intent_or_error, _reminder_in_messages, _todos_in_messages, ) @@ -22,9 +27,35 @@ def _reminder_msg(): return HumanMessage(name="todo_reminder", content="reminder") +class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel): + _seen_messages: list[list[Any]] = PrivateAttr(default_factory=list) + + @property + def seen_messages(self) -> list[list[Any]]: + return self._seen_messages + + def bind_tools(self, tools, *, tool_choice=None, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self._seen_messages.append(list(messages)) + return super()._generate( + messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + + def _make_runtime(): runtime = MagicMock() - runtime.context = {"thread_id": "test-thread"} + runtime.context = {"thread_id": "test-thread", "run_id": "test-run"} + return runtime + + +def _make_runtime_for(thread_id: str, run_id: str): + runtime = _make_runtime() + runtime.context = {"thread_id": thread_id, "run_id": run_id} return runtime @@ -161,10 +192,62 @@ def _completion_reminder_msg(): return HumanMessage(name="todo_completion_reminder", content="finish your todos") +def _todo_completion_reminders(messages): + reminders = [] + for message in messages: + if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder": + reminders.append(message) + return reminders + + def _ai_no_tool_calls(): return AIMessage(content="I'm done!") +def _ai_with_invalid_tool_calls(): + return AIMessage( + content="", + tool_calls=[], + invalid_tool_calls=[ + { + "type": "invalid_tool_call", + "id": "write_file:36", + "name": "write_file", + "args": "{invalid", + "error": "Failed to parse tool arguments", + } + ], + ) + + +def _ai_with_raw_provider_tool_calls(): + return AIMessage( + content="", + tool_calls=[], + invalid_tool_calls=[], + additional_kwargs={ + "tool_calls": [ + { + "id": "raw-tool-call", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"report.md"}'}, + } + ] + }, + ) + + +def _ai_with_legacy_function_call(): + return AIMessage( + content="", + additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}}, + ) + + +def _ai_with_tool_finish_reason(): + return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"}) + + def _incomplete_todos(): return [ {"status": "completed", "content": "Step 1"}, @@ -194,6 +277,36 @@ class TestCompletionReminderCount: assert _completion_reminder_count(msgs) == 1 +class TestToolCallIntentOrError: + def test_false_for_plain_final_answer(self): + assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False + + def test_true_for_structured_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True + + def test_true_for_invalid_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True + + def test_true_for_raw_provider_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True + + def test_true_for_legacy_function_call(self): + assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True + + def test_true_for_tool_finish_reason(self): + assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True + + def test_langchain_ai_message_tool_fields_are_explicitly_handled(self): + # Sentinel for LangChain compatibility: if future AIMessage versions add + # new top-level tool/function-call fields, this test should fail. When + # it does, update `_has_tool_call_intent_or_error()` so the completion + # reminder guard explicitly decides whether each new field means "not a + # clean final answer"; the helper has a matching comment pointing back + # to this sentinel. + tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())} + assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"} + + class TestAfterModel: def test_returns_none_when_agent_still_using_tools(self): mw = TodoMiddleware() @@ -235,68 +348,299 @@ class TestAfterModel: } assert mw.after_model(state, _make_runtime()) is None - def test_injects_reminder_and_jumps_to_model_when_incomplete(self): + def test_queues_reminder_and_jumps_to_model_when_incomplete(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [HumanMessage(content="hi"), _ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) + result = mw.after_model(state, runtime) assert result is not None assert result["jump_to"] == "model" - assert len(result["messages"]) == 1 - reminder = result["messages"][0] + assert "messages" not in result + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + request.override.assert_called_once() + reminder = request.override.call_args.kwargs["messages"][-1] assert isinstance(reminder, HumanMessage) assert reminder.name == "todo_completion_reminder" + assert reminder.additional_kwargs["hide_from_ui"] is True assert "Step 2" in reminder.content assert "Step 3" in reminder.content + handler.assert_called_once_with("patched-request") def test_reminder_lists_only_incomplete_items(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [_ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) - content = result["messages"][0].content + result = mw.after_model(state, runtime) + assert result is not None + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + mw.wrap_model_call(request, MagicMock(return_value="response")) + content = request.override.call_args.kwargs["messages"][-1].content assert "Step 1" not in content # completed — should not appear assert "Step 2" in content assert "Step 3" in content def test_allows_exit_after_max_reminders(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [ - _completion_reminder_msg(), - _completion_reminder_msg(), _ai_no_tool_calls(), ], "todos": _incomplete_todos(), } + assert mw.after_model(state, runtime) is not None + assert mw.after_model(state, runtime) is not None + assert mw.after_model(state, runtime) is None + + def test_still_sends_reminder_before_cap(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [ + _ai_no_tool_calls(), + ], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, runtime) is not None + result = mw.after_model(state, runtime) + assert result is not None + assert result["jump_to"] == "model" + + def test_does_not_trigger_for_invalid_tool_calls(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_invalid_tool_calls()], + "todos": _incomplete_todos(), + } assert mw.after_model(state, _make_runtime()) is None - def test_still_sends_reminder_before_cap(self): + def test_does_not_trigger_for_raw_provider_tool_calls(self): mw = TodoMiddleware() state = { - "messages": [ - _completion_reminder_msg(), # 1 reminder so far - _ai_no_tool_calls(), - ], + "messages": [_ai_with_raw_provider_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) - assert result is not None - assert result["jump_to"] == "model" + assert mw.after_model(state, _make_runtime()) is None + + def test_does_not_trigger_for_legacy_function_call(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_legacy_function_call()], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, _make_runtime()) is None + + def test_does_not_trigger_for_tool_finish_reason(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_tool_finish_reason()], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, _make_runtime()) is None class TestAafterModel: def test_delegates_to_sync(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [_ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = asyncio.run(mw.aafter_model(state, _make_runtime())) + result = asyncio.run(mw.aafter_model(state, runtime)) assert result is not None assert result["jump_to"] == "model" - assert result["messages"][0].name == "todo_completion_reminder" + assert "messages" not in result + + +class TestWrapModelCall: + def test_no_pending_reminder_passthrough(self): + mw = TodoMiddleware() + request = MagicMock() + request.runtime = _make_runtime() + request.messages = [HumanMessage(content="hi")] + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + request.override.assert_not_called() + handler.assert_called_once_with(request) + + def test_pending_reminder_is_injected_once(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [_ai_no_tool_calls()], + "todos": _incomplete_todos(), + } + mw.after_model(state, runtime) + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + injected_messages = request.override.call_args.kwargs["messages"] + assert injected_messages[-1].name == "todo_completion_reminder" + + request.override.reset_mock() + handler.reset_mock() + handler.return_value = "second-response" + assert mw.wrap_model_call(request, handler) == "second-response" + request.override.assert_not_called() + handler.assert_called_once_with(request) + + +class TestTodoMiddlewareAgentGraphIntegration: + def test_completion_reminder_is_transient_in_real_agent_graph(self): + mw = TodoMiddleware() + model = _CapturingFakeMessagesListChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "write_todos", + "id": "todos-1", + "args": { + "todos": [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "pending"}, + ] + }, + } + ], + ), + AIMessage(content="premature final 1"), + AIMessage(content="premature final 2"), + AIMessage(content="premature final 3"), + ], + ) + graph = create_agent(model=model, tools=[], middleware=[mw]) + + result = graph.invoke( + {"messages": [("user", "finish all todos")]}, + context={"thread_id": "integration-thread", "run_id": "integration-run"}, + ) + + assert len(model.seen_messages) == 4 + reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages] + assert reminders_by_call[0] == [] + assert reminders_by_call[1] == [] + assert len(reminders_by_call[2]) == 1 + assert len(reminders_by_call[3]) == 1 + assert "Step 1" not in reminders_by_call[2][0].content + assert "Step 2" in reminders_by_call[2][0].content + + persisted_reminders = _todo_completion_reminders(result["messages"]) + assert persisted_reminders == [] + assert result["messages"][-1].content == "premature final 3" + assert result["todos"] == [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "pending"}, + ] + assert mw._pending_completion_reminders == {} + assert mw._completion_reminder_counts == {} + + +class TestRunScopedReminderCleanup: + def test_before_agent_clears_stale_count_without_pending_reminder(self): + mw = TodoMiddleware() + stale_runtime = _make_runtime() + stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"} + current_runtime = _make_runtime() + current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"} + other_thread_runtime = _make_runtime() + other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"} + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, stale_runtime) is not None + assert mw.after_model(state, other_thread_runtime) is not None + + # Simulate a model call that drained the pending message, followed by an + # abnormal run end where after_agent did not clear the reminder count. + assert mw._drain_completion_reminders(stale_runtime) + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1 + + mw.before_agent({}, current_runtime) + + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1 + + def test_size_guard_prunes_oldest_count_only_reminder_state(self): + mw = TodoMiddleware() + mw._MAX_COMPLETION_REMINDER_KEYS = 2 + first_runtime = _make_runtime_for("thread-a", "run-a") + second_runtime = _make_runtime_for("thread-b", "run-b") + third_runtime = _make_runtime_for("thread-c", "run-c") + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, first_runtime) is not None + + # Simulate the normal model request path: pending reminder is consumed, + # but the run count remains until after_agent() or stale cleanup. + assert mw._drain_completion_reminders(first_runtime) + assert mw._completion_reminder_count_for_runtime(first_runtime) == 1 + + assert mw.after_model(state, second_runtime) is not None + assert mw.after_model(state, third_runtime) is not None + + assert mw._completion_reminder_count_for_runtime(first_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(second_runtime) == 1 + assert mw._completion_reminder_count_for_runtime(third_runtime) == 1 + assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order + + def test_size_guard_prunes_pending_and_count_state_together(self): + mw = TodoMiddleware() + mw._MAX_COMPLETION_REMINDER_KEYS = 1 + stale_runtime = _make_runtime_for("thread-a", "run-a") + current_runtime = _make_runtime_for("thread-b", "run-b") + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, stale_runtime) is not None + assert mw.after_model(state, current_runtime) is not None + + assert mw._drain_completion_reminders(stale_runtime) == [] + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(current_runtime) == 1 + + +class TestAwrapModelCall: + def test_async_pending_reminder_is_injected(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [_ai_no_tool_calls()], + "todos": _incomplete_todos(), + } + mw.after_model(state, runtime) + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = AsyncMock(return_value="response") + + result = asyncio.run(mw.awrap_model_call(request, handler)) + assert result == "response" + injected_messages = request.override.call_args.kwargs["messages"] + assert injected_messages[-1].name == "todo_completion_reminder" + handler.assert_awaited_once_with("patched-request") diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py index b24ff7b16..9686455c0 100644 --- a/backend/tests/test_token_usage_middleware.py +++ b/backend/tests/test_token_usage_middleware.py @@ -1,9 +1,10 @@ """Tests for TokenUsageMiddleware attribution annotations.""" +import importlib import logging from unittest.mock import MagicMock -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from deerflow.agents.middlewares.token_usage_middleware import ( TOKEN_USAGE_ATTRIBUTION_KEY, @@ -232,3 +233,49 @@ class TestTokenUsageMiddleware: "tool_call_id": "write_todos:remove", } ] + + def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch): + middleware = TokenUsageMiddleware() + first_dispatch = AIMessage( + content="", + tool_calls=[{"id": "task:first", "name": "task", "args": {}}], + ) + second_dispatch = AIMessage( + content="", + tool_calls=[ + {"id": "task:second-a", "name": "task", "args": {}}, + {"id": "task:second-b", "name": "task", "args": {}}, + ], + ) + messages = [ + first_dispatch, + ToolMessage(content="first", tool_call_id="task:first"), + second_dispatch, + ToolMessage(content="second-a", tool_call_id="task:second-a"), + ToolMessage(content="second-b", tool_call_id="task:second-b"), + AIMessage(content="done"), + ] + cached_usage = { + "task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + "task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27}, + } + + task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool") + monkeypatch.setattr( + task_tool_module, + "pop_cached_subagent_usage", + lambda tool_call_id: cached_usage.pop(tool_call_id, None), + ) + + result = middleware.after_model({"messages": messages}, _make_runtime()) + + assert result is not None + usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)] + assert len(usage_updates) == 1 + updated = usage_updates[0] + assert updated.tool_calls == second_dispatch.tool_calls + assert updated.usage_metadata == { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + } diff --git a/backend/tests/test_tool_deduplication.py b/backend/tests/test_tool_deduplication.py index ed9efffaf..f018fc57d 100644 --- a/backend/tests/test_tool_deduplication.py +++ b/backend/tests/test_tool_deduplication.py @@ -65,8 +65,7 @@ def _make_minimal_config(tools): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg): +def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): """Config-loaded async-only tools can still be invoked by sync clients.""" async def async_tool_impl(x: int) -> str: @@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg): +def test_no_duplicates_returned(mock_bash, mock_cfg): """get_available_tools() never returns two tools with the same name.""" mock_cfg.return_value = _make_minimal_config([]) @@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg): +def test_first_occurrence_wins(mock_bash, mock_cfg): """When duplicates exist, the first occurrence is kept.""" mock_cfg.return_value = _make_minimal_config([]) @@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog): +def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog): """A warning is logged for every skipped duplicate.""" import logging diff --git a/backend/tests/test_update_agent_e2e_user_isolation.py b/backend/tests/test_update_agent_e2e_user_isolation.py new file mode 100644 index 000000000..7fa725352 --- /dev/null +++ b/backend/tests/test_update_agent_e2e_user_isolation.py @@ -0,0 +1,253 @@ +"""End-to-end verification for update_agent's user_id resolution. + +PR #2784 hardened setup_agent to prefer runtime.context["user_id"] over the +contextvar. update_agent had the same latent gap: it unconditionally called +get_effective_user_id() at module level, so any scenario where the contextvar +was unavailable while runtime.context carried user_id (a background task +scheduled outside the request task, a worker pool that doesn't copy_context, +checkpoint resume on a different task) would silently route writes to +users/default/agents/... + +These tests are load-bearing under @no_auto_user (contextvar empty): + +- The negative-control test confirms the fixture actually puts the tool in + the regime where the contextvar fallback would land in users/default/. + Without that, the positive test would be vacuously satisfied. +- The positive test verifies update_agent honours runtime.context["user_id"] + injected by inject_authenticated_user_context in the gateway. Before the + fix in this PR, this test failed; now it passes. +""" + +from __future__ import annotations + +from contextlib import ExitStack +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +import yaml +from _agent_e2e_helpers import build_single_tool_call_model +from langchain_core.messages import HumanMessage + +from app.gateway.services import ( + build_run_config, + inject_authenticated_user_context, + merge_run_context_overrides, +) +from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context + + +def _make_request(user_id_str: str | None) -> SimpleNamespace: + user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") if user_id_str else None + return SimpleNamespace(state=SimpleNamespace(user=user)) + + +def _assemble_config(*, body_context: dict | None, request_user_id: str | None, thread_id: str) -> dict: + config = build_run_config(thread_id, {"recursion_limit": 50}, None, assistant_id="lead_agent") + merge_run_context_overrides(config, body_context) + inject_authenticated_user_context(config, _make_request(request_user_id)) + return config + + +def _seed_existing_agent(tmp_path: Path, user_id: str, agent_name: str, soul: str = "# Original"): + """Pre-create an agent on disk for update_agent to overwrite.""" + agent_dir = tmp_path / "users" / user_id / "agents" / agent_name + agent_dir.mkdir(parents=True, exist_ok=True) + (agent_dir / "config.yaml").write_text( + yaml.dump({"name": agent_name, "description": "old"}, allow_unicode=True), + encoding="utf-8", + ) + (agent_dir / "SOUL.md").write_text(soul, encoding="utf-8") + return agent_dir + + +def _make_paths_mock(tmp_path: Path): + paths = MagicMock() + paths.base_dir = tmp_path + paths.agent_dir = lambda name: tmp_path / "agents" / name + paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name + return paths + + +def _patch_update_agent_dependencies(tmp_path: Path): + """update_agent reads load_agent_config + get_app_config — stub them + minimally so the tool can run without a real config file or LLM.""" + fake_model_cfg = SimpleNamespace(name="fake-model") + fake_app_cfg = MagicMock() + fake_app_cfg.get_model_config = lambda name: fake_model_cfg if name == "fake-model" else None + + return [ + patch( + "deerflow.tools.builtins.update_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ), + patch( + "deerflow.tools.builtins.update_agent_tool.get_app_config", + return_value=fake_app_cfg, + ), + # load_agent_config (used by update_agent to read existing config) also + # reads paths via its own module-level get_paths reference. Patch it too + # or the tool returns "Agent does not exist" before touching disk. + patch( + "deerflow.config.agents_config.get_paths", + return_value=_make_paths_mock(tmp_path), + ), + ] + + +def _build_update_graph(*, soul_payload: str): + from langchain.agents import create_agent + + from deerflow.tools.builtins.update_agent_tool import update_agent + + fake_model = build_single_tool_call_model( + tool_name="update_agent", + tool_args={"soul": soul_payload, "description": "refined"}, + tool_call_id="call_update_1", + final_text="updated", + ) + return create_agent(model=fake_model, tools=[update_agent], system_prompt="updater") + + +# --------------------------------------------------------------------------- +# Negative control — proves the test environment puts update_agent in the +# regime where the contextvar fallback would land in default/. +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +def test_update_agent_falls_back_to_default_when_no_inject_and_no_contextvar(tmp_path: Path): + """No request.state.user, no contextvar — update_agent must look in + users/default/agents/. We seed the file there so the tool succeeds and + we know which directory it actually consulted.""" + from langgraph.runtime import Runtime + + _seed_existing_agent(tmp_path, "default", "fallback-target") + + config = _assemble_config( + body_context={"agent_name": "fallback-target"}, + request_user_id=None, # no auth, inject is no-op + thread_id="thread-update-1", + ) + runtime_ctx = _build_runtime_context("thread-update-1", "run-1", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# Fallback Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + graph.invoke( + {"messages": [HumanMessage(content="update fallback-target")]}, + config=config, + ) + + soul = (tmp_path / "users" / "default" / "agents" / "fallback-target" / "SOUL.md").read_text() + assert soul == "# Fallback Updated", "Sanity: tool should have written under default/" + + +# --------------------------------------------------------------------------- +# Regression guard — passes on this branch, would fail on main before the fix. +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +def test_update_agent_should_use_runtime_context_user_id_when_contextvar_missing(tmp_path: Path): + """update_agent prefers the authenticated user_id carried in + runtime.context (placed there by inject_authenticated_user_context) + over the contextvar — same contract as setup_agent (PR #2784). + + Before this PR's fix, update_agent unconditionally called + get_effective_user_id() and landed in default/ whenever the contextvar + was unavailable. This test pins the corrected behaviour. + """ + from langgraph.runtime import Runtime + + auth_uid = "abcdef01-2345-6789-abcd-ef0123456789" + + # Seed the agent in BOTH locations so we can prove which one was opened. + auth_dir = _seed_existing_agent(tmp_path, auth_uid, "shared-name", soul="# Auth Original") + default_dir = _seed_existing_agent(tmp_path, "default", "shared-name", soul="# Default Original") + + config = _assemble_config( + body_context={"agent_name": "shared-name"}, + request_user_id=auth_uid, + thread_id="thread-update-2", + ) + runtime_ctx = _build_runtime_context("thread-update-2", "run-2", config.get("context"), None) + assert runtime_ctx["user_id"] == auth_uid, "Pre-condition: inject must have placed user_id into runtime_ctx" + + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# Auth Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + graph.invoke( + {"messages": [HumanMessage(content="update shared-name")]}, + config=config, + ) + + auth_soul = (auth_dir / "SOUL.md").read_text() + default_soul = (default_dir / "SOUL.md").read_text() + + assert auth_soul == "# Auth Updated", f"REGRESSION: update_agent ignored runtime.context['user_id']={auth_uid!r} and routed the write to users/default/ instead. auth_soul={auth_soul!r}, default_soul={default_soul!r}" + assert default_soul == "# Default Original", "REGRESSION: update_agent corrupted the shared default-user agent. It should have written under the authenticated user's path." + + +# --------------------------------------------------------------------------- +# Positive — when contextvar IS the auth user (the normal HTTP case), things +# already work. Pin it as a regression guard so future refactors don't +# accidentally break the contextvar path in pursuit of the runtime-context fix. +# --------------------------------------------------------------------------- + + +def test_update_agent_uses_contextvar_when_present(tmp_path: Path, monkeypatch): + """The normal HTTP case: contextvar is set by auth_middleware. This must + keep working regardless of how runtime.context is populated.""" + from types import SimpleNamespace as _SN + + from deerflow.runtime.user_context import reset_current_user, set_current_user + + auth_uid = "11112222-3333-4444-5555-666677778888" + user = _SN(id=auth_uid, email="ctxvar@local") + + _seed_existing_agent(tmp_path, auth_uid, "ctxvar-agent", soul="# Original") + + from langgraph.runtime import Runtime + + config = _assemble_config( + body_context={"agent_name": "ctxvar-agent"}, + request_user_id=auth_uid, + thread_id="thread-update-3", + ) + runtime_ctx = _build_runtime_context("thread-update-3", "run-3", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# CtxVar Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + token = set_current_user(user) + try: + final = graph.invoke( + {"messages": [HumanMessage(content="update ctxvar-agent")]}, + config=config, + ) + finally: + reset_current_user(token) + + # surface the tool's reply for debug if it errored + tool_replies = [m.content for m in final["messages"] if getattr(m, "type", "") == "tool"] + soul = (tmp_path / "users" / auth_uid / "agents" / "ctxvar-agent" / "SOUL.md").read_text() + assert soul == "# CtxVar Updated", f"tool replies: {tool_replies}" diff --git a/backend/uv.lock b/backend/uv.lock index e144fb07e..9cc2030fa 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -763,6 +763,9 @@ dependencies = [ ] [package.optional-dependencies] +discord = [ + { name = "discord-py" }, +] postgres = [ { name = "deerflow-harness", extra = ["postgres"] }, ] @@ -781,6 +784,7 @@ requires-dist = [ { name = "deerflow-harness", editable = "packages/harness" }, { name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" }, { name = "dingtalk-stream", specifier = ">=0.24.3" }, + { name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" }, { name = "email-validator", specifier = ">=2.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "httpx", specifier = ">=0.28.0" }, @@ -795,7 +799,7 @@ requires-dist = [ { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, { name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" }, ] -provides-extras = ["postgres"] +provides-extras = ["postgres", "discord"] [package.metadata.requires-dev] dev = [ @@ -923,6 +927,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" }, ] +[[package]] +name = "discord-py" +version = "2.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" }, +] + [[package]] name = "distro" version = "1.9.0" @@ -2005,7 +2022,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.7.36" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -2018,9 +2035,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" }, + { url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" }, ] [package.optional-dependencies] diff --git a/config.example.yaml b/config.example.yaml index 9a8d07bf4..7396f6cfb 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1029,6 +1029,14 @@ run_events: # client_secret: $DINGTALK_CLIENT_SECRET # allowed_users: [] # empty = allow all # card_template_id: "" # Optional: AI Card template ID for streaming updates +# +# discord: +# enabled: false +# bot_token: $DISCORD_BOT_TOKEN +# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID +# mention_only: false # If true, only respond when the bot is mentioned +# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention) +# thread_mode: false # If true, group a channel conversation into a thread # ============================================================================ # Guardrails Configuration diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf index 45be0ab97..18481adb3 100644 --- a/docker/nginx/nginx.conf +++ b/docker/nginx/nginx.conf @@ -28,6 +28,10 @@ http { set $gateway_upstream gateway:8001; set $frontend_upstream frontend:3000; + # Default proxy settings for all locations (streaming/SSE support) + proxy_buffering off; + proxy_cache off; + # Keep the unified nginx endpoint same-origin by default. When split # frontend/backend or port-forwarded deployments need browser CORS, # configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and @@ -49,8 +53,6 @@ http { proxy_set_header Connection ''; # SSE/Streaming support - proxy_buffering off; - proxy_cache off; proxy_set_header X-Accel-Buffering no; # Timeouts for long-running requests @@ -70,6 +72,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Memory endpoint @@ -80,6 +83,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: MCP configuration endpoint @@ -90,6 +94,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Skills configuration endpoint @@ -100,6 +105,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Agents endpoint @@ -110,6 +116,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Uploads endpoint @@ -124,6 +131,8 @@ http { # Large file upload support client_max_body_size 100M; proxy_request_buffering off; + + # Disable response buffering to avoid permission errors } # Custom API: Other endpoints under /api/threads @@ -134,6 +143,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # API Documentation: Swagger UI @@ -144,6 +154,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # API Documentation: ReDoc @@ -154,6 +165,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # API Documentation: OpenAPI Schema @@ -164,6 +176,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Health check endpoint (gateway) @@ -174,6 +187,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # ── Provisioner API (sandbox management) ──────────────────────── @@ -187,6 +201,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*). @@ -198,6 +213,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + # Disable buffering to avoid permission errors when nginx + # runs as a non-root user (e.g. local development). } # All other requests go to frontend @@ -220,4 +238,4 @@ http { proxy_read_timeout 600s; } } -} +} \ No newline at end of file diff --git a/docker/nginx/nginx.local.conf b/docker/nginx/nginx.local.conf index 68ca1f1ac..035406862 100644 --- a/docker/nginx/nginx.local.conf +++ b/docker/nginx/nginx.local.conf @@ -70,6 +70,11 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + # Disable buffering to avoid permission errors when nginx + # runs as a non-root user (e.g. local development). + proxy_buffering off; + proxy_cache off; } # Custom API: Memory endpoint @@ -80,6 +85,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: MCP configuration endpoint @@ -90,6 +98,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: Skills configuration endpoint @@ -100,6 +111,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: Agents endpoint @@ -110,6 +124,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: Uploads endpoint @@ -124,6 +141,10 @@ http { # Large file upload support client_max_body_size 100M; proxy_request_buffering off; + + # Disable response buffering to avoid permission errors + proxy_buffering off; + proxy_cache off; } # Custom API: Other endpoints under /api/threads @@ -134,6 +155,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # API Documentation: Swagger UI @@ -144,6 +168,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # API Documentation: ReDoc @@ -154,6 +181,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # API Documentation: OpenAPI Schema @@ -164,6 +194,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Health check endpoint (gateway) @@ -174,6 +207,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Catch-all for any /api/* prefix not matched by a more specific block above. @@ -193,6 +229,11 @@ http { # Auth endpoints set HttpOnly cookies — make sure nginx doesn't # strip the Set-Cookie header from upstream responses. proxy_pass_header Set-Cookie; + + # Disable buffering to avoid permission errors when nginx + # runs as a non-root user (e.g. local development). + proxy_buffering off; + proxy_cache off; } # All other requests go to frontend diff --git a/frontend/README.md b/frontend/README.md index 6db881301..4ad70fb1f 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -82,10 +82,10 @@ pnpm start Key environment variables (see `.env.example` for full list): ```bash -# Backend API URLs (optional, uses nginx proxy by default) +# Backend API URL (optional, uses local Next.js/nginx proxy by default) NEXT_PUBLIC_BACKEND_BASE_URL="http://localhost:8001" -# LangGraph API URLs (optional, uses nginx proxy by default) -NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:2024" +# LangGraph-compatible API URL (optional, uses local Next.js/nginx proxy by default) +NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:8001/api" ``` ## Project Structure diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index 8627762b0..c16af882a 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -66,6 +66,7 @@ export default function AgentChatPage() { thread, pendingUsageMessages, sendMessage, + isUploading, isHistoryLoading, hasMoreHistory, loadMoreHistory, @@ -106,7 +107,11 @@ export default function AgentChatPage() { const handleSubmit = useCallback( (message: PromptInputMessage) => { - void sendMessage(threadId, message, { agent_name }); + const sendPromise = sendMessage(threadId, message, { agent_name }); + if (message.files.length > 0) { + return sendPromise; + } + void sendPromise; }, [sendMessage, threadId, agent_name], ); @@ -243,7 +248,10 @@ export default function AgentChatPage() { ) } - disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"} + disabled={ + env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" || + isUploading + } onContextChange={(context) => setSettings("context", context)} onSubmit={handleSubmit} onStop={handleStop} diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index ed7d91c68..6f865ade8 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -109,7 +109,11 @@ export default function ChatPage() { const handleSubmit = useCallback( (message: PromptInputMessage) => { - void sendMessage(threadId, message); + const sendPromise = sendMessage(threadId, message); + if (message.files.length > 0) { + return sendPromise; + } + void sendPromise; }, [sendMessage, threadId], ); diff --git a/frontend/src/components/ai-elements/prompt-input.tsx b/frontend/src/components/ai-elements/prompt-input.tsx index 52a909cdd..4609c43d3 100644 --- a/frontend/src/components/ai-elements/prompt-input.tsx +++ b/frontend/src/components/ai-elements/prompt-input.tsx @@ -499,6 +499,10 @@ export const PromptInput = ({ // Keep a ref to files for cleanup on unmount (avoids stale closure) const filesRef = useRef(files); filesRef.current = files; + const providerTextRef = useRef(""); + if (usingProvider) { + providerTextRef.current = controller.textInput.value; + } const openFileDialogLocal = useCallback(() => { inputRef.current?.click(); @@ -768,6 +772,24 @@ export const PromptInput = ({ } // Convert blob URLs to data URLs asynchronously + const submittedFileIds = files.map((file) => file.id); + const clearSubmittedState = () => { + const currentFileIds = new Set(filesRef.current.map((file) => file.id)); + const submittedFileIdsStillPresent = submittedFileIds.filter((id) => + currentFileIds.has(id), + ); + if (submittedFileIdsStillPresent.length === filesRef.current.length) { + clear(); + } else { + for (const id of submittedFileIdsStillPresent) { + remove(id); + } + } + if (usingProvider && providerTextRef.current === text) { + controller.textInput.clear(); + } + }; + Promise.all( files.map(async ({ id, ...item }) => { if (item.file instanceof File) { @@ -793,20 +815,14 @@ export const PromptInput = ({ if (result instanceof Promise) { result .then(() => { - clear(); - if (usingProvider) { - controller.textInput.clear(); - } + clearSubmittedState(); }) .catch(() => { // Don't clear on error - user may want to retry }); } else { // Sync function completed without throwing, clear attachments - clear(); - if (usingProvider) { - controller.textInput.clear(); - } + clearSubmittedState(); } } catch { // Don't clear on error - user may want to retry diff --git a/frontend/src/components/workspace/input-box.tsx b/frontend/src/components/workspace/input-box.tsx index 9a33d41e6..6344a26d2 100644 --- a/frontend/src/components/workspace/input-box.tsx +++ b/frontend/src/components/workspace/input-box.tsx @@ -110,6 +110,7 @@ export function InputBox({ threadId, initialValue, onContextChange, + onFollowupsVisibilityChange, onSubmit, onStop, ...props @@ -142,7 +143,8 @@ export function InputBox({ reasoning_effort?: "minimal" | "low" | "medium" | "high"; }, ) => void; - onSubmit?: (message: PromptInputMessage) => void; + onFollowupsVisibilityChange?: (visible: boolean) => void; + onSubmit?: (message: PromptInputMessage) => void | Promise; onStop?: () => void; }) { const { t } = useI18n(); @@ -251,12 +253,12 @@ export function InputBox({ ); const handleSubmit = useCallback( - async (message: PromptInputMessage) => { + (message: PromptInputMessage) => { if (status === "streaming") { onStop?.(); return; } - if (!message.text) { + if (!message.text.trim() && message.files.length === 0) { return; } setFollowups([]); @@ -274,11 +276,14 @@ export function InputBox({ selectedModel?.supports_thinking ?? false, ), }); - setTimeout(() => onSubmit?.(message), 0); - return; + return new Promise((resolve, reject) => { + setTimeout(() => { + Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject); + }, 0); + }); } - onSubmit?.(message); + return onSubmit?.(message); }, [ context, @@ -348,6 +353,14 @@ export function InputBox({ !followupsHidden && (followupsLoading || followups.length > 0); + useEffect(() => { + onFollowupsVisibilityChange?.(showFollowups); + }, [onFollowupsVisibilityChange, showFollowups]); + + useEffect(() => { + return () => onFollowupsVisibilityChange?.(false); + }, [onFollowupsVisibilityChange]); + useEffect(() => { messagesRef.current = thread.messages; }, [thread.messages]); diff --git a/frontend/src/components/workspace/messages/message-token-usage.tsx b/frontend/src/components/workspace/messages/message-token-usage.tsx index cc8d0debb..84f8a8057 100644 --- a/frontend/src/components/workspace/messages/message-token-usage.tsx +++ b/frontend/src/components/workspace/messages/message-token-usage.tsx @@ -12,13 +12,11 @@ function TokenUsageSummary({ inputTokens, outputTokens, totalTokens, - unavailable = false, }: { className?: string; inputTokens?: number; outputTokens?: number; totalTokens?: number; - unavailable?: boolean; }) { const { t } = useI18n(); @@ -33,21 +31,15 @@ function TokenUsageSummary({ {t.tokenUsage.label} - {!unavailable ? ( - <> - - {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} - - - {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} - - - {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} - - - ) : ( - {t.tokenUsage.unavailableShort} - )} + + {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} + + + {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} + + + {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} + ); } @@ -55,7 +47,7 @@ function TokenUsageSummary({ export function MessageTokenUsageList({ className, enabled = false, - isLoading = false, + isLoading: _isLoading = false, messages, }: { className?: string; @@ -63,7 +55,7 @@ export function MessageTokenUsageList({ isLoading?: boolean; messages: Message[]; }) { - if (!enabled || isLoading) { + if (!enabled) { return null; } @@ -75,13 +67,16 @@ export function MessageTokenUsageList({ const usage = accumulateUsage(aiMessages); + if (!usage) { + return null; + } + return ( ); } diff --git a/frontend/src/content/en/application/agents-and-threads.mdx b/frontend/src/content/en/application/agents-and-threads.mdx index bbf3cfc7e..0a281a33e 100644 --- a/frontend/src/content/en/application/agents-and-threads.mdx +++ b/frontend/src/content/en/application/agents-and-threads.mdx @@ -111,10 +111,9 @@ checkpointer: ``` - The LangGraph Server manages its own state separately. The - checkpointer setting in config.yaml applies to the - embedded DeerFlowClient (used in direct Python integrations), not - to the LangGraph Server deployment used by DeerFlow App. + The Gateway embedded runtime uses the checkpointer setting in + config.yaml. The same setting is also used by + DeerFlowClient in direct Python integrations. ### Thread data storage diff --git a/frontend/src/content/en/application/deployment-guide.mdx b/frontend/src/content/en/application/deployment-guide.mdx index 04b3599c0..52b59cf01 100644 --- a/frontend/src/content/en/application/deployment-guide.mdx +++ b/frontend/src/content/en/application/deployment-guide.mdx @@ -23,8 +23,7 @@ Services started: | Service | Port | Description | | ----------- | ---- | ------------------------ | -| LangGraph | 2024 | DeerFlow Harness runtime | -| Gateway API | 8001 | FastAPI backend | +| Gateway API | 8001 | FastAPI backend + embedded agent runtime | | Frontend | 3000 | Next.js UI | | nginx | 2026 | Unified reverse proxy | @@ -36,13 +35,12 @@ Access the app at **http://localhost:2026**. make stop ``` -Stops all four services. Safe to run even if a service is not running. +Stops all services. Safe to run even if a service is not running. ``` -logs/langgraph.log # Agent runtime logs -logs/gateway.log # API gateway logs +logs/gateway.log # API gateway and agent runtime logs logs/frontend.log # Next.js dev server logs logs/nginx.log # nginx access/error logs ``` @@ -50,7 +48,7 @@ logs/nginx.log # nginx access/error logs Tail a log in real time: ```bash -tail -f logs/langgraph.log +tail -f logs/gateway.log ``` @@ -74,7 +72,7 @@ export DEER_FLOW_ROOT=/path/to/deer-flow docker compose -f docker/docker-compose-dev.yaml up --build ``` -Services: nginx, frontend, gateway, langgraph, and optionally provisioner (for K8s-managed sandboxes). +Services: nginx, frontend, gateway, and optionally provisioner (for K8s-managed sandboxes). Access the app at **http://localhost:2026**. @@ -99,7 +97,7 @@ The `docker-compose*.yaml` files include an `env_file: ../.env` directive that l ### Data persistence -Thread data is stored in `backend/.deer-flow/threads/`. In Docker deployments, this directory is bind-mounted into the langgraph container. +Thread data is stored in `backend/.deer-flow/threads/`. In Docker deployments, this directory is bind-mounted into the gateway container. To avoid data loss when containers are recreated: @@ -161,14 +159,7 @@ When `USERDATA_PVC_NAME` is set, the provisioner automatically uses subPath (`th ### nginx configuration -nginx routes all traffic. Key environment variables that control routing: - -| Variable | Default | Description | -| -------------------- | ---------------- | --------------------------------------- | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | LangGraph service address | -| `LANGGRAPH_REWRITE` | `/` | URL rewrite prefix for LangGraph routes | - -These are set in the Docker Compose environment and processed by `envsubst` at container startup. +nginx routes all traffic to the frontend or Gateway. `/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes, so no separate LangGraph upstream is required. ### Authentication @@ -186,8 +177,7 @@ openssl rand -base64 32 | Service | Minimum | Recommended | | ------------------------------- | ---------------- | ---------------- | -| LangGraph (agent runtime) | 2 vCPU, 4 GB RAM | 4 vCPU, 8 GB RAM | -| Gateway | 0.5 vCPU, 512 MB | 1 vCPU, 1 GB | +| Gateway + agent runtime | 2 vCPU, 4 GB RAM | 4 vCPU, 8 GB RAM | | Frontend | 0.5 vCPU, 512 MB | 1 vCPU, 1 GB | | Sandbox container (per session) | 1 vCPU, 1 GB | 2 vCPU, 2 GB | @@ -199,9 +189,6 @@ After starting, verify the deployment: # Check Gateway health curl http://localhost:8001/health -# Check LangGraph health -curl http://localhost:2024/ok - # List configured models (through nginx) curl http://localhost:2026/api/models ``` diff --git a/frontend/src/content/en/application/index.mdx b/frontend/src/content/en/application/index.mdx index 2cb15a911..b45a6cbf0 100644 --- a/frontend/src/content/en/application/index.mdx +++ b/frontend/src/content/en/application/index.mdx @@ -25,11 +25,11 @@ DeerFlow App is the reference implementation of what a production DeerFlow exper | **Streaming responses** | Real-time token streaming with thinking steps and tool call visibility | | **Artifact viewer** | In-browser preview and download of files and outputs produced by the agent | | **Extensions UI** | Enable/disable MCP servers and skills without editing config files | -| **Gateway API** | FastAPI-based REST API that bridges the frontend and the LangGraph runtime | +| **Gateway API** | FastAPI-based REST API with the embedded LangGraph-compatible agent runtime | ## Architecture -The DeerFlow App runs as four services behind a single nginx reverse proxy: +The DeerFlow App runs behind a single nginx reverse proxy: ``` ┌──────────────────┐ @@ -42,19 +42,11 @@ The DeerFlow App runs as four services behind a single nginx reverse proxy: │ Frontend :3000 │ │ Gateway API :8001 │ │ (Next.js) │ │ (FastAPI) │ └──────────────────┘ └──────────────────────┘ - │ - ┌─────────┘ - ▼ - ┌──────────────────────┐ - │ LangGraph :2024 │ - │ (DeerFlow Harness) │ - └──────────────────────┘ ``` -- **nginx**: routes requests — `/api/*` to the Gateway, LangGraph streaming endpoints to LangGraph directly, and everything else to the frontend. -- **Frontend** (Next.js + React): the browser UI. Communicates with both the Gateway and LangGraph. -- **Gateway** (FastAPI): handles API operations — model listing, agent CRUD, memory, extensions management, file uploads. -- **LangGraph**: the DeerFlow Harness runtime. Manages thread state, agent execution, and streaming. +- **nginx**: routes requests — `/api/*` and `/api/langgraph/*` to Gateway, and everything else to the frontend. +- **Frontend** (Next.js + React): the browser UI. Communicates with Gateway. +- **Gateway** (FastAPI): handles API operations and the embedded LangGraph-compatible runtime for thread state, agent execution, and streaming. ## Technology stack @@ -64,7 +56,7 @@ The DeerFlow App runs as four services behind a single nginx reverse proxy: | Gateway | FastAPI, Python 3.12, uvicorn | | Agent runtime | LangGraph, LangChain, DeerFlow Harness | | Reverse proxy | nginx | -| State persistence | LangGraph Server (default) + optional SQLite/PostgreSQL checkpointer | +| State persistence | Gateway runtime + optional SQLite/PostgreSQL checkpointer | diff --git a/frontend/src/content/en/application/operations-and-troubleshooting.mdx b/frontend/src/content/en/application/operations-and-troubleshooting.mdx index 8b21cf4b4..0f8d7e44c 100644 --- a/frontend/src/content/en/application/operations-and-troubleshooting.mdx +++ b/frontend/src/content/en/application/operations-and-troubleshooting.mdx @@ -15,15 +15,13 @@ All services write logs to the `logs/` directory when started with `make dev`: | File | Service | | -------------------- | ------------------------------------ | -| `logs/langgraph.log` | LangGraph / DeerFlow Harness runtime | -| `logs/gateway.log` | FastAPI Gateway API | +| `logs/gateway.log` | FastAPI Gateway API and agent runtime | | `logs/frontend.log` | Next.js frontend dev server | | `logs/nginx.log` | nginx reverse proxy | Tail logs in real time: ```bash -tail -f logs/langgraph.log tail -f logs/gateway.log ``` @@ -41,9 +39,6 @@ Verify each service is responding: # Gateway health curl http://localhost:8001/health -# LangGraph health -curl http://localhost:2024/ok - # Through nginx (verifies full proxy chain) curl http://localhost:2026/api/models ``` @@ -66,7 +61,7 @@ grep config_version config.yaml ### The app loads but the agent doesn't respond -1. Check `logs/langgraph.log` for startup errors. +1. Check `logs/gateway.log` for startup errors. 2. Verify your model is correctly configured in `config.yaml` with a valid API key. 3. Confirm the API key environment variable is set in the shell that ran `make dev`. 4. Test the model endpoint directly with `curl` to rule out network issues. @@ -126,7 +121,7 @@ Connection refused: http://provisioner:8002 If MCP tools appear in `extensions_config.json` but are not available in the agent: -1. Check `logs/langgraph.log` for MCP initialization errors. +1. Check `logs/gateway.log` for MCP initialization errors. 2. Verify the MCP server command is installed (`npx`, `uvx`, or the relevant binary). 3. Test the server command manually to confirm it starts without errors. 4. Set `log_level: debug` to see detailed MCP loading output. @@ -137,7 +132,7 @@ If MCP tools appear in `extensions_config.json` but are not available in the age - Verify `memory.enabled: true` in `config.yaml`. - Check that the storage path is writable: `ls -la backend/.deer-flow/`. -- Look for memory update errors in `logs/langgraph.log` (search for "memory"). +- Look for memory update errors in `logs/gateway.log` (search for "memory"). ## Data backup diff --git a/frontend/src/content/en/application/quick-start.mdx b/frontend/src/content/en/application/quick-start.mdx index 5ecfb3a26..c3baa0764 100644 --- a/frontend/src/content/en/application/quick-start.mdx +++ b/frontend/src/content/en/application/quick-start.mdx @@ -1,6 +1,6 @@ --- title: Quick Start -description: This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. All four services (LangGraph, Gateway, Frontend, nginx) start together and are accessible through a single URL. +description: This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. Gateway, Frontend, and nginx start together and are accessible through a single URL. --- import { Callout, Cards, Steps } from "nextra/components"; @@ -12,7 +12,7 @@ import { Callout, Cards, Steps } from "nextra/components"; Python 3.12+, Node.js 22+, and at least one LLM API key. -This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. All four services (LangGraph, Gateway, Frontend, nginx) start together and are accessible through a single URL. +This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. Gateway, Frontend, and nginx start together and are accessible through a single URL. ## Prerequisites @@ -88,8 +88,7 @@ make dev This starts: -- LangGraph server on port `2024` -- Gateway API on port `8001` +- Gateway API and embedded agent runtime on port `8001` - Frontend on port `3000` - nginx reverse proxy on port `2026` @@ -113,15 +112,13 @@ Log files: | Service | Log file | | --------- | -------------------- | -| LangGraph | `logs/langgraph.log` | | Gateway | `logs/gateway.log` | | Frontend | `logs/frontend.log` | | nginx | `logs/nginx.log` | If something is not working, check the log files first. Most startup errors - (missing API keys, config parsing failures) appear in `logs/langgraph.log` or - `logs/gateway.log`. + (missing API keys, config parsing failures) appear in `logs/gateway.log`. diff --git a/frontend/src/content/en/harness/skills.mdx b/frontend/src/content/en/harness/skills.mdx index 09f8b0d43..78247c40b 100644 --- a/frontend/src/content/en/harness/skills.mdx +++ b/frontend/src/content/en/harness/skills.mdx @@ -68,7 +68,7 @@ DeerFlow ships with the following public skills: ### Discovery and loading -`load_skills()` in `skills/loader.py` scans both `public/` and `custom/` directories under the configured skills path. It re-reads `ExtensionsConfig.from_file()` on every call, which means enabling or disabling a skill through the Gateway API takes effect immediately in the running LangGraph server without a restart. +`load_skills()` in `skills/loader.py` scans both `public/` and `custom/` directories under the configured skills path. It re-reads `ExtensionsConfig.from_file()` on every call, which means enabling or disabling a skill through the Gateway API takes effect immediately in the running agent runtime without a restart. ### Parsing diff --git a/frontend/src/content/zh/application/configuration.mdx b/frontend/src/content/zh/application/configuration.mdx index 639eeaec5..0094323e7 100644 --- a/frontend/src/content/zh/application/configuration.mdx +++ b/frontend/src/content/zh/application/configuration.mdx @@ -215,7 +215,6 @@ BETTER_AUTH_SECRET=local-dev-secret-at-least-32-chars | `DEER_FLOW_CONFIG_PATH` | 自动发现 | `config.yaml` 的绝对路径 | | `LOG_LEVEL` | `info` | 日志详细程度(`debug`/`info`/`warning`/`error`) | | `DEER_FLOW_ROOT` | 仓库根目录 | 用于 Docker 中的技能和线程挂载 | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | nginx 代理的 LangGraph 地址 | diff --git a/frontend/src/content/zh/application/deployment-guide.mdx b/frontend/src/content/zh/application/deployment-guide.mdx index 59eceece2..635120337 100644 --- a/frontend/src/content/zh/application/deployment-guide.mdx +++ b/frontend/src/content/zh/application/deployment-guide.mdx @@ -23,8 +23,7 @@ make dev | 服务 | 端口 | 描述 | | ----------- | ---- | ----------------------- | -| LangGraph | 2024 | DeerFlow Harness 运行时 | -| Gateway API | 8001 | FastAPI 后端 | +| Gateway API | 8001 | FastAPI 后端 + 嵌入式 Agent 运行时 | | 前端 | 3000 | Next.js 界面 | | nginx | 2026 | 统一反向代理 | @@ -36,13 +35,12 @@ make dev make stop ``` -停止所有四个服务。即使某个服务没有运行也可以安全执行。 +停止所有服务。即使某个服务没有运行也可以安全执行。 ``` -logs/langgraph.log # Agent 运行时日志 -logs/gateway.log # API Gateway 日志 +logs/gateway.log # API Gateway 和 Agent 运行时日志 logs/frontend.log # Next.js 开发服务器日志 logs/nginx.log # nginx 访问/错误日志 ``` @@ -50,7 +48,7 @@ logs/nginx.log # nginx 访问/错误日志 实时追踪日志: ```bash -tail -f logs/langgraph.log +tail -f logs/gateway.log ``` @@ -96,7 +94,7 @@ BETTER_AUTH_SECRET=your-secret-here-min-32-chars ### 数据持久化 -线程数据存储在 `backend/.deer-flow/threads/`。在 Docker 部署中,此目录被绑定挂载到 langgraph 容器中。 +线程数据存储在 `backend/.deer-flow/threads/`。在 Docker 部署中,此目录会绑定挂载到 gateway 容器中。 为避免容器重建时数据丢失: @@ -156,14 +154,7 @@ SKILLS_PVC_NAME=deer-flow-skills-pvc ### nginx 配置 -nginx 路由所有流量,控制路由的关键环境变量: - -| 变量 | 默认值 | 描述 | -| -------------------- | ---------------- | ----------------------------- | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | LangGraph 服务地址 | -| `LANGGRAPH_REWRITE` | `/` | LangGraph 路由的 URL 重写前缀 | - -这些在 Docker Compose 环境中设置,并在容器启动时由 `envsubst` 处理。 +nginx 将流量路由到前端或 Gateway。`/api/langgraph/*` 会被重写到 Gateway 的 LangGraph-compatible `/api/*` 路由,因此不需要单独的 LangGraph upstream。 ### 认证配置 @@ -181,8 +172,7 @@ openssl rand -base64 32 | 服务 | 最低配置 | 推荐配置 | | ------------------------- | ---------------- | ---------------- | -| LangGraph(Agent 运行时) | 2 vCPU、4 GB RAM | 4 vCPU、8 GB RAM | -| Gateway | 0.5 vCPU、512 MB | 1 vCPU、1 GB | +| Gateway + Agent 运行时 | 2 vCPU、4 GB RAM | 4 vCPU、8 GB RAM | | 前端 | 0.5 vCPU、512 MB | 1 vCPU、1 GB | | 沙箱容器(每会话) | 1 vCPU、1 GB | 2 vCPU、2 GB | @@ -194,9 +184,6 @@ openssl rand -base64 32 # 检查 Gateway 健康状态 curl http://localhost:8001/health -# 检查 LangGraph 健康状态 -curl http://localhost:2024/ok - # 通过 nginx 列出配置的模型(验证完整代理链) curl http://localhost:2026/api/models ``` diff --git a/frontend/src/content/zh/application/index.mdx b/frontend/src/content/zh/application/index.mdx index 81e7113e2..c12959b42 100644 --- a/frontend/src/content/zh/application/index.mdx +++ b/frontend/src/content/zh/application/index.mdx @@ -25,11 +25,11 @@ DeerFlow 应用是 DeerFlow 生产体验的参考实现。它将 Harness 运行 | **流式响应** | 实时 token 流式传输,带思考步骤和工具调用可见性 | | **产出物查看器** | Agent 生成文件和输出的浏览器内预览和下载 | | **扩展界面** | 无需编辑配置文件即可启用/禁用 MCP 服务器和技能 | -| **Gateway API** | 桥接前端和 LangGraph 运行时的基于 FastAPI 的 REST API | +| **Gateway API** | 基于 FastAPI 的 REST API,并内置 LangGraph-compatible Agent 运行时 | ## 架构 -DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理提供: +DeerFlow 应用通过单个 nginx 反向代理提供: ``` ┌──────────────────┐ @@ -42,19 +42,11 @@ DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理 │ 前端 :3000 │ │ Gateway API :8001 │ │ (Next.js) │ │ (FastAPI) │ └──────────────────┘ └──────────────────────┘ - │ - ┌─────────┘ - ▼ - ┌──────────────────────┐ - │ LangGraph :2024 │ - │ (DeerFlow Harness) │ - └──────────────────────┘ ``` -- **nginx**:路由请求——`/api/*` 到 Gateway,LangGraph 流式端点到 LangGraph,其余到前端。 -- **前端**(Next.js + React):浏览器界面,与 Gateway 和 LangGraph 通信。 -- **Gateway**(FastAPI):处理 API 操作——模型列表、Agent CRUD、记忆、扩展管理、文件上传。 -- **LangGraph**:DeerFlow Harness 运行时,管理线程状态、Agent 执行和流式传输。 +- **nginx**:路由请求——`/api/*` 和 `/api/langgraph/*` 到 Gateway,其余到前端。 +- **前端**(Next.js + React):浏览器界面,与 Gateway 通信。 +- **Gateway**(FastAPI):处理 API 操作,并通过内置 LangGraph-compatible 运行时管理线程状态、Agent 执行和流式传输。 ## 技术栈 @@ -64,7 +56,7 @@ DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理 | Gateway | FastAPI、Python 3.12、uvicorn | | Agent 运行时 | LangGraph、LangChain、DeerFlow Harness | | 反向代理 | nginx | -| 状态持久化 | LangGraph Server(默认)+ 可选 SQLite/PostgreSQL 检查点 | +| 状态持久化 | Gateway 运行时 + 可选 SQLite/PostgreSQL 检查点 | diff --git a/frontend/src/content/zh/application/operations-and-troubleshooting.mdx b/frontend/src/content/zh/application/operations-and-troubleshooting.mdx index c047bbd5c..8dc4c6551 100644 --- a/frontend/src/content/zh/application/operations-and-troubleshooting.mdx +++ b/frontend/src/content/zh/application/operations-and-troubleshooting.mdx @@ -15,16 +15,14 @@ DeerFlow 应用在 `logs/` 目录中写入每个服务的日志: | 文件 | 内容 | | -------------------- | -------------------------------------- | -| `logs/langgraph.log` | Agent 运行时、工具调用、LangGraph 错误 | -| `logs/gateway.log` | API 请求/响应、Gateway 错误 | +| `logs/gateway.log` | API 请求/响应、Agent 运行时和 Gateway 错误 | | `logs/frontend.log` | Next.js 服务器日志 | | `logs/nginx.log` | 代理访问和错误日志 | **实时追踪日志**: ```bash -tail -f logs/langgraph.log # 查看 Agent 活动 -tail -f logs/gateway.log # 查看 API 请求 +tail -f logs/gateway.log # 查看 API 请求和 Agent 活动 ``` **调整日志级别**: @@ -42,9 +40,6 @@ DeerFlow 暴露健康检查端点: # Gateway 健康状态 curl http://localhost:8001/health -# LangGraph 健康状态 -curl http://localhost:2024/ok - # 通过 nginx 完整代理链验证 curl http://localhost:2026/api/models ``` @@ -68,8 +63,8 @@ make config-upgrade **诊断**: ```bash -# 检查 LangGraph 日志中的模型错误 -grep -i "error\|apikey\|unauthorized" logs/langgraph.log | tail -20 +# 检查 Gateway 日志中的模型错误 +grep -i "error\|apikey\|unauthorized" logs/gateway.log | tail -20 ``` **解决**: @@ -118,13 +113,13 @@ SKIP_ENV_VALIDATION=1 pnpm build ### MCP 服务器连接失败 -**症状**:MCP 工具未出现,`logs/langgraph.log` 中有超时错误。 +**症状**:MCP 工具未出现,`logs/gateway.log` 中有超时错误。 **诊断**: ```bash # 检查 MCP 相关错误 -grep -i "mcp\|timeout" logs/langgraph.log | tail -20 +grep -i "mcp\|timeout" logs/gateway.log | tail -20 ``` **解决**: diff --git a/frontend/src/content/zh/application/quick-start.mdx b/frontend/src/content/zh/application/quick-start.mdx index 5ccf117ad..b5ab052fc 100644 --- a/frontend/src/content/zh/application/quick-start.mdx +++ b/frontend/src/content/zh/application/quick-start.mdx @@ -1,6 +1,6 @@ --- title: 快速上手 -description: 本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。所有四个服务(LangGraph、Gateway、前端、nginx)一起启动,通过单个 URL 访问。 +description: 本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。Gateway、前端和 nginx 会一起启动,通过单个 URL 访问。 --- import { Callout, Cards, Steps } from "nextra/components"; @@ -12,7 +12,7 @@ import { Callout, Cards, Steps } from "nextra/components"; 3.12+、Node.js 22+ 的机器,以及至少一个 LLM API Key。 -本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。所有四个服务(LangGraph、Gateway、前端、nginx)一起启动,通过单个 URL 访问。 +本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。Gateway、前端和 nginx 会一起启动,通过单个 URL 访问。 ## 前置条件 @@ -88,8 +88,7 @@ make dev 这会启动: -- LangGraph 服务,端口 `2024` -- Gateway API,端口 `8001` +- Gateway API 和嵌入式 Agent 运行时,端口 `8001` - 前端,端口 `3000` - nginx 反向代理,端口 `2026` @@ -113,15 +112,13 @@ make stop | 服务 | 日志文件 | | --------- | -------------------- | -| LangGraph | `logs/langgraph.log` | | Gateway | `logs/gateway.log` | | 前端 | `logs/frontend.log` | | nginx | `logs/nginx.log` | 如果有问题,先检查日志文件。大多数启动错误(缺失 API - Key、配置解析失败)会出现在 logs/langgraph.log 或{" "} - logs/gateway.log 中。 + Key、配置解析失败)会出现在 logs/gateway.log 中。 diff --git a/frontend/src/core/messages/usage.ts b/frontend/src/core/messages/usage.ts index 4679dffa5..01e3a59e1 100644 --- a/frontend/src/core/messages/usage.ts +++ b/frontend/src/core/messages/usage.ts @@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null { return hasUsage ? cumulative : null; } -function hasNonZeroUsage( +export function hasNonZeroUsage( usage: TokenUsage | null | undefined, ): usage is TokenUsage { return ( @@ -75,7 +75,7 @@ function hasNonZeroUsage( ); } -function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { +export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { return { inputTokens: base.inputTokens + delta.inputTokens, outputTokens: base.outputTokens + delta.outputTokens, diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index e20daa1b6..3f1fef9ad 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -26,6 +26,13 @@ export type MessageGroup = | AssistantClarificationGroup | AssistantSubagentGroup; +const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([ + "summary", + "loop_warning", + "todo_reminder", + "todo_completion_reminder", +]); + export function getMessageGroups(messages: Message[]): MessageGroup[] { if (messages.length === 0) { return []; @@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] { continue; } - if (message.name === "todo_reminder") { - continue; - } - if (message.type === "human") { groups.push({ id: message.id, type: "human", messages: [message] }); continue; @@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) { export function isHiddenFromUIMessage(message: Message) { return ( message.additional_kwargs?.hide_from_ui === true || - message.name === "summary" || - message.name === "loop_warning" + (typeof message.name === "string" && + HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name)) ); } diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 0ac790eb2..fba3edd0c 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -45,15 +45,60 @@ type SendMessageOptions = { additionalKwargs?: Record; }; -function mergeMessages( +function isNonEmptyString(value: string | undefined): value is string { + return typeof value === "string" && value.length > 0; +} + +function messageIdentity(message: Message): string | undefined { + if ( + "tool_call_id" in message && + typeof message.tool_call_id === "string" && + message.tool_call_id.length > 0 + ) { + return `tool:${message.tool_call_id}`; + } + if (typeof message.id === "string" && message.id.length > 0) { + return `message:${message.id}`; + } + return undefined; +} + +function dedupeMessagesByIdentity(messages: Message[]): Message[] { + const lastIndexByIdentity = new Map(); + + messages.forEach((message, index) => { + const identity = messageIdentity(message); + if (identity) { + lastIndexByIdentity.set(identity, index); + } + }); + + return messages.filter((message, index) => { + const identity = messageIdentity(message); + return !identity || lastIndexByIdentity.get(identity) === index; + }); +} + +function findLatestUnloadedRunIndex( + runs: Run[], + loadedRunIds: ReadonlySet, +): number { + for (let i = runs.length - 1; i >= 0; i--) { + const run = runs[i]; + if (run && !loadedRunIds.has(run.run_id)) { + return i; + } + } + return -1; +} + +export function mergeMessages( historyMessages: Message[], threadMessages: Message[], optimisticMessages: Message[], ): Message[] { const threadMessageIds = new Set( - threadMessages - .map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id)) - .filter(Boolean), + threadMessages.map(messageIdentity).filter(isNonEmptyString), ); // The overlap is a contiguous suffix of historyMessages (newest history == oldest thread). @@ -65,28 +110,19 @@ function mergeMessages( if (!msg) { continue; } - if ( - (msg?.id && threadMessageIds.has(msg.id)) || - ("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id)) - ) { + const identity = messageIdentity(msg); + if (identity && threadMessageIds.has(identity)) { cutoff = i; } else { break; } } - return [ + return dedupeMessagesByIdentity([ ...historyMessages.slice(0, cutoff), ...threadMessages, ...optimisticMessages, - ]; -} - -function messageIdentity(message: Message): string | undefined { - if ("tool_call_id" in message) { - return message.tool_call_id; - } - return message.id; + ]); } function getMessagesAfterBaseline( @@ -296,7 +332,11 @@ export function useThreadStream({ onError(error) { setOptimisticMessages([]); toast.error(getStreamErrorMessage(error)); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ queryKey: threadTokenUsageQueryKey(threadIdRef.current), @@ -305,7 +345,11 @@ export function useThreadStream({ }, onFinish(state) { listeners.current.onFinish?.(state.values); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ @@ -339,7 +383,11 @@ export function useThreadStream({ useEffect(() => { startedRef.current = false; sendInFlightRef.current = false; - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); prevHumanMsgCountRef.current = latestMessageCountsRef.current.humanMessageCount; }, [threadId]); @@ -615,48 +663,105 @@ export function useThreadHistory(threadId: string) { const runsRef = useRef(runs.data ?? []); const indexRef = useRef(-1); const loadingRef = useRef(false); + const pendingLoadRef = useRef(false); + const loadingRunIdRef = useRef(null); + const loadedRunIdsRef = useRef>(new Set()); const [loading, setLoading] = useState(false); const [messages, setMessages] = useState([]); - loadingRef.current = loading; const loadMessages = useCallback(async () => { + if (loadingRef.current) { + const pendingRunIndex = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + const pendingRun = runsRef.current[pendingRunIndex]; + if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) { + pendingLoadRef.current = true; + } + return; + } if (runsRef.current.length === 0) { return; } - const run = runsRef.current[indexRef.current]; - if (!run || loadingRef.current) { - return; - } + + loadingRef.current = true; + setLoading(true); + try { - setLoading(true); - const result: { data: RunMessage[]; hasMore: boolean } = await fetch( - `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`, - { - method: "GET", - headers: { - "Content-Type": "application/json", + do { + pendingLoadRef.current = false; + + const nextRunIndex = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + indexRef.current = nextRunIndex; + + const run = runsRef.current[nextRunIndex]; + if (!run) { + indexRef.current = -1; + return; + } + + const requestThreadId = threadIdRef.current; + loadingRunIdRef.current = run.run_id; + const result: { data: RunMessage[]; hasMore: boolean } = await fetch( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`, + { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + credentials: "include", }, - credentials: "include", - }, - ).then((res) => { - return res.json(); - }); - const _messages = result.data - .filter((m) => !m.metadata.caller?.startsWith("middleware:")) - .map((m) => m.content); - setMessages((prev) => [..._messages, ...prev]); - indexRef.current -= 1; + ).then((res) => { + return res.json(); + }); + const _messages = result.data + .filter((m) => !m.metadata.caller?.startsWith("middleware:")) + .map((m) => m.content); + if (threadIdRef.current !== requestThreadId) { + return; + } + setMessages((prev) => + dedupeMessagesByIdentity([..._messages, ...prev]), + ); + loadedRunIdsRef.current.add(run.run_id); + indexRef.current = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + } while (pendingLoadRef.current); } catch (err) { console.error(err); } finally { + loadingRef.current = false; + loadingRunIdRef.current = null; setLoading(false); } }, []); useEffect(() => { + const threadChanged = threadIdRef.current !== threadId; threadIdRef.current = threadId; + + if (threadChanged) { + runsRef.current = []; + indexRef.current = -1; + pendingLoadRef.current = false; + loadingRunIdRef.current = null; + loadedRunIdsRef.current = new Set(); + loadingRef.current = false; + setLoading(false); + setMessages([]); + } + if (runs.data && runs.data.length > 0) { runsRef.current = runs.data ?? []; - indexRef.current = runs.data.length - 1; + indexRef.current = findLatestUnloadedRunIndex( + runs.data, + loadedRunIdsRef.current, + ); } loadMessages().catch(() => { toast.error("Failed to load thread history."); @@ -665,7 +770,7 @@ export function useThreadHistory(threadId: string) { const appendMessages = useCallback((_messages: Message[]) => { setMessages((prev) => { - return [...prev, ..._messages]; + return dedupeMessagesByIdentity([...prev, ..._messages]); }); }, []); const hasMore = indexRef.current >= 0 || !runs.data; diff --git a/frontend/tests/e2e/chat.spec.ts b/frontend/tests/e2e/chat.spec.ts index 490305de9..e608793df 100644 --- a/frontend/tests/e2e/chat.spec.ts +++ b/frontend/tests/e2e/chat.spec.ts @@ -48,4 +48,66 @@ test.describe("Chat workspace", () => { timeout: 10_000, }); }); + + test("keeps attachments visible while upload submit is pending", async ({ + page, + }) => { + let releaseUpload!: () => void; + const uploadCanFinish = new Promise((resolve) => { + releaseUpload = resolve; + }); + let uploadStarted!: () => void; + const uploadStartedPromise = new Promise((resolve) => { + uploadStarted = resolve; + }); + + await page.route("**/api/threads/*/uploads", async (route) => { + uploadStarted(); + await uploadCanFinish; + return route.fulfill({ + status: 200, + contentType: "application/json", + body: JSON.stringify({ + success: true, + message: "Uploaded", + files: [ + { + filename: "report.docx", + size: 12, + path: "report.docx", + virtual_path: "/mnt/user-data/uploads/report.docx", + artifact_url: "/api/threads/test/uploads/report.docx", + extension: ".docx", + }, + ], + }), + }); + }); + + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + const promptForm = page.locator("form").filter({ has: textarea }); + + await page.getByLabel("Upload files").setInputFiles({ + name: "report.docx", + mimeType: + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + buffer: Buffer.from("fake docx"), + }); + await expect(promptForm.getByText("report.docx")).toBeVisible(); + + await textarea.fill("Summarize this document"); + await textarea.press("Enter"); + + await uploadStartedPromise; + await expect(promptForm.getByText("report.docx")).toBeVisible(); + + releaseUpload(); + await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({ + timeout: 10_000, + }); + await expect(promptForm.getByText("report.docx")).toBeHidden(); + }); }); diff --git a/frontend/tests/unit/core/messages/utils.test.ts b/frontend/tests/unit/core/messages/utils.test.ts index 24d014c7e..cbc245583 100644 --- a/frontend/tests/unit/core/messages/utils.test.ts +++ b/frontend/tests/unit/core/messages/utils.test.ts @@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => { ), ).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]); }); + +test("hides internal todo reminder messages from message groups", () => { + const messages = [ + { + id: "human-1", + type: "human", + content: "Audit the middleware", + }, + { + id: "todo-reminder-1", + type: "human", + name: "todo_completion_reminder", + content: "finish todos", + }, + { + id: "todo-reminder-2", + type: "human", + name: "todo_reminder", + content: "remember todos", + }, + { + id: "ai-1", + type: "ai", + content: "Done", + }, + ] as Message[]; + + const groups = getMessageGroups(messages); + + expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]); + expect( + groups.flatMap((group) => group.messages).map((message) => message.id), + ).toEqual(["human-1", "ai-1"]); +}); diff --git a/frontend/tests/unit/core/threads/message-merge.test.ts b/frontend/tests/unit/core/threads/message-merge.test.ts new file mode 100644 index 000000000..9b29aebc9 --- /dev/null +++ b/frontend/tests/unit/core/threads/message-merge.test.ts @@ -0,0 +1,64 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { expect, test } from "vitest"; + +import { mergeMessages } from "@/core/threads/hooks"; + +test("mergeMessages removes duplicate messages already present in history", () => { + const human = { + id: "human-1", + type: "human", + content: "Design an agent", + } as Message; + const ai = { + id: "ai-1", + type: "ai", + content: "Let's design it.", + } as Message; + + expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]); +}); + +test("mergeMessages lets live thread messages replace overlapping history", () => { + const oldHuman = { + id: "human-1", + type: "human", + content: "old", + } as Message; + const liveHuman = { + id: "human-1", + type: "human", + content: "live", + } as Message; + const oldAi = { + id: "ai-1", + type: "ai", + content: "old", + } as Message; + const liveAi = { + id: "ai-1", + type: "ai", + content: "live", + } as Message; + + expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([ + liveHuman, + liveAi, + ]); +}); + +test("mergeMessages deduplicates tool messages by tool_call_id", () => { + const oldTool = { + id: "tool-message-old", + type: "tool", + tool_call_id: "call-1", + content: "old", + } as Message; + const liveTool = { + id: "tool-message-live", + type: "tool", + tool_call_id: "call-1", + content: "live", + } as Message; + + expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]); +}); diff --git a/scripts/detect_uv_extras.py b/scripts/detect_uv_extras.py index 91a9bd0ad..e6f4e8a24 100755 --- a/scripts/detect_uv_extras.py +++ b/scripts/detect_uv_extras.py @@ -72,6 +72,7 @@ def find_config_file() -> Path | None: _SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$") +_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$") _KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$") @@ -141,6 +142,84 @@ def section_value(lines: list[str], section: str, key: str) -> str | None: return None +def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None: + """Return the value of a nested YAML key like ``channels.discord.enabled``. + + Handles two levels of nesting: + channels: + discord: + enabled: true + """ + parts = section_path.split(".") + if len(parts) != 2: + return None + parent_section, child_section = parts + + inside_parent = False + inside_child = False + parent_indent: int | None = None + child_indent: int | None = None + + for raw in lines: + line = _strip_comment(raw) + if not line.strip(): + continue + + stripped = line.lstrip() + indent = len(line) - len(stripped) + + # Top-level section match + sect_match = _SECTION_RE.match(line) + if sect_match: + if indent == 0: + inside_parent = sect_match.group(1) == parent_section + inside_child = False + parent_indent = None + child_indent = None + continue + + if not inside_parent: + continue + + # Track parent indent from first child + if parent_indent is None and indent > 0: + parent_indent = indent + + # If indent goes back to 0, we left the parent section + if indent == 0: + inside_parent = False + inside_child = False + continue + + # Check if we're at the parent's child level (subsection) + if parent_indent is not None and indent == parent_indent: + # This could be a subsection or a direct key of parent + sub_match = _INDENTED_SECTION_RE.match(line) + if sub_match and sub_match.group(1) == child_section: + inside_child = True + child_indent = None + continue + else: + inside_child = False + continue + + if not inside_child: + continue + + # We're inside the subsection — track child indent + if child_indent is None and indent > (parent_indent or 0): + child_indent = indent + + if child_indent is not None and indent != child_indent: + continue + + key_match = _KEY_RE.match(line) + if key_match and key_match.group(1) == key: + return _unquote(key_match.group(2).strip()) + + return None + + def detect_from_config(path: Path) -> list[str]: try: text = path.read_text(encoding="utf-8", errors="replace") @@ -152,6 +231,8 @@ def detect_from_config(path: Path) -> list[str]: extras.add("postgres") if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres": extras.add("postgres") + if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true": + extras.add("discord") return sorted(extras) diff --git a/skills/public/claude-to-deerflow/SKILL.md b/skills/public/claude-to-deerflow/SKILL.md index d191f5c75..969a292c1 100644 --- a/skills/public/claude-to-deerflow/SKILL.md +++ b/skills/public/claude-to-deerflow/SKILL.md @@ -14,8 +14,8 @@ DeerFlow exposes two API surfaces behind an Nginx reverse proxy: | Service | Direct Port | Via Proxy | Purpose | |----------------|-------------|----------------------------------|----------------------------------| -| Gateway API | 8001 | `$DEERFLOW_GATEWAY_URL` | REST endpoints (models, skills, memory, uploads) | -| LangGraph API | 2024 | `$DEERFLOW_LANGGRAPH_URL` | Agent threads, runs, streaming | +| Gateway API | 8001 | `$DEERFLOW_GATEWAY_URL` | REST endpoints and embedded agent runtime | +| LangGraph-compatible API | 8001 | `$DEERFLOW_LANGGRAPH_URL` | Agent threads, runs, streaming | ## Environment Variables