From 84f88b6610e5c6384735e703809bc8b35e33dacb Mon Sep 17 00:00:00 2001 From: Eilen Shin <136898293+Eilen6316@users.noreply.github.com> Date: Tue, 12 May 2026 16:19:21 +0800 Subject: [PATCH 01/12] docs: align runtime docs with gateway mode (#2868) Co-authored-by: Willem Jiang --- CONTRIBUTING.md | 8 +-- README_fr.md | 6 +- README_ja.md | 6 +- README_zh.md | 6 +- backend/CONTRIBUTING.md | 5 +- backend/README.md | 57 ++++++++----------- backend/docs/API.md | 31 ++++++---- backend/docs/ARCHITECTURE.md | 49 ++++++++-------- frontend/README.md | 6 +- .../en/application/agents-and-threads.mdx | 7 +-- .../en/application/deployment-guide.mdx | 29 +++------- frontend/src/content/en/application/index.mdx | 20 ++----- .../operations-and-troubleshooting.mdx | 13 ++--- .../content/en/application/quick-start.mdx | 11 ++-- frontend/src/content/en/harness/skills.mdx | 2 +- .../content/zh/application/configuration.mdx | 1 - .../zh/application/deployment-guide.mdx | 27 +++------ frontend/src/content/zh/application/index.mdx | 20 ++----- .../operations-and-troubleshooting.mdx | 17 ++---- .../content/zh/application/quick-start.mdx | 11 ++-- skills/public/claude-to-deerflow/SKILL.md | 4 +- 21 files changed, 135 insertions(+), 201 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 51b834b4f..ceebba99c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -185,9 +185,9 @@ If you need to start services individually: 1. **Start backend service**: ```bash - # Terminal 1: Start Gateway API and embedded LangGraph-compatible runtime (port 8001) + # Terminal 1: Start Gateway API + embedded agent runtime (port 8001) cd backend - make gateway + make dev # Terminal 2: Start Frontend (port 3000) cd frontend @@ -207,7 +207,7 @@ If you need to start services individually: The nginx configuration provides: - Unified entry point on port 2026 -- Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx +- Rewrites `/api/langgraph/*` to Gateway's LangGraph-compatible API (8001) - Routes other `/api/*` endpoints to Gateway API (8001) - Routes non-API requests to Frontend (3000) - Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist @@ -231,7 +231,7 @@ deer-flow/ ├── backend/ # Backend application │ ├── src/ │ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001) -│ │ ├── agents/ # LangGraph agent definitions +│ │ ├── agents/ # LangGraph agent runtime used by Gateway │ │ ├── mcp/ # Model Context Protocol integration │ │ ├── skills/ # Skills system │ │ └── sandbox/ # Sandbox execution diff --git a/README_fr.md b/README_fr.md index 3b8dc3d41..f144d8bc5 100644 --- a/README_fr.md +++ b/README_fr.md @@ -228,7 +228,7 @@ make down # Stop and remove containers ``` > [!NOTE] -> Le serveur d'agents LangGraph fonctionne actuellement via `langgraph dev` (le serveur CLI open source). +> Le runtime d'agent s'exécute actuellement dans la Gateway. nginx réécrit `/api/langgraph/*` vers l'API compatible LangGraph servie par la Gateway. Accès : http://localhost:2026 @@ -296,8 +296,8 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca ```yaml channels: - # LangGraph Server URL (default: http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL (default: http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/README_ja.md b/README_ja.md index d2ba81750..2bf060799 100644 --- a/README_ja.md +++ b/README_ja.md @@ -181,7 +181,7 @@ make down # コンテナを停止して削除 ``` > [!NOTE] -> LangGraphエージェントサーバーは現在`langgraph dev`(オープンソースCLIサーバー)経由で実行されます。 +> Agentランタイムは現在Gateway内で実行されます。`/api/langgraph/*`はnginxによってGatewayのLangGraph-compatible APIへ書き換えられます。 アクセス: http://localhost:2026 @@ -249,8 +249,8 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート ```yaml channels: - # LangGraphサーバーURL(デフォルト: http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL(デフォルト: http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL(デフォルト: http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/README_zh.md b/README_zh.md index d5317082e..ec67b95d6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -184,7 +184,7 @@ make down # 停止并移除容器 ``` > [!NOTE] -> 当前 LangGraph agent server 通过开源 CLI 服务 `langgraph dev` 运行。 +> 当前 Agent 运行时嵌入在 Gateway 中运行,`/api/langgraph/*` 会由 nginx 重写到 Gateway 的 LangGraph-compatible API。 访问地址:http://localhost:2026 @@ -254,8 +254,8 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应 ```yaml channels: - # LangGraph Server URL(默认:http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL(默认:http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL(默认:http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/backend/CONTRIBUTING.md b/backend/CONTRIBUTING.md index 322710e74..f7ef58447 100644 --- a/backend/CONTRIBUTING.md +++ b/backend/CONTRIBUTING.md @@ -56,11 +56,8 @@ export OPENAI_API_KEY="your-api-key" ### Run the Development Server ```bash -# Terminal 1: LangGraph server +# Gateway API + embedded agent runtime make dev - -# Terminal 2: Gateway API -make gateway ``` ## Project Structure diff --git a/backend/README.md b/backend/README.md index 9b4d26fb1..18d89c2be 100644 --- a/backend/README.md +++ b/backend/README.md @@ -11,34 +11,26 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent │ Nginx (Port 2026) │ │ Unified reverse proxy │ └───────┬──────────────────┬───────────┘ - │ │ - /api/langgraph/* │ │ /api/* (other) - ▼ ▼ - ┌──────────────────────────────────────────────┐ - │ Gateway API (8001) │ - │ FastAPI REST + LangGraph-compatible runtime │ - │ │ - │ Models, MCP, Skills, Memory, Uploads, │ - │ Artifacts, Threads, Runs, Streaming │ - │ │ - │ ┌────────────────┐ │ - │ │ Lead Agent │ │ - │ │ ┌──────────┐ │ │ - │ │ │Middleware│ │ │ - │ │ │ Chain │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │ Tools │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │Subagents │ │ │ - │ │ └──────────┘ │ │ - │ └────────────────┘ │ - └──────────────────────────────────────────────┘ + │ + /api/langgraph/* │ /api/* (other) + rewritten to /api/* │ + ▼ + ┌────────────────────────────────────────┐ + │ Gateway API (8001) │ + │ FastAPI REST + agent runtime │ + │ │ + │ Models, MCP, Skills, Memory, Uploads, │ + │ Artifacts, Threads, Runs, Streaming │ + │ │ + │ ┌────────────────────────────────────┐ │ + │ │ Lead Agent │ │ + │ │ Middleware Chain, Tools, Subagents │ │ + │ └────────────────────────────────────┘ │ + └────────────────────────────────────────┘ ``` **Request Routing** (via Nginx): -- `/api/langgraph/*` → Gateway API - LangGraph-compatible agent interactions, threads, runs, and streaming translated to native `/api/*` routers +- `/api/langgraph/*` → Gateway LangGraph-compatible API - agent interactions, threads, streaming - `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup - `/` (non-API) → Frontend - Next.js web interface @@ -196,7 +188,7 @@ export OPENAI_API_KEY="your-api-key-here" **Full Application** (from project root): ```bash -make dev # Starts LangGraph + Gateway + Frontend + Nginx +make dev # Starts Gateway + Frontend + Nginx ``` Access at: http://localhost:2026 @@ -204,14 +196,11 @@ Access at: http://localhost:2026 **Backend Only** (from backend directory): ```bash -# Terminal 1: LangGraph server +# Gateway API + embedded agent runtime make dev - -# Terminal 2: Gateway API -make gateway ``` -Direct access: LangGraph at http://localhost:2024, Gateway at http://localhost:8001 +Direct access: Gateway at http://localhost:8001 --- @@ -247,7 +236,7 @@ backend/ │ └── utils/ # Utilities ├── docs/ # Documentation ├── tests/ # Test suite -├── langgraph.json # LangGraph server configuration +├── langgraph.json # LangGraph graph registry for tooling/Studio compatibility ├── pyproject.toml # Python dependencies ├── Makefile # Development commands └── Dockerfile # Container build @@ -365,8 +354,8 @@ If a provider is explicitly enabled but required credentials are missing, or the ```bash make install # Install dependencies -make dev # Run LangGraph server (port 2024) -make gateway # Run Gateway API (port 8001) +make dev # Run Gateway API + embedded agent runtime (port 8001) +make gateway # Run Gateway API without reload (port 8001) make lint # Run linter (ruff) make format # Format code (ruff) ``` diff --git a/backend/docs/API.md b/backend/docs/API.md index 293c1ebd1..d0b06ef0b 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -561,12 +561,13 @@ location /api/ { --- -## WebSocket Support +## Streaming Support -The LangGraph server supports WebSocket connections for real-time streaming. Connect to: +Gateway's LangGraph-compatible API streams run events with Server-Sent Events (SSE): -``` -ws://localhost:2026/api/langgraph/threads/{thread_id}/runs/stream +```http +POST /api/langgraph/threads/{thread_id}/runs/stream +Accept: text/event-stream ``` --- @@ -602,13 +603,21 @@ const response = await fetch('/api/models'); const data = await response.json(); console.log(data.models); -// Using EventSource for streaming -const eventSource = new EventSource( - `/api/langgraph/threads/${threadId}/runs/stream` -); -eventSource.onmessage = (event) => { - console.log(JSON.parse(event.data)); -}; +// Create a run and stream SSE events +const streamResponse = await fetch(`/api/langgraph/threads/${threadId}/runs/stream`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ + input: { messages: [{ role: "user", content: "Hello" }] }, + stream_mode: ["values", "messages-tuple", "custom"], + }), +}); + +const reader = streamResponse.body?.getReader(); +// Decode and parse SSE frames from reader in your client code. ``` ### cURL Examples diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index e6fdbe217..f1557a6fb 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -20,24 +20,22 @@ This document provides a comprehensive overview of the DeerFlow backend architec │ └────────────────────────────────────────────────────────────────────┘ │ └─────────────────────────────────┬────────────────────────────────────────┘ │ - ┌───────────────────────┼───────────────────────┐ - │ │ │ - ▼ ▼ ▼ -┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ -│ Embedded Runtime │ │ Gateway API │ │ Frontend │ -│ (inside Gateway) │ │ (Port 8001) │ │ (Port 3000) │ -│ │ │ │ │ │ -│ - Agent Runtime │ │ - Models API │ │ - Next.js App │ -│ - Thread Mgmt │ │ - MCP Config │ │ - React UI │ -│ - SSE Streaming │ │ - Skills Mgmt │ │ - Chat Interface │ -│ - Checkpointing │ │ - File Uploads │ │ │ -│ │ │ - Thread Cleanup │ │ │ -│ │ │ - Artifacts │ │ │ -└─────────────────────┘ └─────────────────────┘ └─────────────────────┘ - │ │ - │ ┌─────────────────┘ - │ │ - ▼ ▼ + ┌───────────────────────┴───────────────────────┐ + │ │ + ▼ ▼ +┌─────────────────────────────────────────────┐ ┌─────────────────────┐ +│ Gateway API │ │ Frontend │ +│ (Port 8001) │ │ (Port 3000) │ +│ │ │ │ +│ - LangGraph-compatible runs/threads API │ │ - Next.js App │ +│ - Embedded Agent Runtime │ │ - React UI │ +│ - SSE Streaming │ │ - Chat Interface │ +│ - Checkpointing │ │ │ +│ - Models, MCP, Skills, Uploads, Artifacts │ │ │ +│ - Thread Cleanup │ │ │ +└─────────────────────────────────────────────┘ └─────────────────────┘ + │ + ▼ ┌──────────────────────────────────────────────────────────────────────────┐ │ Shared Configuration │ │ ┌─────────────────────────┐ ┌────────────────────────────────────────┐ │ @@ -52,9 +50,9 @@ This document provides a comprehensive overview of the DeerFlow backend architec ## Component Details -### Embedded LangGraph Runtime +### Gateway Embedded Agent Runtime -The LangGraph-compatible runtime runs inside the Gateway process and is built on LangGraph for robust multi-agent workflow orchestration. +The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for robust multi-agent workflow orchestration. Nginx rewrites `/api/langgraph/*` to Gateway's native `/api/*` routes, so the public API remains compatible with LangGraph SDK clients without running a separate LangGraph server. **Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent` @@ -65,7 +63,7 @@ The LangGraph-compatible runtime runs inside the Gateway process and is built on - Tool execution orchestration - SSE streaming for real-time responses -**Configuration**: `langgraph.json` +**Graph registry**: `langgraph.json` remains available for tooling and Studio compatibility. ```json { @@ -84,6 +82,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl **Routers**: - `models.py` - `/api/models` - Model listing and details +- `thread_runs.py` / `runs.py` - `/api/threads/{id}/runs`, `/api/runs/*` - LangGraph-compatible runs and streaming - `mcp.py` - `/api/mcp` - MCP server configuration - `skills.py` - `/api/skills` - Skills management - `uploads.py` - `/api/threads/{id}/uploads` - File upload @@ -91,7 +90,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl - `artifacts.py` - `/api/threads/{id}/artifacts` - Artifact serving - `suggestions.py` - `/api/threads/{id}/suggestions` - Follow-up suggestion generation -The web conversation delete flow is now split across both backend surfaces: LangGraph handles `DELETE /api/langgraph/threads/{thread_id}` for thread state, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`. +The web conversation delete flow first deletes Gateway-managed thread state through the LangGraph-compatible route, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`. ### Agent Architecture @@ -354,9 +353,9 @@ SKILL.md Format: {"input": {"messages": [{"role": "user", "content": "Hello"}]}} 2. Nginx → Gateway API (8001) - Routes `/api/langgraph/*` to the Gateway's LangGraph-compatible runtime + `/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes -3. Embedded LangGraph runtime +3. Gateway embedded runtime a. Load/create thread state b. Execute middleware chain: - ThreadDataMiddleware: Set up paths @@ -412,7 +411,7 @@ SKILL.md Format: ### Thread Cleanup Flow ``` -1. Client deletes conversation via LangGraph +1. Client deletes conversation via the LangGraph-compatible Gateway route DELETE /api/langgraph/threads/{thread_id} 2. Web UI follows up with Gateway cleanup diff --git a/frontend/README.md b/frontend/README.md index 6db881301..4ad70fb1f 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -82,10 +82,10 @@ pnpm start Key environment variables (see `.env.example` for full list): ```bash -# Backend API URLs (optional, uses nginx proxy by default) +# Backend API URL (optional, uses local Next.js/nginx proxy by default) NEXT_PUBLIC_BACKEND_BASE_URL="http://localhost:8001" -# LangGraph API URLs (optional, uses nginx proxy by default) -NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:2024" +# LangGraph-compatible API URL (optional, uses local Next.js/nginx proxy by default) +NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:8001/api" ``` ## Project Structure diff --git a/frontend/src/content/en/application/agents-and-threads.mdx b/frontend/src/content/en/application/agents-and-threads.mdx index bbf3cfc7e..0a281a33e 100644 --- a/frontend/src/content/en/application/agents-and-threads.mdx +++ b/frontend/src/content/en/application/agents-and-threads.mdx @@ -111,10 +111,9 @@ checkpointer: ``` - The LangGraph Server manages its own state separately. The - checkpointer setting in config.yaml applies to the - embedded DeerFlowClient (used in direct Python integrations), not - to the LangGraph Server deployment used by DeerFlow App. + The Gateway embedded runtime uses the checkpointer setting in + config.yaml. The same setting is also used by + DeerFlowClient in direct Python integrations. ### Thread data storage diff --git a/frontend/src/content/en/application/deployment-guide.mdx b/frontend/src/content/en/application/deployment-guide.mdx index 04b3599c0..52b59cf01 100644 --- a/frontend/src/content/en/application/deployment-guide.mdx +++ b/frontend/src/content/en/application/deployment-guide.mdx @@ -23,8 +23,7 @@ Services started: | Service | Port | Description | | ----------- | ---- | ------------------------ | -| LangGraph | 2024 | DeerFlow Harness runtime | -| Gateway API | 8001 | FastAPI backend | +| Gateway API | 8001 | FastAPI backend + embedded agent runtime | | Frontend | 3000 | Next.js UI | | nginx | 2026 | Unified reverse proxy | @@ -36,13 +35,12 @@ Access the app at **http://localhost:2026**. make stop ``` -Stops all four services. Safe to run even if a service is not running. +Stops all services. Safe to run even if a service is not running. ``` -logs/langgraph.log # Agent runtime logs -logs/gateway.log # API gateway logs +logs/gateway.log # API gateway and agent runtime logs logs/frontend.log # Next.js dev server logs logs/nginx.log # nginx access/error logs ``` @@ -50,7 +48,7 @@ logs/nginx.log # nginx access/error logs Tail a log in real time: ```bash -tail -f logs/langgraph.log +tail -f logs/gateway.log ``` @@ -74,7 +72,7 @@ export DEER_FLOW_ROOT=/path/to/deer-flow docker compose -f docker/docker-compose-dev.yaml up --build ``` -Services: nginx, frontend, gateway, langgraph, and optionally provisioner (for K8s-managed sandboxes). +Services: nginx, frontend, gateway, and optionally provisioner (for K8s-managed sandboxes). Access the app at **http://localhost:2026**. @@ -99,7 +97,7 @@ The `docker-compose*.yaml` files include an `env_file: ../.env` directive that l ### Data persistence -Thread data is stored in `backend/.deer-flow/threads/`. In Docker deployments, this directory is bind-mounted into the langgraph container. +Thread data is stored in `backend/.deer-flow/threads/`. In Docker deployments, this directory is bind-mounted into the gateway container. To avoid data loss when containers are recreated: @@ -161,14 +159,7 @@ When `USERDATA_PVC_NAME` is set, the provisioner automatically uses subPath (`th ### nginx configuration -nginx routes all traffic. Key environment variables that control routing: - -| Variable | Default | Description | -| -------------------- | ---------------- | --------------------------------------- | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | LangGraph service address | -| `LANGGRAPH_REWRITE` | `/` | URL rewrite prefix for LangGraph routes | - -These are set in the Docker Compose environment and processed by `envsubst` at container startup. +nginx routes all traffic to the frontend or Gateway. `/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes, so no separate LangGraph upstream is required. ### Authentication @@ -186,8 +177,7 @@ openssl rand -base64 32 | Service | Minimum | Recommended | | ------------------------------- | ---------------- | ---------------- | -| LangGraph (agent runtime) | 2 vCPU, 4 GB RAM | 4 vCPU, 8 GB RAM | -| Gateway | 0.5 vCPU, 512 MB | 1 vCPU, 1 GB | +| Gateway + agent runtime | 2 vCPU, 4 GB RAM | 4 vCPU, 8 GB RAM | | Frontend | 0.5 vCPU, 512 MB | 1 vCPU, 1 GB | | Sandbox container (per session) | 1 vCPU, 1 GB | 2 vCPU, 2 GB | @@ -199,9 +189,6 @@ After starting, verify the deployment: # Check Gateway health curl http://localhost:8001/health -# Check LangGraph health -curl http://localhost:2024/ok - # List configured models (through nginx) curl http://localhost:2026/api/models ``` diff --git a/frontend/src/content/en/application/index.mdx b/frontend/src/content/en/application/index.mdx index 2cb15a911..b45a6cbf0 100644 --- a/frontend/src/content/en/application/index.mdx +++ b/frontend/src/content/en/application/index.mdx @@ -25,11 +25,11 @@ DeerFlow App is the reference implementation of what a production DeerFlow exper | **Streaming responses** | Real-time token streaming with thinking steps and tool call visibility | | **Artifact viewer** | In-browser preview and download of files and outputs produced by the agent | | **Extensions UI** | Enable/disable MCP servers and skills without editing config files | -| **Gateway API** | FastAPI-based REST API that bridges the frontend and the LangGraph runtime | +| **Gateway API** | FastAPI-based REST API with the embedded LangGraph-compatible agent runtime | ## Architecture -The DeerFlow App runs as four services behind a single nginx reverse proxy: +The DeerFlow App runs behind a single nginx reverse proxy: ``` ┌──────────────────┐ @@ -42,19 +42,11 @@ The DeerFlow App runs as four services behind a single nginx reverse proxy: │ Frontend :3000 │ │ Gateway API :8001 │ │ (Next.js) │ │ (FastAPI) │ └──────────────────┘ └──────────────────────┘ - │ - ┌─────────┘ - ▼ - ┌──────────────────────┐ - │ LangGraph :2024 │ - │ (DeerFlow Harness) │ - └──────────────────────┘ ``` -- **nginx**: routes requests — `/api/*` to the Gateway, LangGraph streaming endpoints to LangGraph directly, and everything else to the frontend. -- **Frontend** (Next.js + React): the browser UI. Communicates with both the Gateway and LangGraph. -- **Gateway** (FastAPI): handles API operations — model listing, agent CRUD, memory, extensions management, file uploads. -- **LangGraph**: the DeerFlow Harness runtime. Manages thread state, agent execution, and streaming. +- **nginx**: routes requests — `/api/*` and `/api/langgraph/*` to Gateway, and everything else to the frontend. +- **Frontend** (Next.js + React): the browser UI. Communicates with Gateway. +- **Gateway** (FastAPI): handles API operations and the embedded LangGraph-compatible runtime for thread state, agent execution, and streaming. ## Technology stack @@ -64,7 +56,7 @@ The DeerFlow App runs as four services behind a single nginx reverse proxy: | Gateway | FastAPI, Python 3.12, uvicorn | | Agent runtime | LangGraph, LangChain, DeerFlow Harness | | Reverse proxy | nginx | -| State persistence | LangGraph Server (default) + optional SQLite/PostgreSQL checkpointer | +| State persistence | Gateway runtime + optional SQLite/PostgreSQL checkpointer | diff --git a/frontend/src/content/en/application/operations-and-troubleshooting.mdx b/frontend/src/content/en/application/operations-and-troubleshooting.mdx index 8b21cf4b4..0f8d7e44c 100644 --- a/frontend/src/content/en/application/operations-and-troubleshooting.mdx +++ b/frontend/src/content/en/application/operations-and-troubleshooting.mdx @@ -15,15 +15,13 @@ All services write logs to the `logs/` directory when started with `make dev`: | File | Service | | -------------------- | ------------------------------------ | -| `logs/langgraph.log` | LangGraph / DeerFlow Harness runtime | -| `logs/gateway.log` | FastAPI Gateway API | +| `logs/gateway.log` | FastAPI Gateway API and agent runtime | | `logs/frontend.log` | Next.js frontend dev server | | `logs/nginx.log` | nginx reverse proxy | Tail logs in real time: ```bash -tail -f logs/langgraph.log tail -f logs/gateway.log ``` @@ -41,9 +39,6 @@ Verify each service is responding: # Gateway health curl http://localhost:8001/health -# LangGraph health -curl http://localhost:2024/ok - # Through nginx (verifies full proxy chain) curl http://localhost:2026/api/models ``` @@ -66,7 +61,7 @@ grep config_version config.yaml ### The app loads but the agent doesn't respond -1. Check `logs/langgraph.log` for startup errors. +1. Check `logs/gateway.log` for startup errors. 2. Verify your model is correctly configured in `config.yaml` with a valid API key. 3. Confirm the API key environment variable is set in the shell that ran `make dev`. 4. Test the model endpoint directly with `curl` to rule out network issues. @@ -126,7 +121,7 @@ Connection refused: http://provisioner:8002 If MCP tools appear in `extensions_config.json` but are not available in the agent: -1. Check `logs/langgraph.log` for MCP initialization errors. +1. Check `logs/gateway.log` for MCP initialization errors. 2. Verify the MCP server command is installed (`npx`, `uvx`, or the relevant binary). 3. Test the server command manually to confirm it starts without errors. 4. Set `log_level: debug` to see detailed MCP loading output. @@ -137,7 +132,7 @@ If MCP tools appear in `extensions_config.json` but are not available in the age - Verify `memory.enabled: true` in `config.yaml`. - Check that the storage path is writable: `ls -la backend/.deer-flow/`. -- Look for memory update errors in `logs/langgraph.log` (search for "memory"). +- Look for memory update errors in `logs/gateway.log` (search for "memory"). ## Data backup diff --git a/frontend/src/content/en/application/quick-start.mdx b/frontend/src/content/en/application/quick-start.mdx index 5ecfb3a26..c3baa0764 100644 --- a/frontend/src/content/en/application/quick-start.mdx +++ b/frontend/src/content/en/application/quick-start.mdx @@ -1,6 +1,6 @@ --- title: Quick Start -description: This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. All four services (LangGraph, Gateway, Frontend, nginx) start together and are accessible through a single URL. +description: This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. Gateway, Frontend, and nginx start together and are accessible through a single URL. --- import { Callout, Cards, Steps } from "nextra/components"; @@ -12,7 +12,7 @@ import { Callout, Cards, Steps } from "nextra/components"; Python 3.12+, Node.js 22+, and at least one LLM API key. -This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. All four services (LangGraph, Gateway, Frontend, nginx) start together and are accessible through a single URL. +This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. Gateway, Frontend, and nginx start together and are accessible through a single URL. ## Prerequisites @@ -88,8 +88,7 @@ make dev This starts: -- LangGraph server on port `2024` -- Gateway API on port `8001` +- Gateway API and embedded agent runtime on port `8001` - Frontend on port `3000` - nginx reverse proxy on port `2026` @@ -113,15 +112,13 @@ Log files: | Service | Log file | | --------- | -------------------- | -| LangGraph | `logs/langgraph.log` | | Gateway | `logs/gateway.log` | | Frontend | `logs/frontend.log` | | nginx | `logs/nginx.log` | If something is not working, check the log files first. Most startup errors - (missing API keys, config parsing failures) appear in `logs/langgraph.log` or - `logs/gateway.log`. + (missing API keys, config parsing failures) appear in `logs/gateway.log`. diff --git a/frontend/src/content/en/harness/skills.mdx b/frontend/src/content/en/harness/skills.mdx index 09f8b0d43..78247c40b 100644 --- a/frontend/src/content/en/harness/skills.mdx +++ b/frontend/src/content/en/harness/skills.mdx @@ -68,7 +68,7 @@ DeerFlow ships with the following public skills: ### Discovery and loading -`load_skills()` in `skills/loader.py` scans both `public/` and `custom/` directories under the configured skills path. It re-reads `ExtensionsConfig.from_file()` on every call, which means enabling or disabling a skill through the Gateway API takes effect immediately in the running LangGraph server without a restart. +`load_skills()` in `skills/loader.py` scans both `public/` and `custom/` directories under the configured skills path. It re-reads `ExtensionsConfig.from_file()` on every call, which means enabling or disabling a skill through the Gateway API takes effect immediately in the running agent runtime without a restart. ### Parsing diff --git a/frontend/src/content/zh/application/configuration.mdx b/frontend/src/content/zh/application/configuration.mdx index 639eeaec5..0094323e7 100644 --- a/frontend/src/content/zh/application/configuration.mdx +++ b/frontend/src/content/zh/application/configuration.mdx @@ -215,7 +215,6 @@ BETTER_AUTH_SECRET=local-dev-secret-at-least-32-chars | `DEER_FLOW_CONFIG_PATH` | 自动发现 | `config.yaml` 的绝对路径 | | `LOG_LEVEL` | `info` | 日志详细程度(`debug`/`info`/`warning`/`error`) | | `DEER_FLOW_ROOT` | 仓库根目录 | 用于 Docker 中的技能和线程挂载 | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | nginx 代理的 LangGraph 地址 | diff --git a/frontend/src/content/zh/application/deployment-guide.mdx b/frontend/src/content/zh/application/deployment-guide.mdx index 59eceece2..635120337 100644 --- a/frontend/src/content/zh/application/deployment-guide.mdx +++ b/frontend/src/content/zh/application/deployment-guide.mdx @@ -23,8 +23,7 @@ make dev | 服务 | 端口 | 描述 | | ----------- | ---- | ----------------------- | -| LangGraph | 2024 | DeerFlow Harness 运行时 | -| Gateway API | 8001 | FastAPI 后端 | +| Gateway API | 8001 | FastAPI 后端 + 嵌入式 Agent 运行时 | | 前端 | 3000 | Next.js 界面 | | nginx | 2026 | 统一反向代理 | @@ -36,13 +35,12 @@ make dev make stop ``` -停止所有四个服务。即使某个服务没有运行也可以安全执行。 +停止所有服务。即使某个服务没有运行也可以安全执行。 ``` -logs/langgraph.log # Agent 运行时日志 -logs/gateway.log # API Gateway 日志 +logs/gateway.log # API Gateway 和 Agent 运行时日志 logs/frontend.log # Next.js 开发服务器日志 logs/nginx.log # nginx 访问/错误日志 ``` @@ -50,7 +48,7 @@ logs/nginx.log # nginx 访问/错误日志 实时追踪日志: ```bash -tail -f logs/langgraph.log +tail -f logs/gateway.log ``` @@ -96,7 +94,7 @@ BETTER_AUTH_SECRET=your-secret-here-min-32-chars ### 数据持久化 -线程数据存储在 `backend/.deer-flow/threads/`。在 Docker 部署中,此目录被绑定挂载到 langgraph 容器中。 +线程数据存储在 `backend/.deer-flow/threads/`。在 Docker 部署中,此目录会绑定挂载到 gateway 容器中。 为避免容器重建时数据丢失: @@ -156,14 +154,7 @@ SKILLS_PVC_NAME=deer-flow-skills-pvc ### nginx 配置 -nginx 路由所有流量,控制路由的关键环境变量: - -| 变量 | 默认值 | 描述 | -| -------------------- | ---------------- | ----------------------------- | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | LangGraph 服务地址 | -| `LANGGRAPH_REWRITE` | `/` | LangGraph 路由的 URL 重写前缀 | - -这些在 Docker Compose 环境中设置,并在容器启动时由 `envsubst` 处理。 +nginx 将流量路由到前端或 Gateway。`/api/langgraph/*` 会被重写到 Gateway 的 LangGraph-compatible `/api/*` 路由,因此不需要单独的 LangGraph upstream。 ### 认证配置 @@ -181,8 +172,7 @@ openssl rand -base64 32 | 服务 | 最低配置 | 推荐配置 | | ------------------------- | ---------------- | ---------------- | -| LangGraph(Agent 运行时) | 2 vCPU、4 GB RAM | 4 vCPU、8 GB RAM | -| Gateway | 0.5 vCPU、512 MB | 1 vCPU、1 GB | +| Gateway + Agent 运行时 | 2 vCPU、4 GB RAM | 4 vCPU、8 GB RAM | | 前端 | 0.5 vCPU、512 MB | 1 vCPU、1 GB | | 沙箱容器(每会话) | 1 vCPU、1 GB | 2 vCPU、2 GB | @@ -194,9 +184,6 @@ openssl rand -base64 32 # 检查 Gateway 健康状态 curl http://localhost:8001/health -# 检查 LangGraph 健康状态 -curl http://localhost:2024/ok - # 通过 nginx 列出配置的模型(验证完整代理链) curl http://localhost:2026/api/models ``` diff --git a/frontend/src/content/zh/application/index.mdx b/frontend/src/content/zh/application/index.mdx index 81e7113e2..c12959b42 100644 --- a/frontend/src/content/zh/application/index.mdx +++ b/frontend/src/content/zh/application/index.mdx @@ -25,11 +25,11 @@ DeerFlow 应用是 DeerFlow 生产体验的参考实现。它将 Harness 运行 | **流式响应** | 实时 token 流式传输,带思考步骤和工具调用可见性 | | **产出物查看器** | Agent 生成文件和输出的浏览器内预览和下载 | | **扩展界面** | 无需编辑配置文件即可启用/禁用 MCP 服务器和技能 | -| **Gateway API** | 桥接前端和 LangGraph 运行时的基于 FastAPI 的 REST API | +| **Gateway API** | 基于 FastAPI 的 REST API,并内置 LangGraph-compatible Agent 运行时 | ## 架构 -DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理提供: +DeerFlow 应用通过单个 nginx 反向代理提供: ``` ┌──────────────────┐ @@ -42,19 +42,11 @@ DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理 │ 前端 :3000 │ │ Gateway API :8001 │ │ (Next.js) │ │ (FastAPI) │ └──────────────────┘ └──────────────────────┘ - │ - ┌─────────┘ - ▼ - ┌──────────────────────┐ - │ LangGraph :2024 │ - │ (DeerFlow Harness) │ - └──────────────────────┘ ``` -- **nginx**:路由请求——`/api/*` 到 Gateway,LangGraph 流式端点到 LangGraph,其余到前端。 -- **前端**(Next.js + React):浏览器界面,与 Gateway 和 LangGraph 通信。 -- **Gateway**(FastAPI):处理 API 操作——模型列表、Agent CRUD、记忆、扩展管理、文件上传。 -- **LangGraph**:DeerFlow Harness 运行时,管理线程状态、Agent 执行和流式传输。 +- **nginx**:路由请求——`/api/*` 和 `/api/langgraph/*` 到 Gateway,其余到前端。 +- **前端**(Next.js + React):浏览器界面,与 Gateway 通信。 +- **Gateway**(FastAPI):处理 API 操作,并通过内置 LangGraph-compatible 运行时管理线程状态、Agent 执行和流式传输。 ## 技术栈 @@ -64,7 +56,7 @@ DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理 | Gateway | FastAPI、Python 3.12、uvicorn | | Agent 运行时 | LangGraph、LangChain、DeerFlow Harness | | 反向代理 | nginx | -| 状态持久化 | LangGraph Server(默认)+ 可选 SQLite/PostgreSQL 检查点 | +| 状态持久化 | Gateway 运行时 + 可选 SQLite/PostgreSQL 检查点 | diff --git a/frontend/src/content/zh/application/operations-and-troubleshooting.mdx b/frontend/src/content/zh/application/operations-and-troubleshooting.mdx index c047bbd5c..8dc4c6551 100644 --- a/frontend/src/content/zh/application/operations-and-troubleshooting.mdx +++ b/frontend/src/content/zh/application/operations-and-troubleshooting.mdx @@ -15,16 +15,14 @@ DeerFlow 应用在 `logs/` 目录中写入每个服务的日志: | 文件 | 内容 | | -------------------- | -------------------------------------- | -| `logs/langgraph.log` | Agent 运行时、工具调用、LangGraph 错误 | -| `logs/gateway.log` | API 请求/响应、Gateway 错误 | +| `logs/gateway.log` | API 请求/响应、Agent 运行时和 Gateway 错误 | | `logs/frontend.log` | Next.js 服务器日志 | | `logs/nginx.log` | 代理访问和错误日志 | **实时追踪日志**: ```bash -tail -f logs/langgraph.log # 查看 Agent 活动 -tail -f logs/gateway.log # 查看 API 请求 +tail -f logs/gateway.log # 查看 API 请求和 Agent 活动 ``` **调整日志级别**: @@ -42,9 +40,6 @@ DeerFlow 暴露健康检查端点: # Gateway 健康状态 curl http://localhost:8001/health -# LangGraph 健康状态 -curl http://localhost:2024/ok - # 通过 nginx 完整代理链验证 curl http://localhost:2026/api/models ``` @@ -68,8 +63,8 @@ make config-upgrade **诊断**: ```bash -# 检查 LangGraph 日志中的模型错误 -grep -i "error\|apikey\|unauthorized" logs/langgraph.log | tail -20 +# 检查 Gateway 日志中的模型错误 +grep -i "error\|apikey\|unauthorized" logs/gateway.log | tail -20 ``` **解决**: @@ -118,13 +113,13 @@ SKIP_ENV_VALIDATION=1 pnpm build ### MCP 服务器连接失败 -**症状**:MCP 工具未出现,`logs/langgraph.log` 中有超时错误。 +**症状**:MCP 工具未出现,`logs/gateway.log` 中有超时错误。 **诊断**: ```bash # 检查 MCP 相关错误 -grep -i "mcp\|timeout" logs/langgraph.log | tail -20 +grep -i "mcp\|timeout" logs/gateway.log | tail -20 ``` **解决**: diff --git a/frontend/src/content/zh/application/quick-start.mdx b/frontend/src/content/zh/application/quick-start.mdx index 5ccf117ad..b5ab052fc 100644 --- a/frontend/src/content/zh/application/quick-start.mdx +++ b/frontend/src/content/zh/application/quick-start.mdx @@ -1,6 +1,6 @@ --- title: 快速上手 -description: 本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。所有四个服务(LangGraph、Gateway、前端、nginx)一起启动,通过单个 URL 访问。 +description: 本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。Gateway、前端和 nginx 会一起启动,通过单个 URL 访问。 --- import { Callout, Cards, Steps } from "nextra/components"; @@ -12,7 +12,7 @@ import { Callout, Cards, Steps } from "nextra/components"; 3.12+、Node.js 22+ 的机器,以及至少一个 LLM API Key。 -本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。所有四个服务(LangGraph、Gateway、前端、nginx)一起启动,通过单个 URL 访问。 +本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。Gateway、前端和 nginx 会一起启动,通过单个 URL 访问。 ## 前置条件 @@ -88,8 +88,7 @@ make dev 这会启动: -- LangGraph 服务,端口 `2024` -- Gateway API,端口 `8001` +- Gateway API 和嵌入式 Agent 运行时,端口 `8001` - 前端,端口 `3000` - nginx 反向代理,端口 `2026` @@ -113,15 +112,13 @@ make stop | 服务 | 日志文件 | | --------- | -------------------- | -| LangGraph | `logs/langgraph.log` | | Gateway | `logs/gateway.log` | | 前端 | `logs/frontend.log` | | nginx | `logs/nginx.log` | 如果有问题,先检查日志文件。大多数启动错误(缺失 API - Key、配置解析失败)会出现在 logs/langgraph.log 或{" "} - logs/gateway.log 中。 + Key、配置解析失败)会出现在 logs/gateway.log 中。 diff --git a/skills/public/claude-to-deerflow/SKILL.md b/skills/public/claude-to-deerflow/SKILL.md index d191f5c75..969a292c1 100644 --- a/skills/public/claude-to-deerflow/SKILL.md +++ b/skills/public/claude-to-deerflow/SKILL.md @@ -14,8 +14,8 @@ DeerFlow exposes two API surfaces behind an Nginx reverse proxy: | Service | Direct Port | Via Proxy | Purpose | |----------------|-------------|----------------------------------|----------------------------------| -| Gateway API | 8001 | `$DEERFLOW_GATEWAY_URL` | REST endpoints (models, skills, memory, uploads) | -| LangGraph API | 2024 | `$DEERFLOW_LANGGRAPH_URL` | Agent threads, runs, streaming | +| Gateway API | 8001 | `$DEERFLOW_GATEWAY_URL` | REST endpoints and embedded agent runtime | +| LangGraph-compatible API | 8001 | `$DEERFLOW_LANGGRAPH_URL` | Agent threads, runs, streaming | ## Environment Variables From f734e14d8b5004a6e5088499f27138f25e917653 Mon Sep 17 00:00:00 2001 From: greatmengqi Date: Tue, 12 May 2026 23:07:11 +0800 Subject: [PATCH 02/12] docs: document auth design and user isolation (#2913) * docs: document auth design and user isolation * docs: align auth docs with current storage and reset behavior --------- Co-authored-by: greatmengqi --- backend/app/gateway/app.py | 4 +- backend/app/gateway/auth/models.py | 2 +- backend/app/gateway/routers/auth.py | 2 +- backend/docs/API.md | 26 ++- backend/docs/AUTH_DESIGN.md | 331 +++++++++++++++++++++++++++ backend/docs/AUTH_TEST_DOCKER_GAP.md | 12 +- backend/docs/AUTH_TEST_PLAN.md | 254 +++++++++++--------- backend/docs/AUTH_UPGRADE.md | 59 +++-- backend/docs/README.md | 2 + 9 files changed, 547 insertions(+), 145 deletions(-) create mode 100644 backend/docs/AUTH_DESIGN.md diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 8848f473e..2c13f571c 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -62,7 +62,7 @@ async def _ensure_admin_user(app: FastAPI) -> None: Subsequent boots (admin already exists): - Runs the one-time "no-auth → with-auth" orphan thread migration for - existing LangGraph thread metadata that has no owner_id. + existing LangGraph thread metadata that has no user_id. No SQL persistence migration is needed: the four user_id columns (threads_meta, runs, run_events, feedback) only come into existence @@ -177,7 +177,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async with langgraph_runtime(app): logger.info("LangGraph runtime initialised") - # Ensure admin user exists (auto-create on first boot) + # Check admin bootstrap state and migrate orphan threads after admin exists. # Must run AFTER langgraph_runtime so app.state.store is available for thread migration await _ensure_admin_user(app) diff --git a/backend/app/gateway/auth/models.py b/backend/app/gateway/auth/models.py index d8f9b954a..25c6476fe 100644 --- a/backend/app/gateway/auth/models.py +++ b/backend/app/gateway/auth/models.py @@ -28,7 +28,7 @@ class User(BaseModel): oauth_id: str | None = Field(None, description="User ID from OAuth provider") # Auth lifecycle - needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes") + needs_setup: bool = Field(default=False, description="True when a reset account must complete setup") token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs") diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py index 3a41e13eb..6192456fb 100644 --- a/backend/app/gateway/routers/auth.py +++ b/backend/app/gateway/routers/auth.py @@ -305,7 +305,7 @@ async def login_local( async def register(request: Request, response: Response, body: RegisterRequest): """Register a new user account (always 'user' role). - Admin is auto-created on first boot. This endpoint creates regular users. + The first admin is created explicitly through /initialize. This endpoint creates regular users. Auto-login by setting the session cookie. """ try: diff --git a/backend/docs/API.md b/backend/docs/API.md index d0b06ef0b..762a135c4 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -535,14 +535,28 @@ All APIs return errors in a consistent format: ## Authentication -Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials. +DeerFlow enforces authentication for all non-public HTTP routes. Public routes are limited to health/docs metadata and these public auth endpoints: -Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers. +- `POST /api/v1/auth/initialize` creates the first admin account when no admin exists. +- `POST /api/v1/auth/login/local` logs in with email/password and sets an HttpOnly `access_token` cookie. +- `POST /api/v1/auth/register` creates a regular `user` account and sets the session cookie. +- `POST /api/v1/auth/logout` clears the session cookie. +- `GET /api/v1/auth/setup-status` reports whether the first admin still needs to be created. -For production deployments, it is recommended to: -1. Use Nginx for basic auth or OAuth integration -2. Deploy behind a VPN or private network -3. Implement custom authentication middleware +The authenticated auth endpoints are: + +- `GET /api/v1/auth/me` returns the current user. +- `POST /api/v1/auth/change-password` changes password, optionally changes email during setup, increments `token_version`, and reissues the cookie. + +Protected state-changing requests also require the CSRF double-submit token: send the `csrf_token` cookie value as the `X-CSRF-Token` header. Login/register/initialize/logout are bootstrap auth endpoints: they are exempt from the double-submit token but still reject hostile browser `Origin` headers. + +User isolation is enforced from the authenticated user context: + +- Thread metadata is scoped by `threads_meta.user_id`; search/read/write/delete APIs only expose the current user's threads. +- Thread files live under `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/` and are exposed inside the sandbox as `/mnt/user-data/`. +- Memory and custom agents are stored under `{base_dir}/users/{user_id}/...`. + +Note: MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers; that is separate from DeerFlow API authentication. --- diff --git a/backend/docs/AUTH_DESIGN.md b/backend/docs/AUTH_DESIGN.md new file mode 100644 index 000000000..9a740871d --- /dev/null +++ b/backend/docs/AUTH_DESIGN.md @@ -0,0 +1,331 @@ +# 用户认证与隔离设计 + +本文档描述 DeerFlow 当前内置认证模块的设计,而不是历史 RFC。它覆盖浏览器登录、API 认证、CSRF、用户隔离、首次初始化、密码重置、内部调用和升级迁移。 + +## 设计目标 + +认证模块的核心目标是把 DeerFlow 从“本地单用户工具”提升为“可多用户部署的 agent runtime”,并让用户身份贯穿 HTTP API、LangGraph-compatible runtime、文件系统、memory、自定义 agent 和反馈数据。 + +设计约束: + +- 默认强制认证:除健康检查、文档和 auth bootstrap 端点外,HTTP 路由都必须有有效 session。 +- 服务端持有所有权:客户端 metadata 不能声明 `user_id` 或 `owner_id`。 +- 隔离默认开启:repository(仓储)、文件路径、memory、agent 配置默认按当前用户解析。 +- 旧数据可升级:无认证版本留下的 thread 可以在 admin 存在后迁移到 admin。 +- 密码不进日志:首次初始化由操作者设置密码;`reset_admin` 只写 0600 凭据文件。 + +非目标: + +- 当前 OAuth 端点只是占位,尚未实现第三方登录。 +- 当前用户角色只有 `admin` 和 `user`,尚未实现细粒度 RBAC。 +- 当前登录限速是进程内字典,多 worker 下不是全局精确限速。 + +## 核心模型 + +```mermaid +graph TB + classDef actor fill:#D8CFC4,stroke:#6E6259,color:#2F2A26; + classDef api fill:#C9D7D2,stroke:#5D706A,color:#21302C; + classDef state fill:#D7D3E8,stroke:#6B6680,color:#29263A; + classDef data fill:#E5D2C4,stroke:#806A5B,color:#30251E; + + Browser["Browser — access_token cookie and csrf_token cookie"]:::actor + AuthMiddleware["AuthMiddleware — strict session gate"]:::api + CSRFMiddleware["CSRFMiddleware — double-submit token and Origin check"]:::api + AuthRoutes["Auth routes — initialize login register logout me change-password"]:::api + UserContext["Current user ContextVar — request-scoped identity"]:::state + Repositories["Repositories — AUTO resolves user_id from context"]:::state + Files["Filesystem — users/{user_id}/threads/{thread_id}/user-data"]:::data + Memory["Memory and agents — users/{user_id}/memory.json and agents"]:::data + + Browser --> AuthMiddleware + Browser --> CSRFMiddleware + AuthMiddleware --> AuthRoutes + AuthMiddleware --> UserContext + UserContext --> Repositories + UserContext --> Files + UserContext --> Memory +``` + +### 用户表 + +用户记录定义在 `app.gateway.auth.models.User`,持久化到 `users` 表。关键字段: + +| 字段 | 语义 | +|---|---| +| `id` | 用户主键,JWT `sub` 使用该值 | +| `email` | 唯一登录名 | +| `password_hash` | bcrypt hash,OAuth 用户可为空 | +| `system_role` | `admin` 或 `user` | +| `needs_setup` | reset 后要求用户完成邮箱 / 密码设置 | +| `token_version` | 改密码或 reset 时递增,用于废弃旧 JWT | + +### 运行时身份 + +认证成功后,`AuthMiddleware` 把用户同时写入: + +- `request.state.user` +- `request.state.auth` +- `deerflow.runtime.user_context` 的 `ContextVar` + +`ContextVar` 是这里的核心边界。上层 Gateway 负责写入身份,下层 persistence / file path 只读取结构化的当前用户,不反向依赖 `app.gateway.auth` 具体类型。 + +可以把 repository 调用的用户参数理解成一个三态 ADT: + +```scala +enum UserScope: + case AutoFromContext + case Explicit(userId: String) + case BypassForMigration +``` + +对应 Python 实现是 `AUTO | str | None`: + +- `AUTO`:从 `ContextVar` 解析当前用户;没有上下文则抛错。 +- `str`:显式指定用户,主要用于测试或管理脚本。 +- `None`:跳过用户过滤,只允许迁移脚本或 admin CLI 使用。 + +## 登录与初始化流程 + +### 首次初始化 + +首次启动时,如果没有 admin,服务不会自动创建账号,只记录日志提示访问 `/setup`。 + +流程: + +1. 用户访问 `/setup`。 +2. 前端调用 `GET /api/v1/auth/setup-status`。 +3. 如果返回 `{"needs_setup": true}`,前端展示创建 admin 表单。 +4. 表单提交 `POST /api/v1/auth/initialize`。 +5. 服务端确认当前没有 admin,创建 `system_role="admin"`、`needs_setup=false` 的用户。 +6. 服务端设置 `access_token` HttpOnly cookie,用户进入 workspace。 + +`/api/v1/auth/initialize` 只在没有 admin 时可用。并发初始化由数据库唯一约束兜底,失败方返回 409。 + +### 普通登录 + +`POST /api/v1/auth/login/local` 使用 `OAuth2PasswordRequestForm`: + +- `username` 是邮箱。 +- `password` 是密码。 +- 成功后签发 JWT,放入 `access_token` HttpOnly cookie。 +- 响应体只返回 `expires_in` 和 `needs_setup`,不返回 token。 + +登录失败会按客户端 IP 计数。IP 解析只在 TCP peer 属于 `AUTH_TRUSTED_PROXIES` 时信任 `X-Real-IP`,不使用 `X-Forwarded-For`。 + +### 注册 + +`POST /api/v1/auth/register` 创建普通 `user`,并自动登录。 + +当前实现允许在没有 admin 时注册普通用户,但 `setup-status` 仍会返回 `needs_setup=true`,因为 admin 仍不存在。这是当前产品策略边界:如果后续要求“必须先初始化 admin 才能注册普通用户”,需要在 `/register` 增加 admin-exists gate。 + +### 改密码与 reset setup + +`POST /api/v1/auth/change-password` 需要当前密码和新密码: + +- 校验当前密码。 +- 更新 bcrypt hash。 +- `token_version += 1`,使旧 JWT 立即失效。 +- 重新签发 cookie。 +- 如果 `needs_setup=true` 且传了 `new_email`,则更新邮箱并清除 `needs_setup`。 + +`python -m app.gateway.auth.reset_admin` 会: + +- 找到 admin 或指定邮箱用户。 +- 生成随机密码。 +- 更新密码 hash。 +- `token_version += 1`。 +- 设置 `needs_setup=true`。 +- 写入 `.deer-flow/admin_initial_credentials.txt`,权限 `0600`。 + +命令行只输出凭据文件路径,不输出明文密码。 + +## HTTP 认证边界 + +`AuthMiddleware` 是 fail-closed(默认拒绝)的全局认证门。 + +公开路径: + +- `/health` +- `/docs` +- `/redoc` +- `/openapi.json` +- `/api/v1/auth/login/local` +- `/api/v1/auth/register` +- `/api/v1/auth/logout` +- `/api/v1/auth/setup-status` +- `/api/v1/auth/initialize` + +其余路径都要求有效 `access_token` cookie。存在 cookie 但 JWT 无效、过期、用户不存在或 `token_version` 不匹配时,直接返回 401,而不是让请求穿透到业务路由。 + +路由级别的 owner check 由 `require_permission(..., owner_check=True)` 完成: + +- 读类请求允许旧的未追踪 legacy thread 兼容读取。 +- 写 / 删除类请求使用 `require_existing=True`,要求 thread row 存在且属于当前用户,避免删除后缺 row 导致其他用户误通过。 + +## CSRF 设计 + +DeerFlow 使用 Double Submit Cookie: + +- 服务端设置 `csrf_token` cookie。 +- 前端 state-changing 请求发送同值 `X-CSRF-Token` header。 +- 服务端用 `secrets.compare_digest` 比较 cookie/header。 + +需要 CSRF 的方法: + +- `POST` +- `PUT` +- `DELETE` +- `PATCH` + +auth bootstrap 端点(login/register/initialize/logout)不要求 double-submit token,因为首次调用时浏览器还没有 token;但这些端点会校验 browser `Origin`,拒绝 hostile Origin,避免 login CSRF / session fixation。 + +## 用户隔离 + +### Thread metadata + +Thread metadata 存在 `threads_meta`,关键隔离字段是 `user_id`。 + +创建 thread 时: + +- 客户端传入的 `metadata.user_id` 和 `metadata.owner_id` 会被剥离。 +- `ThreadMetaRepository.create(..., user_id=AUTO)` 从 `ContextVar` 解析真实用户。 +- `/api/threads/search` 默认只返回当前用户的 thread。 + +读取 / 修改 / 删除时: + +- `get()` 默认按当前用户过滤。 +- `check_access()` 用于路由 owner check。 +- 对其他用户的 thread 返回 404,避免泄露资源存在性。 + +### 文件系统 + +当前线程文件布局: + +```text +{base_dir}/users/{user_id}/threads/{thread_id}/user-data/ +├── workspace/ +├── uploads/ +└── outputs/ +``` + +agent 在 sandbox 内看到统一虚拟路径: + +```text +/mnt/user-data/workspace +/mnt/user-data/uploads +/mnt/user-data/outputs +``` + +`ThreadDataMiddleware` 使用 `get_effective_user_id()` 解析当前用户并生成线程路径。没有认证上下文时会落到 `default` 用户桶,主要用于内部调用、嵌入式 client 或无 HTTP 的本地执行路径。 + +### Memory + +默认 memory 存储: + +```text +{base_dir}/users/{user_id}/memory.json +{base_dir}/users/{user_id}/agents/{agent_name}/memory.json +``` + +有用户上下文时,空或相对 `memory.storage_path` 都使用上述 per-user 默认路径;只有绝对 `memory.storage_path` 会视为显式 opt-out(退出) per-user isolation,所有用户共享该路径。无用户上下文的 legacy 路径仍会把相对 `storage_path` 解析到 `Paths.base_dir` 下。 + +### 自定义 agent + +用户自定义 agent 写入: + +```text +{base_dir}/users/{user_id}/agents/{agent_name}/ +├── config.yaml +├── SOUL.md +└── memory.json +``` + +旧布局 `{base_dir}/agents/{agent_name}/` 只作为只读兼容回退。更新或删除旧共享 agent 会要求先运行迁移脚本。 + +## 内部调用与 IM 渠道 + +IM channel worker 不是浏览器用户,不持有浏览器 cookie。它们通过 Gateway 内部认证: + +- 请求带 `X-DeerFlow-Internal-Token`。 +- 同时带匹配的 CSRF cookie/header。 +- 服务端识别为内部用户,`id="default"`、`system_role="internal"`。 + +这意味着 channel 产生的数据默认进入 `default` 用户桶。这个选择适合“平台级 bot 身份”,但不是“每个 IM 用户单独隔离”。如果后续要做到外部 IM 用户隔离,需要把外部 platform user 映射到 DeerFlow user,并让 channel manager 设置对应的 scoped identity。 + +## LangGraph-compatible 认证 + +Gateway 内嵌 runtime 路径由 `AuthMiddleware` 和 `CSRFMiddleware` 保护。 + +仓库仍保留 `app.gateway.langgraph_auth`,用于 LangGraph Server 直连模式: + +- `@auth.authenticate` 校验 JWT cookie、CSRF、用户存在性和 `token_version`。 +- `@auth.on` 在写入 metadata 时注入 `user_id`,并在读路径返回 `{"user_id": current_user}` 过滤条件。 + +这保证 Gateway 路由和 LangGraph-compatible 直连模式使用同一 JWT 语义。 + +## 升级与迁移 + +从无认证版本升级时,可能存在没有 `user_id` 的历史 thread。 + +当前策略: + +1. 首次启动如果没有 admin,只提示访问 `/setup`,不迁移。 +2. 操作者创建 admin。 +3. 后续启动时,`_ensure_admin_user()` 找到 admin,并把 LangGraph store 中缺少 `metadata.user_id` 的 thread 迁移到 admin。 + +文件系统旧布局迁移由脚本处理: + +```bash +cd backend +PYTHONPATH=. python scripts/migrate_user_isolation.py --dry-run +PYTHONPATH=. python scripts/migrate_user_isolation.py --user-id +``` + +迁移脚本覆盖 legacy `memory.json`、`threads/` 和 `agents/` 到 per-user layout。 + +## 安全不变量 + +必须长期保持的不变量: + +- JWT 只在 HttpOnly cookie 中传输,不出现在响应 JSON。 +- 任何非 public HTTP 路由都不能只靠“cookie 存在”放行,必须严格验证 JWT。 +- `token_version` 不匹配必须拒绝,保证改密码 / reset 后旧 session 失效。 +- 客户端 metadata 中的 `user_id` / `owner_id` 必须剥离。 +- repository 默认 `AUTO` 必须从当前用户上下文解析,不能静默退化成全局查询。 +- 只有迁移脚本和 admin CLI 可以显式传 `user_id=None` 绕过隔离。 +- 本地文件路径必须通过 `Paths` 和 sandbox path validation 解析,不能拼接未校验的用户输入。 +- 捕获认证、迁移、后台任务异常必须记录日志;不能空 catch。 + +## 已知边界 + +| 边界 | 当前行为 | 后续方向 | +|---|---|---| +| 无 admin 时注册普通用户 | 允许注册普通 `user` | 如产品要求先初始化 admin,给 `/register` 加 gate | +| 登录限速 | 进程内 dict,单 worker 精确,多 worker 近似 | Redis / DB-backed rate limiter | +| OAuth | 端点占位,未实现 | 接入 provider 并统一 `token_version` / role 语义 | +| IM 用户隔离 | channel 使用 `default` 内部用户 | 建立外部用户到 DeerFlow user 的映射 | +| 绝对 memory path | 显式共享 memory | UI / docs 明确提示 opt-out 风险 | + +## 相关文件 + +| 文件 | 职责 | +|---|---| +| `app/gateway/auth_middleware.py` | 全局认证门、JWT 严格验证、写入 user context | +| `app/gateway/csrf_middleware.py` | CSRF double-submit 和 auth Origin 校验 | +| `app/gateway/routers/auth.py` | initialize/login/register/logout/me/change-password | +| `app/gateway/auth/jwt.py` | JWT 创建与解析 | +| `app/gateway/auth/reset_admin.py` | 密码 reset CLI | +| `app/gateway/auth/credential_file.py` | 0600 凭据文件写入 | +| `app/gateway/authz.py` | 路由权限与 owner check | +| `deerflow/runtime/user_context.py` | 当前用户 ContextVar 与 `AUTO` sentinel | +| `deerflow/persistence/thread_meta/` | thread metadata owner filter | +| `deerflow/config/paths.py` | per-user filesystem layout | +| `deerflow/agents/middlewares/thread_data_middleware.py` | run 时解析用户线程目录 | +| `deerflow/agents/memory/storage.py` | per-user memory storage | +| `deerflow/config/agents_config.py` | per-user custom agents | +| `app/channels/manager.py` | IM channel 内部认证调用 | +| `scripts/migrate_user_isolation.py` | legacy 数据迁移到 per-user layout | +| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库,包含 users / threads_meta / runs / feedback 等表 | +| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | +| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | diff --git a/backend/docs/AUTH_TEST_DOCKER_GAP.md b/backend/docs/AUTH_TEST_DOCKER_GAP.md index adf4916a3..969aad92c 100644 --- a/backend/docs/AUTH_TEST_DOCKER_GAP.md +++ b/backend/docs/AUTH_TEST_DOCKER_GAP.md @@ -24,11 +24,11 @@ All other test plan sections were executed against either: | Case | Title | What it covers | Why not run | |---|---|---|---| -| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` | +| TC-DOCKER-01 | `deerflow.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` | | TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` | | TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container | -| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` | -| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | +| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` | +| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | | TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` | ## Coverage already provided by non-Docker tests @@ -41,8 +41,8 @@ the test cases that ran on sg_dev or local: | TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between | | TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` | | TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs | -| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP | -| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | +| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies | +| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | | TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change | ## Reproduction steps when Docker becomes available @@ -72,6 +72,6 @@ Then run TC-DOCKER-01..06 from the test plan as written. about *container packaging* details (bind mounts, multi-worker, log collection), not about whether the auth code paths work. - **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect - the post-simplify reality (credentials file → 0600 file, no log leak). + the current reset flow (`reset_admin` → 0600 credentials file, no log leak). The old "grep 'Password:' in docker logs" expectation would have failed silently and given a false sense of coverage. diff --git a/backend/docs/AUTH_TEST_PLAN.md b/backend/docs/AUTH_TEST_PLAN.md index 15b20494a..e5245d60b 100644 --- a/backend/docs/AUTH_TEST_PLAN.md +++ b/backend/docs/AUTH_TEST_PLAN.md @@ -19,7 +19,7 @@ ```bash # 清除已有数据 -rm -f backend/.deer-flow/users.db +rm -f backend/.deer-flow/data/deerflow.db # 选择模式启动 make dev # 标准模式 @@ -28,10 +28,11 @@ make dev-pro # Gateway 模式 ``` **验证点:** -- [ ] 控制台输出 admin 邮箱和随机密码 -- [ ] 密码格式为 `secrets.token_urlsafe(16)` 的 22 字符字符串 -- [ ] 邮箱为 `admin@deerflow.dev` -- [ ] 提示 `Change it after login: Settings -> Account` +- [ ] 控制台不输出 admin 邮箱或明文密码 +- [ ] 控制台提示 `First boot detected — no admin account exists.` +- [ ] 控制台提示访问 `/setup` 完成 admin 创建 +- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": true}` +- [ ] 前端访问 `/login` 会跳转 `/setup` ### 1.2 非首次启动 @@ -42,7 +43,8 @@ make dev **验证点:** - [ ] 控制台不输出密码 -- [ ] 如果 admin 仍 `needs_setup=True`,控制台有 warning 提示 +- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": false}` +- [ ] 已登录用户如果 `needs_setup=True`,访问 workspace 会被引导到 `/setup` 完成改邮箱 / 改密码流程 ### 1.3 环境变量配置 @@ -76,19 +78,22 @@ make dev curl -s $BASE/api/v1/auth/setup-status | jq . ``` -**预期:** 返回 `{"needs_setup": false}`(admin 在启动时已自动创建,`count_users() > 0`)。仅在启动完成前的极短窗口内可能返回 `true`。 +**预期:** +- 干净数据库且尚未初始化 admin:返回 `{"needs_setup": true}` +- 已存在 admin:返回 `{"needs_setup": false}` -#### TC-API-02: Admin 首次登录 +#### TC-API-02: 首次初始化 Admin ```bash -curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<控制台密码>" \ +curl -s -X POST $BASE/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ -c cookies.txt | jq . ``` **预期:** -- 状态码 200 -- Body: `{"expires_in": 604800, "needs_setup": true}` +- 状态码 201 +- Body: `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}` - `cookies.txt` 包含 `access_token`(HttpOnly)和 `csrf_token`(非 HttpOnly) #### TC-API-03: 获取当前用户 @@ -97,9 +102,9 @@ curl -s -X POST $BASE/api/v1/auth/login/local \ curl -s $BASE/api/v1/auth/me -b cookies.txt | jq . ``` -**预期:** `{"id": "...", "email": "admin@deerflow.dev", "system_role": "admin", "needs_setup": true}` +**预期:** `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}` -#### TC-API-04: Setup 流程(改邮箱 + 改密码) +#### TC-API-04: 改密码流程 ```bash CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}') @@ -107,13 +112,36 @@ curl -s -X POST $BASE/api/v1/auth/change-password \ -b cookies.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"current_password":"<控制台密码>","new_password":"NewPass123!","new_email":"admin@example.com"}' | jq . + -d '{"current_password":"AdminPass1!","new_password":"NewPass123!"}' | jq . ``` **预期:** - 状态码 200 - `{"message": "Password changed successfully"}` -- 再调 `/auth/me` 邮箱变为 `admin@example.com`,`needs_setup` 变为 `false` +- 再调 `/auth/me` 仍为 `admin@example.com`,`needs_setup` 仍为 `false` + +#### TC-API-04a: reset_admin 后的 Setup 流程(改邮箱 + 改密码) + +```bash +cd backend +python -m app.gateway.auth.reset_admin --email admin@example.com +# 从 .deer-flow/admin_initial_credentials.txt 读取 reset 后密码 + +curl -s -X POST $BASE/api/v1/auth/login/local \ + -d "username=admin@example.com&password=<凭据文件密码>" \ + -c cookies.txt | jq . + +CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}') +curl -s -X POST $BASE/api/v1/auth/change-password \ + -b cookies.txt \ + -H "Content-Type: application/json" \ + -H "X-CSRF-Token: $CSRF" \ + -d '{"current_password":"<凭据文件密码>","new_password":"AdminPass2!","new_email":"admin2@example.com"}' | jq . +``` + +**预期:** +- 登录返回 `{"expires_in": 604800, "needs_setup": true}` +- `change-password` 后 `/auth/me` 邮箱变为 `admin2@example.com`,`needs_setup` 变为 `false` #### TC-API-05: 普通用户注册 @@ -493,7 +521,7 @@ curl -s -X POST $BASE/api/v1/auth/register \ ```bash # 检查数据库 -sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMIT 3;" +sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM users LIMIT 3;" ``` **预期:** `password_hash` 以 `$2b$` 开头(bcrypt 格式) @@ -506,24 +534,25 @@ sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMI ### 4.1 首次登录流程 -#### TC-UI-01: 访问首页跳转登录 +#### TC-UI-01: 无 admin 时访问 workspace 跳转 setup 1. 打开 `http://localhost:2026/workspace` -2. **预期:** 自动跳转到 `/login` +2. **预期:** 自动跳转到 `/setup` -#### TC-UI-02: Login 页面 +#### TC-UI-02: Setup 页面创建 admin -1. 输入 admin 邮箱和控制台密码 -2. 点击 Login -3. **预期:** 跳转到 `/setup`(因为 `needs_setup=true`) - -#### TC-UI-03: Setup 页面 - -1. 输入新邮箱、控制台密码(current)、新密码、确认密码 -2. 点击 Complete Setup +1. 输入 admin 邮箱、密码、确认密码 +2. 点击 Create Admin Account 3. **预期:** 跳转到 `/workspace` 4. 刷新页面不跳回 `/setup` +#### TC-UI-03: 已初始化后 Login 页面 + +1. 退出登录后访问 `/login` +2. 输入 admin 邮箱和密码 +3. 点击 Login +4. **预期:** 跳转到 `/workspace` + #### TC-UI-04: Setup 密码不匹配 1. 新密码和确认密码不一致 @@ -602,7 +631,7 @@ sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMI #### TC-UI-15: reset_admin 后重新登录 1. 执行 `cd backend && python -m app.gateway.auth.reset_admin` -2. 使用新密码登录 +2. 从 `.deer-flow/admin_initial_credentials.txt` 读取新密码并登录 3. **预期:** 跳转到 `/setup` 页面(`needs_setup` 被重置为 true) 4. 旧 session 已失效 @@ -645,18 +674,28 @@ make install make dev ``` -#### TC-UPG-01: 首次启动创建 admin +#### TC-UPG-01: 首次启动等待 admin 初始化 **预期:** -- [ ] 控制台输出 admin 邮箱(`admin@deerflow.dev`)和随机密码 +- [ ] 控制台不输出 admin 邮箱或随机密码 +- [ ] 访问 `/setup` 可创建第一个 admin - [ ] 无报错,正常启动 #### TC-UPG-02: 旧 Thread 迁移到 admin ```bash +# 创建第一个 admin +curl -s -X POST http://localhost:2026/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ + -c cookies.txt + +# 重启一次:启动迁移只在已有 admin 的启动路径执行 +make stop && make dev + # 登录 admin curl -s -X POST http://localhost:2026/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<控制台密码>" \ + -d "username=admin@example.com&password=AdminPass1!" \ -c cookies.txt # 查看 thread 列表 @@ -670,8 +709,8 @@ curl -s -X POST http://localhost:2026/api/threads/search \ **预期:** - [ ] 返回的 thread 数量 ≥ 旧版创建的数量 -- [ ] 控制台日志有 `Migrated N orphaned thread(s) to admin` -- [ ] 每个 thread 的 `metadata.owner_id` 都已被设为 admin 的 ID +- [ ] 控制台日志有 `Migrated N orphan LangGraph thread(s) to admin` +- [ ] 旧 thread 只对 admin 可见 #### TC-UPG-03: 旧 Thread 内容完整 @@ -683,7 +722,7 @@ curl -s http://localhost:2026/api/threads/ \ **预期:** - [ ] `metadata.title` 保留原值(如 `old-thread-1`) -- [ ] `metadata.owner_id` 已填充 +- [ ] 响应不回显服务端保留的 `user_id` / `owner_id` #### TC-UPG-04: 新用户看不到旧 Thread @@ -706,18 +745,19 @@ curl -s -X POST http://localhost:2026/api/threads/search \ ### 5.3 数据库 Schema 兼容 -#### TC-UPG-05: 无 users.db 时自动创建 +#### TC-UPG-05: 无 deerflow.db 时创建 schema 但不创建默认用户 ```bash -ls -la backend/.deer-flow/users.db +ls -la backend/.deer-flow/data/deerflow.db +sqlite3 backend/.deer-flow/data/deerflow.db "SELECT COUNT(*) FROM users;" ``` -**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列 +**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列;未调用 `/initialize` 前用户数为 0 -#### TC-UPG-06: users.db WAL 模式 +#### TC-UPG-06: deerflow.db WAL 模式 ```bash -sqlite3 backend/.deer-flow/users.db "PRAGMA journal_mode;" +sqlite3 backend/.deer-flow/data/deerflow.db "PRAGMA journal_mode;" ``` **预期:** 返回 `wal` @@ -768,9 +808,9 @@ make dev ``` **预期:** -- [ ] 服务正常启动(忽略 `users.db`,无 auth 相关代码不报错) +- [ ] 服务正常启动(忽略 `deerflow.db`,无 auth 相关代码不报错) - [ ] 旧对话数据仍然可访问 -- [ ] `users.db` 文件残留但不影响运行 +- [ ] `deerflow.db` 文件残留但不影响运行 #### TC-UPG-12: 再次升级到 auth 分支 @@ -781,51 +821,47 @@ make dev ``` **预期:** -- [ ] 识别已有 `users.db`,不重新创建 admin -- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `users.db`) +- [ ] 识别已有 `deerflow.db`,不重新创建 admin +- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `deerflow.db`) -### 5.7 休眠 Admin(初始密码未使用/未更改) +### 5.7 Admin 初始化与 reset_admin -> 首次启动生成 admin + 随机密码,但运维未登录、未改密码。 -> 密码只在首次启动的控制台闪过一次,后续启动不再显示。 +> 首次启动不生成默认 admin,也不在日志输出密码。忘记密码时走 `reset_admin`,新密码写入 0600 凭据文件。 -#### TC-UPG-13: 重启后自动重置密码并打印 +#### TC-UPG-13: 未初始化 admin 时重启不创建默认账号 ```bash -# 首次启动,记录密码 -rm -f backend/.deer-flow/users.db +rm -f backend/.deer-flow/data/deerflow.db make dev -# 控制台输出密码 P0,不登录 make stop -# 隔了几天,再次启动 make dev -# 控制台输出新密码 P1 +curl -s $BASE/api/v1/auth/setup-status | jq . ``` **预期:** -- [ ] 控制台输出 `Admin account setup incomplete — password reset` -- [ ] 输出新密码 P1(P0 已失效) -- [ ] 用 P1 可以登录,P0 不可以 -- [ ] 登录后 `needs_setup=true`,跳转 `/setup` -- [ ] `token_version` 递增(旧 session 如有也失效) +- [ ] 控制台不输出密码 +- [ ] `setup-status` 仍为 `{"needs_setup": true}` +- [ ] 访问 `/setup` 仍可创建第一个 admin -#### TC-UPG-14: 密码丢失 — 无需 CLI,重启即可 +#### TC-UPG-14: 密码丢失 — reset_admin 写入凭据文件 ```bash -# 忘记了控制台密码 → 直接重启服务 -make stop && make dev -# 控制台自动输出新密码 +python -m app.gateway.auth.reset_admin --email admin@example.com +ls -la backend/.deer-flow/admin_initial_credentials.txt +cat backend/.deer-flow/admin_initial_credentials.txt ``` **预期:** -- [ ] 无需 `reset_admin`,重启服务即可拿到新密码 -- [ ] `reset_admin` CLI 仍然可用作手动备选方案 +- [ ] 命令行只输出凭据文件路径,不输出明文密码 +- [ ] 凭据文件权限为 `0600` +- [ ] 凭据文件包含 email + password 行 +- [ ] 该用户下次登录返回 `needs_setup=true` -#### TC-UPG-15: 休眠 admin 期间普通用户注册 +#### TC-UPG-15: 未初始化 admin 期间普通用户注册策略边界 ```bash -# admin 存在但从未登录,普通用户先注册 +# admin 尚不存在,普通用户尝试注册 curl -s -X POST $BASE/api/v1/auth/register \ -H "Content-Type: application/json" \ -d '{"email":"earlybird@example.com","password":"EarlyPass1!"}' \ @@ -833,11 +869,11 @@ curl -s -X POST $BASE/api/v1/auth/register \ ``` **预期:** -- [ ] 注册成功(201),角色为 `user` -- [ ] 无法提权为 admin -- [ ] 普通用户的数据与 admin 隔离 +- [ ] 当前代码允许注册普通用户并自动登录(201,角色为 `user`) +- [ ] 但 `setup-status` 仍为 `{"needs_setup": true}`,因为 admin 仍不存在 +- [ ] 这是一个产品策略边界:若要求“必须先有 admin”,需要在 `/register` 增加 admin-exists gate -#### TC-UPG-16: 休眠 admin 不影响后续操作 +#### TC-UPG-16: 普通用户数据与后续 admin 隔离 ```bash # 普通用户正常创建 thread、发消息 @@ -849,14 +885,13 @@ curl -s -X POST $BASE/api/threads \ -d '{"metadata":{}}' | jq .thread_id ``` -**预期:** 正常创建,不受休眠 admin 影响 +**预期:** 普通用户正常创建 thread;后续 admin 创建后,搜索不到该普通用户 thread -#### TC-UPG-17: 休眠 admin 最终完成 Setup +#### TC-UPG-17: reset_admin 后完成 Setup ```bash -# 运维终于登录 curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=" \ + -d "username=admin@example.com&password=<凭据文件密码>" \ -c admin.txt | jq .needs_setup # 预期: true @@ -866,7 +901,7 @@ curl -s -X POST $BASE/api/v1/auth/change-password \ -b admin.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"current_password":"<密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \ + -d '{"current_password":"<凭据文件密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \ -c admin.txt # 验证 @@ -876,7 +911,7 @@ curl -s $BASE/api/v1/auth/me -b admin.txt | jq '{email, needs_setup}' **预期:** - [ ] `email` 变为 `admin@real.com` - [ ] `needs_setup` 变为 `false` -- [ ] 后续重启控制台不再有 warning +- [ ] 后续登录使用新密码 #### TC-UPG-18: 长期未用后 JWT 密钥轮换 @@ -890,8 +925,8 @@ make stop && make dev **预期:** - [ ] 服务正常启动 -- [ ] 旧密码仍可登录(密码存在 DB,与 JWT 密钥无关) -- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)— 但因为从未登录过也没有旧 token +- [ ] 账号密码仍可登录(密码存在 DB,与 JWT 密钥无关) +- [ ] 旧的 JWT token 失效(密钥变了签名不匹配) --- @@ -910,7 +945,7 @@ for i in 1 2 3; do done # 检查 admin 数量 -sqlite3 backend/.deer-flow/users.db \ +sqlite3 backend/.deer-flow/data/deerflow.db \ "SELECT COUNT(*) FROM users WHERE system_role='admin';" ``` @@ -1055,7 +1090,7 @@ curl -s -X POST $BASE/api/v1/auth/register \ wait # 检查用户数 -sqlite3 backend/.deer-flow/users.db \ +sqlite3 backend/.deer-flow/data/deerflow.db \ "SELECT COUNT(*) FROM users WHERE email='race@example.com';" ``` @@ -1165,13 +1200,16 @@ curl -s -w "%{http_code}" -X DELETE "$BASE/api/threads/$TID" \ ```bash cd backend python -m app.gateway.auth.reset_admin -# 记录密码 P1 +cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p1.txt +P1=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p1.txt) python -m app.gateway.auth.reset_admin -# 记录密码 P2 +cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p2.txt +P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt) ``` **预期:** +- [ ] `.deer-flow/admin_initial_credentials.txt` 每次都会被重写,文件权限为 `0600` - [ ] P1 ≠ P2(每次生成新随机密码) - [ ] P1 不可用,只有 P2 有效 - [ ] `token_version` 递增了 2 @@ -1324,7 +1362,8 @@ done ```bash GW=http://localhost:8001 -for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local /api/v1/auth/register; do +for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local \ + /api/v1/auth/register /api/v1/auth/initialize /api/v1/auth/logout; do echo "$path: $(curl -s -w '%{http_code}' -o /dev/null $GW$path)" done # 预期: 200 或 405/422(方法不对但不是 401) @@ -1399,9 +1438,9 @@ done > > 前置条件: > - `.env` 中设置 `AUTH_JWT_SECRET`(否则每次容器重启 session 全部失效) -> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `users.db`) +> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `deerflow.db`) -#### TC-DOCKER-01: users.db 通过 volume 持久化 +#### TC-DOCKER-01: deerflow.db 通过 volume 持久化 ```bash # 启动容器 @@ -1416,13 +1455,13 @@ curl -s -X POST $BASE/api/v1/auth/register \ -H "Content-Type: application/json" \ -d '{"email":"docker-test@example.com","password":"DockerTest1!"}' -w "\nHTTP %{http_code}" -# 检查宿主机上的 users.db -ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db -sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db \ +# 检查宿主机上的 deerflow.db +ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db +sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db \ "SELECT email FROM users WHERE email='docker-test@example.com';" ``` -**预期:** users.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。 +**预期:** deerflow.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。 #### TC-DOCKER-02: 重启容器后 session 保持 @@ -1466,22 +1505,24 @@ done **已知限制:** In-process rate limiter 不跨 worker 共享。生产环境如需精确限速,需要 Redis 等外部存储。 -#### TC-DOCKER-04: IM 渠道不经过 auth +#### TC-DOCKER-04: IM 渠道使用内部认证 ```bash -# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 通信 -# 不走 nginx,不经过 AuthMiddleware +# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 调 Gateway +# 请求携带 process-local internal auth header,并带匹配的 CSRF cookie/header # 验证方式:检查 gateway 日志中 channel manager 的请求不包含 auth 错误 docker logs deer-flow-gateway 2>&1 | grep -E "ChannelManager|channel" | head -10 ``` -**预期:** 无 auth 相关错误。渠道通过 `langgraph-sdk` 直连 LangGraph Server(`http://langgraph:2024`),不走 auth 层。 +**预期:** 无 auth 相关错误。渠道不依赖浏览器 cookie;服务端通过内部认证头把请求归入 `default` 用户桶。 -#### TC-DOCKER-05: admin 密码写入 0600 凭证文件(不再走日志) +#### TC-DOCKER-05: reset_admin 密码写入 0600 凭证文件(不再走日志) ```bash -# 凭证文件写在挂载到宿主机的 DEER_FLOW_HOME 下 +# 首次启动不会自动生成 admin 密码。先重置已有 admin,凭据文件写在挂载到宿主机的 DEER_FLOW_HOME 下。 +docker exec deer-flow-gateway python -m app.gateway.auth.reset_admin --email docker-test@example.com + ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/admin_initial_credentials.txt # 预期文件权限: -rw------- (0600) @@ -1512,14 +1553,15 @@ sleep 15 docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l # 预期: 0 -# auth 流程正常 +# auth 流程正常:未登录受保护接口返回 401 curl -s -w "%{http_code}" -o /dev/null $BASE/api/models # 预期: 401 -curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<日志密码>" \ +curl -s -X POST $BASE/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ -c cookies.txt -w "\nHTTP %{http_code}" -# 预期: 200 +# 预期: 201 ``` ### 7.4 补充边界用例 @@ -1587,13 +1629,15 @@ curl -s -D - -X POST $BASE/api/v1/auth/login/local \ #### TC-EDGE-05: HTTP 无 max_age / HTTPS 有 max_age ```bash +GW=http://localhost:8001 + # HTTP -curl -s -D - -X POST $BASE/api/v1/auth/login/local \ +curl -s -D - -X POST $GW/api/v1/auth/login/local \ -d "username=admin@example.com&password=正确密码" 2>/dev/null \ | grep "access_token=" | grep -oi "max-age=[0-9]*" || echo "NO max-age (HTTP session cookie)" -# HTTPS -curl -s -D - -X POST $BASE/api/v1/auth/login/local \ +# HTTPS:直连 Gateway 才能用 X-Forwarded-Proto 模拟 HTTPS;nginx 会覆盖该 header +curl -s -D - -X POST $GW/api/v1/auth/login/local \ -H "X-Forwarded-Proto: https" \ -d "username=admin@example.com&password=正确密码" 2>/dev/null \ | grep "access_token=" | grep -oi "max-age=[0-9]*" @@ -1712,10 +1756,10 @@ curl -s -X POST $BASE/api/threads \ -b cookies.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"metadata":{"owner_id":"victim-user-id"}}' | jq .metadata.owner_id + -d '{"metadata":{"owner_id":"victim-user-id","user_id":"victim-user-id"}}' | jq .metadata ``` -**预期:** 返回的 `metadata.owner_id` 应为当前登录用户的 ID,不是请求中注入的 `victim-user-id`。服务端应覆盖客户端提供的 `user_id`。 +**预期:** 返回的 `metadata` 不包含 `owner_id` 或 `user_id`。真实所有权写入 `threads_meta.user_id`,不从客户端 metadata 接收,也不通过 metadata 回显。 #### 7.5.6 HTTP Method 探测 @@ -1796,6 +1840,6 @@ cd backend && PYTHONPATH=. uv run pytest \ # 核心接口冒烟 curl -s $BASE/health # 200 curl -s $BASE/api/models # 401 (无 cookie) -curl -s -X POST $BASE/api/v1/auth/setup-status # 200 +curl -s $BASE/api/v1/auth/setup-status # 200 curl -s $BASE/api/v1/auth/me -b cookies.txt # 200 (有 cookie) ``` diff --git a/backend/docs/AUTH_UPGRADE.md b/backend/docs/AUTH_UPGRADE.md index 344c488c4..75fe8b3cb 100644 --- a/backend/docs/AUTH_UPGRADE.md +++ b/backend/docs/AUTH_UPGRADE.md @@ -2,13 +2,16 @@ DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。 +完整设计见 [AUTH_DESIGN.md](AUTH_DESIGN.md)。 + ## 核心概念 认证模块采用**始终强制**策略: -- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志 +- 首次启动时不会自动创建账号;首次访问 `/setup` 时由操作者创建第一个 admin 账号 - 认证从一开始就是强制的,无竞争窗口 -- 历史对话(升级前创建的 thread)自动迁移到 admin 名下 +- 已有 admin 后,服务启动时会把历史对话(升级前创建且缺少 `user_id` 的 thread)迁移到 admin 名下 +- 新数据按用户隔离:thread、workspace/uploads/outputs、memory、自定义 agent 都归属当前用户 ## 升级步骤 @@ -25,39 +28,41 @@ cd backend && make install make dev ``` -控制台会输出: +如果没有 admin 账号,控制台只会提示: ``` ============================================================ - Admin account created on first boot - Email: admin@deerflow.dev - Password: aB3xK9mN_pQ7rT2w - Change it after login: Settings → Account + First boot detected — no admin account exists. + Visit /setup to complete admin account creation. ============================================================ ``` -如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。 +首次启动不会在日志里打印随机密码,也不会写入默认 admin。这样避免启动日志泄露凭据,也避免在操作者创建账号前出现可被猜测的默认身份。 -### 3. 登录 +### 3. 创建 admin -访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。 +访问 `http://localhost:2026/setup`,填写邮箱和密码创建第一个 admin 账号。创建成功后会自动登录并进入 workspace。 -### 4. 修改密码 +如果这是从无认证版本升级,创建 admin 后重启一次服务,让启动迁移把缺少 `user_id` 的历史 thread 归属到 admin。 -登录后进入 Settings → Account → Change Password。 +### 4. 登录 + +后续访问 `http://localhost:2026/login`,使用已创建的邮箱和密码登录。 ### 5. 添加用户(可选) -其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。 +其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话、上传文件、输出文件、memory 和自定义 agent。 ## 安全机制 | 机制 | 说明 | |------|------| | JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 | -| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` | +| CSRF Double Submit Cookie | 受保护的 POST/PUT/PATCH/DELETE 请求需携带 `X-CSRF-Token`;登录/注册/初始化/登出走 auth 端点 Origin 校验 | | bcrypt 密码哈希 | 密码不以明文存储 | -| 多租户隔离 | 用户只能访问自己的 thread | +| Thread owner filter | `threads_meta.user_id` 由服务端认证上下文写入,搜索、读取、更新、删除默认按当前用户过滤 | +| 文件系统隔离 | 线程数据写入 `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/`,sandbox 内统一映射为 `/mnt/user-data/` | +| Memory / agent 隔离 | 用户 memory 和自定义 agent 写入 `{base_dir}/users/{user_id}/...`;旧共享 agent 只作为只读兼容回退 | | HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 | ## 常见操作 @@ -74,22 +79,26 @@ python -m app.gateway.auth.reset_admin python -m app.gateway.auth.reset_admin --email user@example.com ``` -会输出新的随机密码。 +会把新的随机密码写入 `.deer-flow/admin_initial_credentials.txt`,文件权限为 `0600`。命令行只输出文件路径,不输出明文密码。 ### 完全重置 -删除用户数据库,重启后自动创建新 admin: +删除统一 SQLite 数据库,重启后重新访问 `/setup` 创建新 admin: ```bash -rm -f backend/.deer-flow/users.db -# 重启服务,控制台输出新密码 +rm -f backend/.deer-flow/data/deerflow.db +# 重启服务后访问 http://localhost:2026/setup ``` ## 数据存储 | 文件 | 内容 | |------|------| -| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) | +| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库(users、threads_meta、runs、feedback 等应用数据) | +| `.deer-flow/users/{user_id}/threads/{thread_id}/user-data/` | 用户线程的 workspace、uploads、outputs | +| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory | +| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | +| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | | `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) | ### 生产环境建议 @@ -111,19 +120,21 @@ python -c "import secrets; print(secrets.token_urlsafe(32))" | `/api/v1/auth/me` | GET | 获取当前用户信息 | | `/api/v1/auth/change-password` | POST | 修改密码 | | `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 | +| `/api/v1/auth/initialize` | POST | 首次初始化第一个 admin(仅无 admin 时可调用) | ## 兼容性 -- **标准模式**(`make dev`):完全兼容,admin 自动创建 +- **标准模式**(`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化 - **Gateway 模式**(`make dev-pro`):完全兼容 -- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载 -- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层 +- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载 +- **IM 渠道**(Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶 - **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响 ## 故障排查 | 症状 | 原因 | 解决 | |------|------|------| -| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` | +| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` | +| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin | | 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 | | 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 | diff --git a/backend/docs/README.md b/backend/docs/README.md index da566005d..27e33f854 100644 --- a/backend/docs/README.md +++ b/backend/docs/README.md @@ -8,6 +8,7 @@ This directory contains detailed documentation for the DeerFlow backend. |----------|-------------| | [ARCHITECTURE.md](ARCHITECTURE.md) | System architecture overview | | [API.md](API.md) | Complete API reference | +| [AUTH_DESIGN.md](AUTH_DESIGN.md) | User authentication, CSRF, and per-user isolation design | | [CONFIGURATION.md](CONFIGURATION.md) | Configuration options | | [SETUP.md](SETUP.md) | Quick setup guide | @@ -42,6 +43,7 @@ docs/ ├── README.md # This file ├── ARCHITECTURE.md # System architecture ├── API.md # API reference +├── AUTH_DESIGN.md # User authentication and isolation design ├── CONFIGURATION.md # Configuration guide ├── SETUP.md # Setup instructions ├── FILE_UPLOAD.md # File upload feature From 506be8bffda8413ee0506f198ff47def931294db Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Tue, 12 May 2026 23:15:11 +0800 Subject: [PATCH 03/12] docs: clarify LangGraph compatibility entrypoints (#2914) --- backend/README.md | 4 ++++ backend/app/gateway/langgraph_auth.py | 12 ++++++++---- backend/docs/ARCHITECTURE.md | 3 ++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/backend/README.md b/backend/README.md index 18d89c2be..8c61e2db2 100644 --- a/backend/README.md +++ b/backend/README.md @@ -242,6 +242,10 @@ backend/ └── Dockerfile # Container build ``` +`langgraph.json` is not the default service entrypoint. The scripts and Docker +deployments run the Gateway embedded runtime; the file is kept for LangGraph +tooling, Studio, or direct LangGraph Server compatibility. + --- ## Configuration diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py index 38e020150..202fab2d5 100644 --- a/backend/app/gateway/langgraph_auth.py +++ b/backend/app/gateway/langgraph_auth.py @@ -1,8 +1,12 @@ -"""LangGraph Server auth handler — shares JWT logic with Gateway. +"""LangGraph compatibility auth handler — shares JWT logic with Gateway. -Loaded by LangGraph Server via langgraph.json ``auth.path``. -Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway, -so both modes validate tokens with the same secret and rules. +The default DeerFlow runtime is embedded in the FastAPI Gateway; scripts and +Docker deployments do not load this module. It is retained for LangGraph +tooling, Studio, or direct LangGraph Server compatibility through +``langgraph.json``'s ``auth.path``. + +When that compatibility path is used, this module reuses the same JWT and CSRF +rules as Gateway so both modes validate sessions consistently. Two layers: 1. @auth.authenticate — validates JWT cookie, extracts user_id, diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index f1557a6fb..47859cc9c 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -63,7 +63,8 @@ The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for - Tool execution orchestration - SSE streaming for real-time responses -**Graph registry**: `langgraph.json` remains available for tooling and Studio compatibility. +**Graph registry**: `langgraph.json` remains available for tooling, Studio, or direct LangGraph Server compatibility. +It is not the default service entrypoint; scripts and Docker deployments run the Gateway embedded runtime. ```json { From 68d8caec1f6b543fa7936d8a0c382f33726e00b0 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Tue, 12 May 2026 23:18:54 +0800 Subject: [PATCH 04/12] fix(agents): make update_agent honor runtime.context user_id like setup_agent (#2867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(agents): make update_agent honor runtime.context user_id like setup_agent PR #2784 hardened setup_agent to prefer runtime.context["user_id"] (set by inject_authenticated_user_context from the auth-validated request) over the contextvar, so an agent created during the bootstrap flow always lands under users//agents/. update_agent was left calling get_effective_user_id() unconditionally — the same class of bug that produced issues #2782 / #2862 still applies whenever the contextvar is not available on the executing task (background work, future cross-process drivers, checkpoint resume on a different task). In that regime update_agent silently routes writes to users/default/agents/, corrupting the shared default bucket and losing the user's edit. Extract the resolution policy into a shared resolve_runtime_user_id helper on deerflow.runtime.user_context and route both setup_agent and update_agent through it so the two halves of the lifecycle stay in lockstep. Add load-bearing end-to-end tests that drive a real langchain.agents create_agent graph with a fake LLM, exercising the full pipeline: HTTP wire format -> app.gateway.services.start_run config-assembly -> deerflow.runtime.runs.worker._build_runtime_context -> langchain.agents create_agent graph -> ToolNode dispatch (sync + async + sub-graph + ContextThreadPoolExecutor) -> setup_agent / update_agent The negative-control tests intentionally land in users/default/ to prove the positive tests are actually load-bearing rather than vacuously passing. The new test_update_agent_e2e_user_isolation suite included a test that failed against main and now passes after this fix. * style: ruff format on new e2e tests * test(e2e): real-server HTTP test driving setup_agent through the full ASGI stack Adds tests/test_setup_agent_http_e2e_real_server.py — a single load-bearing test that drives the entire FastAPI gateway through starlette.testclient. TestClient with no mocks above the LLM: - lifespan boots (config, sqlite engine, LangGraph runtime, channels) - POST /api/v1/auth/register (real password hash, real sqlite write, issues access_token + csrf_token cookies) - POST /api/threads (real thread_meta + checkpoint creation) - POST /api/threads/{id}/runs/stream with the exact wire shape the React frontend sends (assistant_id + input + config + context with agent_name/is_bootstrap) - AuthMiddleware -> CSRFMiddleware -> require_permission -> start_run -> inject_authenticated_user_context -> asyncio.create_task(run_agent) -> worker._build_runtime_context -> Runtime injection -> ToolNode dispatch -> real setup_agent - Asserts SOUL.md is under users//agents// and NOT under users/default/agents//. DEER_FLOW_HOME and the sqlite path are redirected into tmp_path so the test never touches the real .deer-flow directory or developer database. The only patch above the LLM boundary is replacing create_chat_model with a fake that emits a single setup_agent tool_call. This is the "真实验证" answer: it reproduces what curl-against-uvicorn would do, minus the network socket layer. * test: address Copilot review on user-isolation e2e tests - Drop "currently expected to FAIL" wording from update_agent e2e docstring and header (Copilot review): the fix is in this PR, the test pins the corrected behaviour rather than driving a future change. - Rephrase the assertion failure messages from "BUG:" to "REGRESSION:" to match the test's role on the fixed branch. - Bound _drain_stream with a wall-clock timeout, a max-bytes cap, and an early break on the "event: end" SSE frame (Copilot review). Stops the test from hanging on a stuck run or runaway heartbeat loop. - Replace the misleading "patch both module aliases" comment with an explanation of why patching lead_agent.agent.create_chat_model is the only correct target (Copilot review): lead_agent rebinds the symbol into its own namespace at import time, so patching deerflow.models is too late. * test(refactor): address WillemJiang review on user-isolation e2e tests - Extract the duplicated FakeToolCallingModel (and a build_single_tool_call_model helper) into tests/_agent_e2e_helpers.py. All three e2e files now import from the shared module instead of redefining the shim locally. - Convert the manual p.start() / p.stop() try/finally blocks in test_update_agent_e2e_user_isolation.py to contextlib.ExitStack so patch lifecycle is Pythonic and exception-safe. - Lift the isolated_app fixture's private-attribute resets into a named _reset_process_singletons helper with a comment block explaining why each singleton has to be invalidated for true e2e isolation, and why raising=False is intentional. Makes the fragility visible and the intent self-documenting rather than leaving the resets inline as opaque monkeypatch calls. Net change: -59 lines (143 -> 84) across the three test files, with every assertion intact. Full suite remains 69 passed / lint clean. * test(e2e): make real-server test self-supply its config CI's actions/checkout only ships config.example.yaml (the real config.yaml is gitignored), so the production config-discovery search (./config.yaml -> ../config.yaml -> $DEER_FLOW_CONFIG_PATH) finds nothing and the test fails at lifespan boot with FileNotFoundError. The dev-machine run passed only because a local config.yaml happened to exist. Write a minimal AppConfig-valid yaml into tmp_path and pin DEER_FLOW_CONFIG_PATH to it. The yaml carries just what the schema requires (a single fake-test-model entry, LocalSandboxProvider, sqlite database). The LLM never gets instantiated because the test patches create_chat_model on the lead agent module, so the api_key/base_url stay placeholders. Verified by hiding the local config.yaml to mirror the CI checkout — the test now passes in both environments. --- .../harness/deerflow/runtime/user_context.py | 28 ++ .../tools/builtins/setup_agent_tool.py | 11 +- .../tools/builtins/update_agent_tool.py | 12 +- backend/tests/_agent_e2e_helpers.py | 68 +++ .../test_setup_agent_e2e_user_isolation.py | 429 ++++++++++++++++++ .../test_setup_agent_http_e2e_real_server.py | 326 +++++++++++++ .../test_update_agent_e2e_user_isolation.py | 253 +++++++++++ 7 files changed, 1114 insertions(+), 13 deletions(-) create mode 100644 backend/tests/_agent_e2e_helpers.py create mode 100644 backend/tests/test_setup_agent_e2e_user_isolation.py create mode 100644 backend/tests/test_setup_agent_http_e2e_real_server.py create mode 100644 backend/tests/test_update_agent_e2e_user_isolation.py diff --git a/backend/packages/harness/deerflow/runtime/user_context.py b/backend/packages/harness/deerflow/runtime/user_context.py index ffe4be690..cfbb68c94 100644 --- a/backend/packages/harness/deerflow/runtime/user_context.py +++ b/backend/packages/harness/deerflow/runtime/user_context.py @@ -109,6 +109,34 @@ def get_effective_user_id() -> str: return str(user.id) +def resolve_runtime_user_id(runtime: object | None) -> str: + """Single source of truth for a tool/middleware's effective user_id. + + Resolution order (most authoritative first): + 1. ``runtime.context["user_id"]`` — set by ``inject_authenticated_user_context`` + in the gateway from the auth-validated ``request.state.user``. This is + the only source that survives boundaries where the contextvar may have + been lost (background tasks scheduled outside the request task, + worker pools that don't copy_context, future cross-process drivers). + 2. The ``_current_user`` ContextVar — set by the auth middleware at + request entry. Reliable for in-task work; copied by ``asyncio`` + child tasks and by ``ContextThreadPoolExecutor``. + 3. ``DEFAULT_USER_ID`` — last-resort fallback so unauthenticated + CLI / migration / test paths keep working without raising. + + Tools that persist user-scoped state (custom agents, memory, uploads) + MUST call this instead of ``get_effective_user_id()`` directly so they + benefit from the runtime.context channel that ``setup_agent`` already + relies on. + """ + context = getattr(runtime, "context", None) + if isinstance(context, dict): + ctx_user_id = context.get("user_id") + if ctx_user_id: + return str(ctx_user_id) + return get_effective_user_id() + + # --------------------------------------------------------------------------- # Sentinel-based user_id resolution # --------------------------------------------------------------------------- diff --git a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py index 2f796b005..dfbcf8b6e 100644 --- a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py @@ -7,19 +7,12 @@ from langgraph.types import Command from deerflow.config.agents_config import validate_agent_name from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.user_context import resolve_runtime_user_id from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) -def _get_runtime_user_id(runtime: Runtime) -> str: - context_user_id = runtime.context.get("user_id") if runtime.context else None - if context_user_id: - return str(context_user_id) - return get_effective_user_id() - - @tool(parse_docstring=True) def setup_agent( soul: str, @@ -45,7 +38,7 @@ def setup_agent( if agent_name: # Custom agents are persisted under the current user's bucket so # different users do not see each other's agents. - user_id = _get_runtime_user_id(runtime) + user_id = resolve_runtime_user_id(runtime) agent_dir = paths.user_agent_dir(user_id, agent_name) else: # Default agent (no agent_name): SOUL.md lives at the global base dir. diff --git a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py index b2dc8ca72..18500a248 100644 --- a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py @@ -27,7 +27,7 @@ from langgraph.types import Command from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.app_config import get_app_config from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.user_context import resolve_runtime_user_id from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -118,9 +118,13 @@ def update_agent( return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.") # Resolve the active user so that updates only affect this user's agent. - # ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context - # is set (matching how memory and thread storage behave). - user_id = get_effective_user_id() + # ``resolve_runtime_user_id`` prefers ``runtime.context["user_id"]`` (set by + # the gateway from the auth-validated request) and falls back to the + # contextvar, then DEFAULT_USER_ID. This matches setup_agent so a user + # creating an agent and later refining it always touches the same files, + # even if the contextvar gets lost across an async/thread boundary + # (issue #2782 / #2862 class of bugs). + user_id = resolve_runtime_user_id(runtime) # Reject an unknown ``model`` *before* touching the filesystem. Otherwise # ``_resolve_model_name`` silently falls back to the default at runtime diff --git a/backend/tests/_agent_e2e_helpers.py b/backend/tests/_agent_e2e_helpers.py new file mode 100644 index 000000000..2f28390a9 --- /dev/null +++ b/backend/tests/_agent_e2e_helpers.py @@ -0,0 +1,68 @@ +"""Shared helpers for user-isolation e2e tests on the custom-agent tooling. + +Centralises the small fake-LLM shim and a few test-data builders that the +three e2e files in this PR (``test_setup_agent_e2e_user_isolation``, +``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``) +all need. The shim is what lets a real ``langchain.agents.create_agent`` +graph run without an API key — every other layer in those tests is real +production code, which is the entire point of the test design. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage +from langchain_core.runnables import Runnable + + +class FakeToolCallingModel(FakeMessagesListChatModel): + """FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent. + + ``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to + expose the tool schemas to the model; the upstream fake raises + ``NotImplementedError`` there. We just return ``self`` because we + drive deterministic tool_call output via ``responses=...``, no schema + handling needed. + """ + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + +def build_single_tool_call_model( + *, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str = "call_e2e_1", + final_text: str = "done", +) -> FakeToolCallingModel: + """Build a fake model that emits exactly one tool_call then finishes. + + Two-turn behaviour, identical across our e2e tests: + turn 1 → AIMessage with a single tool_call for *tool_name* + turn 2 → AIMessage with *final_text* (terminates the agent loop) + """ + return FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": tool_args, + "id": tool_call_id, + "type": "tool_call", + } + ], + ), + AIMessage(content=final_text), + ] + ) diff --git a/backend/tests/test_setup_agent_e2e_user_isolation.py b/backend/tests/test_setup_agent_e2e_user_isolation.py new file mode 100644 index 000000000..034d4da84 --- /dev/null +++ b/backend/tests/test_setup_agent_e2e_user_isolation.py @@ -0,0 +1,429 @@ +"""End-to-end verification for issue #2862 (and the regression of #2782). + +Goal: prove — without trusting any single layer's claim — that an authenticated +user creating a custom agent through the real ``setup_agent`` tool, driven by a +real LangGraph ``create_agent`` graph, ends up with files under +``users//agents/`` and **not** under ``users/default/agents/...``. + +We intentionally exercise the full pipeline: + + HTTP body shape (mimics LangGraph SDK wire format) + -> app.gateway.services.start_run config-assembly chain + -> deerflow.runtime.runs.worker._build_runtime_context + -> langchain.agents.create_agent graph + -> ToolNode dispatch + -> setup_agent tool + +The only thing we mock is the LLM (FakeMessagesListChatModel) — every layer +that handles ``user_id`` is the real production code path. If the +``user_id`` propagation is broken anywhere in this chain, these tests will +fail. + +These tests intentionally ``no_auto_user`` so that the ``contextvar`` +fallback would put files into ``default/`` if propagation breaks. +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch +from uuid import UUID + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel +from langchain_core.messages import AIMessage, HumanMessage + +from app.gateway.services import ( + build_run_config, + inject_authenticated_user_context, + merge_run_context_overrides, +) +from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context + +# --------------------------------------------------------------------------- +# Helpers — real production code paths +# --------------------------------------------------------------------------- + + +def _make_request(user_id_str: str | None) -> SimpleNamespace: + """Build a fake FastAPI Request that carries an authenticated user.""" + if user_id_str is None: + user = None + else: + # User.id is UUID in production; honour that + user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") + return SimpleNamespace(state=SimpleNamespace(user=user)) + + +def _assemble_config( + *, + body_config: dict | None, + body_context: dict | None, + request_user_id: str | None, + thread_id: str = "thread-e2e", + assistant_id: str = "lead_agent", +) -> dict: + """Replay the **exact** start_run config-assembly sequence.""" + config = build_run_config(thread_id, body_config, None, assistant_id=assistant_id) + merge_run_context_overrides(config, body_context) + inject_authenticated_user_context(config, _make_request(request_user_id)) + return config + + +def _make_paths_mock(tmp_path: Path): + """Mirror the production paths.user_agent_dir signature.""" + from unittest.mock import MagicMock + + paths = MagicMock() + paths.base_dir = tmp_path + paths.agent_dir = lambda name: tmp_path / "agents" / name + paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name + return paths + + +# --------------------------------------------------------------------------- +# L1-L3: HTTP wire format → start_run → worker._build_runtime_context +# --------------------------------------------------------------------------- + + +class TestConfigAssembly: + """Covers L1-L3: validate that user_id reaches runtime_ctx for every wire shape.""" + + def test_typical_wire_format_user_id_in_runtime_ctx(self): + """Real frontend: body.config={recursion_limit}, body.context={agent_name,...}.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent", "is_bootstrap": True, "mode": "flash"}, + request_user_id="11111111-2222-3333-4444-555555555555", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555" + assert runtime_ctx["agent_name"] == "myagent" + + def test_body_context_none_still_injects_user_id(self): + """If frontend omits body.context entirely, inject must still create it.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context=None, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_body_context_empty_dict_still_injects_user_id(self): + """body.context={} (falsy) path: inject must still produce user_id.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={}, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_body_config_already_contains_context_field(self): + """body.config={'context': {...}} (LG 0.6 alt wire): inject still wins.""" + config = _assemble_config( + body_config={"context": {"agent_name": "myagent"}, "recursion_limit": 1000}, + body_context=None, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_client_supplied_user_id_is_overridden(self): + """Spoofed client user_id must be overwritten by inject (auth-trusted source).""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent", "user_id": "spoofed"}, + request_user_id="11111111-2222-3333-4444-555555555555", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555" + + def test_unauthenticated_request_does_not_inject(self): + """If request.state.user is missing (impossible under fail-closed auth, but + verify defensively), inject must not write user_id and runtime_ctx must + therefore lack it — forcing the tool fallback path to reveal itself.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent"}, + request_user_id=None, + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert "user_id" not in runtime_ctx + + +# --------------------------------------------------------------------------- +# L4-L7: Real LangGraph create_agent driving the real setup_agent tool +# --------------------------------------------------------------------------- + + +def _build_real_bootstrap_graph(authenticated_user_id: str): + """Construct a real LangGraph using create_agent + the real setup_agent tool. + + The LLM is faked (FakeMessagesListChatModel) so we don't need an API key. + Everything else — ToolNode dispatch, runtime injection, middleware — is + the real production code path. + """ + from langchain.agents import create_agent + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + # First model turn: emit a tool_call for setup_agent + # Second model turn (after tool result): final answer (terminates the loop) + fake_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": { + "soul": "# My E2E Agent\n\nA SOUL written by the model.", + "description": "End-to-end test agent", + }, + "id": "call_setup_1", + "type": "tool_call", + } + ], + ), + AIMessage(content=f"Done. Agent created for user {authenticated_user_id}."), + ] + ) + + graph = create_agent( + model=fake_model, + tools=[setup_agent], + system_prompt="You are a bootstrap agent. Call setup_agent immediately.", + ) + return graph + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_real_graph_real_setup_agent_writes_to_authenticated_user_dir(tmp_path: Path): + """The smoking-gun test for issue #2862. + + Under no_auto_user (contextvar = empty), if user_id propagation through + runtime.context is broken, setup_agent will fall back to DEFAULT_USER_ID + and write to users/default/agents/... The assertion that this directory + DOES NOT exist is what makes this test load-bearing. + """ + from langgraph.runtime import Runtime + + auth_uid = "abcdef01-2345-6789-abcd-ef0123456789" + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "e2e-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-1", + ) + + # Replay worker.run_agent's runtime construction. This is the key step: + # it is what makes ToolRuntime.context contain user_id when the tool + # actually fires. + runtime_ctx = _build_runtime_context("thread-e2e-1", "run-1", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_real_bootstrap_graph(auth_uid) + + # Patch get_paths only (the file-system rooting); everything else is real + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Drive the real graph. This goes through real ToolNode + real Runtime merge. + final_state = await graph.ainvoke( + {"messages": [HumanMessage(content="Create an agent named e2e-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "e2e-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "e2e-agent" + + # Load-bearing assertions: + assert expected_dir.exists(), f"Agent directory not found at the authenticated user's path. Expected: {expected_dir}. tmp_path tree: {[str(p) for p in tmp_path.rglob('*')]}" + assert (expected_dir / "SOUL.md").read_text() == "# My E2E Agent\n\nA SOUL written by the model." + assert (expected_dir / "config.yaml").exists() + assert not default_dir.exists(), "REGRESSION: agent landed under users/default/. user_id propagation broke somewhere between HTTP layer and ToolRuntime.context." + + # And final state should reflect tool success + last = final_state["messages"][-1] + assert "Done" in (last.content if isinstance(last.content, str) else str(last.content)) + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_inject_failure_falls_back_to_default_proving_test_is_load_bearing(tmp_path: Path): + """Negative control: if inject does NOT happen (no user in request), and + contextvar is empty (no_auto_user), setup_agent must land in default/. + + This proves the positive test is actually load-bearing — i.e. it would + have failed before PR #2784, not passed accidentally. + """ + from langgraph.runtime import Runtime + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "fallback-agent", "is_bootstrap": True}, + request_user_id=None, # no auth — inject is a no-op + thread_id="thread-e2e-2", + ) + + runtime_ctx = _build_runtime_context("thread-e2e-2", "run-2", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_real_bootstrap_graph("does-not-matter") + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + await graph.ainvoke( + {"messages": [HumanMessage(content="Create fallback-agent")]}, + config=config, + ) + + default_dir = tmp_path / "users" / "default" / "agents" / "fallback-agent" + assert default_dir.exists(), "Negative control failed: even without inject + contextvar, agent did not land in default/. The test infrastructure may not be reproducing the bug condition." + + +# --------------------------------------------------------------------------- +# L5: Sub-graph runtime propagation (the task tool case) +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_subgraph_invocation_preserves_user_id_in_runtime(tmp_path: Path): + """When a parent graph invokes a child graph (the pattern used by + subagents), parent_runtime.merge() must keep user_id intact. + + We construct a child graph that contains setup_agent and call it from + a parent graph's tool. If LangGraph re-creates the Runtime and drops + user_id at the sub-graph boundary, this fails. + """ + from langchain.agents import create_agent + from langgraph.runtime import Runtime + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + auth_uid = "deadbeef-0000-1111-2222-333344445555" + + # Inner graph: same as the bootstrap flow + inner_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": {"soul": "# Inner", "description": "subgraph"}, + "id": "call_inner_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="inner done"), + ] + ) + inner_graph = create_agent( + model=inner_model, + tools=[setup_agent], + system_prompt="inner", + ) + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "subgraph-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-3", + ) + runtime_ctx = _build_runtime_context("thread-e2e-3", "run-3", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Direct sub-graph invoke (mimics what a subagent invocation looks like + # — distinct ainvoke call, but parent config carries the same runtime). + await inner_graph.ainvoke( + {"messages": [HumanMessage(content="Create subgraph-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "subgraph-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "subgraph-agent" + assert expected_dir.exists() + assert not default_dir.exists() + + +# --------------------------------------------------------------------------- +# L6: Sync tool path through ContextThreadPoolExecutor +# --------------------------------------------------------------------------- + + +def test_sync_tool_dispatch_through_thread_pool_uses_runtime_context(tmp_path: Path): + """setup_agent is a sync function. When dispatched through ToolNode's + ContextThreadPoolExecutor, runtime.context must still carry user_id — + not via thread-local copy_context (which only carries contextvars), but + because it was passed in as the ToolRuntime constructor argument. + """ + from langchain.agents import create_agent + from langgraph.runtime import Runtime + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + auth_uid = "11112222-3333-4444-5555-666677778888" + + fake_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": {"soul": "# Sync", "description": "sync path"}, + "id": "call_sync_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="sync done"), + ] + ) + graph = create_agent(model=fake_model, tools=[setup_agent], system_prompt="sync") + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "sync-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-4", + ) + runtime_ctx = _build_runtime_context("thread-e2e-4", "run-4", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Use SYNC invoke to hit the ContextThreadPoolExecutor path + graph.invoke( + {"messages": [HumanMessage(content="Create sync-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "sync-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "sync-agent" + assert expected_dir.exists() + assert not default_dir.exists() diff --git a/backend/tests/test_setup_agent_http_e2e_real_server.py b/backend/tests/test_setup_agent_http_e2e_real_server.py new file mode 100644 index 000000000..950d040a0 --- /dev/null +++ b/backend/tests/test_setup_agent_http_e2e_real_server.py @@ -0,0 +1,326 @@ +"""Real HTTP end-to-end verification for issue #2862's setup_agent path. + +This test drives the **entire** FastAPI gateway through ``starlette.testclient.TestClient``: + + starlette.testclient.TestClient (real ASGI stack) + -> AuthMiddleware (real cookie parsing, real JWT decode) + -> /api/v1/auth/register endpoint (real password hash + sqlite write) + -> /api/threads/{id}/runs/stream endpoint (real start_run config-assembly) + -> background asyncio.create_task(run_agent) (real worker, real Runtime) + -> langchain.agents.create_agent graph (real, with fake LLM) + -> ToolNode dispatch (real) + -> setup_agent tool (real file I/O) + +The only mock is the LLM (no API key needed). Every layer that participates +in ``user_id`` propagation — auth, ContextVar, ``inject_authenticated_user_context``, +``worker._build_runtime_context``, ``Runtime.merge`` — is the real production +code path. If the chain is broken at any layer, this test fails. + +This is what "真实验证" looks like for a server that lives behind authentication: +register a user, log in (cookie), POST to /runs/stream, wait for the run to +finish, then read the filesystem. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model + + +def _build_fake_create_chat_model(agent_name: str): + """Return a callable matching the real ``create_chat_model`` signature. + + Whenever the lead agent constructs a chat model during the bootstrap flow, + we hand it a fake that emits a single setup_agent tool_call on its first + turn, then a benign final answer on its second turn. + """ + + def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel: + return build_single_tool_call_model( + tool_name="setup_agent", + tool_args={ + "soul": f"# Real HTTP E2E SOUL for {agent_name}", + "description": "real-http-e2e agent", + }, + tool_call_id="call_real_http_1", + final_text=f"Agent {agent_name} created via real HTTP e2e.", + ) + + return fake_create_chat_model + + +@pytest.fixture +def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Stand up an isolated DeerFlow data root + config under tmp_path. + + - Sets ``DEER_FLOW_HOME`` so paths land under tmp_path, not the real + ``.deer-flow`` directory. + - Stages a copy of the project's ``config.yaml`` (or ``config.example.yaml`` + on a fresh CI checkout where ``config.yaml`` is gitignored) and pins + ``DEER_FLOW_CONFIG_PATH`` to it, so lifespan boot doesn't depend on the + developer's local config layout. + - Sets a placeholder OPENAI_API_KEY because the config has + ``$OPENAI_API_KEY`` that gets resolved at parse time; the LLM itself is + mocked, so any non-empty value works. + """ + home = tmp_path / "deer-flow-home" + home.mkdir() + monkeypatch.setenv("DEER_FLOW_HOME", str(home)) + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used-because-llm-is-mocked") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + # Hermetic config: do not depend on whether the dev machine has a real + # ``config.yaml`` at the repo root. CI's ``actions/checkout`` only ships + # ``config.example.yaml`` (and its ``models:`` list is commented out, so + # AppConfig validation would reject it). Write a minimal, self-sufficient + # config to tmp_path and pin ``DEER_FLOW_CONFIG_PATH`` to it. + staged_config = tmp_path / "config.yaml" + staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config)) + + return home + + +# Minimal config that satisfies AppConfig + LeadAgent's _resolve_model_name. +# The model `use` path must resolve to a real class for config parsing to +# succeed; the test patches ``create_chat_model`` on the lead agent module, +# so the model is never actually instantiated. SandboxConfig.use is required +# at schema level; LocalSandboxProvider is the only sandbox that runs without +# Docker. +_MINIMAL_CONFIG_YAML = """\ +log_level: info +models: + - name: fake-test-model + display_name: Fake Test Model + use: langchain_openai:ChatOpenAI + model: gpt-4o-mini + api_key: $OPENAI_API_KEY + base_url: $OPENAI_API_BASE +sandbox: + use: deerflow.sandbox.local:LocalSandboxProvider +agents_api: + enabled: true +database: + backend: sqlite +""" + + +def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None: + """Reset every process-wide cache that would survive across tests. + + This fixture stands up a full FastAPI app + sqlite DB + LangGraph runtime + inside ``tmp_path``. To get true per-test isolation we have to invalidate + a handful of module-level caches that production normally never resets, + so they pick up our test-only ``DEER_FLOW_HOME`` and sqlite path: + + - ``deerflow.config.app_config`` caches the parsed ``config.yaml``. + - ``deerflow.config.paths`` caches the ``Paths`` singleton derived from + ``DEER_FLOW_HOME`` at first access. + - ``deerflow.persistence.engine`` caches the SQLAlchemy engine and + session factory after the first call to ``init_engine_from_config``. + + ``raising=False`` keeps the fixture resilient if upstream renames or + drops one of these attributes — the test will simply skip that reset + instead of failing with a confusing AttributeError, and the next test + to call ``get_app_config()``/``get_paths()`` will surface the real + incompatibility loudly. + """ + from deerflow.config import app_config as app_config_module + from deerflow.config import paths as paths_module + from deerflow.persistence import engine as engine_module + + for module, attr in ( + (app_config_module, "_app_config"), + (app_config_module, "_app_config_path"), + (app_config_module, "_app_config_mtime"), + (paths_module, "_paths_singleton"), + (engine_module, "_engine"), + (engine_module, "_session_factory"), + ): + monkeypatch.setattr(module, attr, None, raising=False) + + +@pytest.fixture +def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch): + """Build a fresh FastAPI app inside a clean DEER_FLOW_HOME. + + Each test gets its own sqlite DB and checkpoint store under ``tmp_path``, + with no cross-test contamination. + """ + _reset_process_singletons(monkeypatch) + + # Re-resolve the config from the test-only DEER_FLOW_HOME and pin its + # sqlite path into tmp_path so the lifespan-time engine init lands there. + from deerflow.config import app_config as app_config_module + + cfg = app_config_module.get_app_config() + cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db") + + from app.gateway.app import create_app + + return create_app() + + +def _drain_stream(response, *, timeout: float = 30.0, max_bytes: int = 4 * 1024 * 1024) -> str: + """Consume an SSE response body until the run terminates and return the text. + + Bounded to keep the test fail-fast: + - Stops as soon as an ``event: end`` SSE frame is observed (the gateway + sends this when the background run finishes — see ``services.format_sse`` + and ``StreamBridge.publish_end``). + - Stops at ``timeout`` seconds wall-clock so a stuck run / runaway heartbeat + loop surfaces a real failure instead of hanging pytest. + - Stops at ``max_bytes`` so a runaway producer can't OOM the test process. + """ + import time as _time + + deadline = _time.monotonic() + timeout + body = b"" + for chunk in response.iter_bytes(): + body += chunk + if b"event: end" in body: + break + if len(body) >= max_bytes: + break + if _time.monotonic() >= deadline: + break + return body.decode("utf-8", errors="replace") + + +def _wait_for_file(path: Path, *, timeout: float = 10.0) -> bool: + """Block until *path* exists or *timeout* elapses. + + The run completes inside ``asyncio.create_task`` after start_run returns, + so the test must wait for the background task to flush its writes. + """ + import time as _time + + deadline = _time.monotonic() + timeout + while _time.monotonic() < deadline: + if path.exists(): + return True + _time.sleep(0.05) + return False + + +@pytest.mark.no_auto_user +def test_real_http_create_agent_lands_in_authenticated_user_dir( + isolated_app: Any, + isolated_deer_flow_home: Path, + monkeypatch: pytest.MonkeyPatch, +): + """The full real-server contract test. + + 1. Register a real user via POST /api/v1/auth/register (also auto-logs in) + 2. POST to /api/threads/{tid}/runs/stream with the **exact** body shape the + frontend (LangGraph SDK) sends during the bootstrap flow. + 3. Wait for the background run to finish. + 4. Assert SOUL.md exists under users//agents//. + 5. Assert NOTHING exists under users/default/agents//. + """ + # ``deerflow.agents.lead_agent.agent`` imports ``create_chat_model`` with + # ``from deerflow.models import create_chat_model`` at module load time, + # rebinding the symbol into its own namespace. So the only patch that + # intercepts the call is the bound name on ``lead_agent.agent`` — patching + # ``deerflow.models.create_chat_model`` would be too late. + agent_name = "real-http-agent" + + from starlette.testclient import TestClient + + with ( + patch( + "deerflow.agents.lead_agent.agent.create_chat_model", + new=_build_fake_create_chat_model(agent_name), + ), + TestClient(isolated_app) as client, + ): + # --- 1. Register & auto-login --- + register = client.post( + "/api/v1/auth/register", + json={"email": "e2e-user@example.com", "password": "very-strong-password-123"}, + ) + assert register.status_code == 201, register.text + registered = register.json() + auth_uid = registered["id"] + # The endpoint sets both access_token (auth) and csrf_token (CSRF Double + # Submit Cookie) cookies; the TestClient cookie jar propagates them. + assert client.cookies.get("access_token"), "register endpoint must set session cookie" + csrf_token = client.cookies.get("csrf_token") + assert csrf_token, "register endpoint must set csrf_token cookie" + + # --- 2. Create a thread (require_existing=True on /runs/stream means + # we must call POST /api/threads first; the React frontend does the + # same via the LangGraph SDK's threads.create) --- + import uuid as _uuid + + thread_id = str(_uuid.uuid4()) + created = client.post( + "/api/threads", + json={"thread_id": thread_id, "metadata": {}}, + headers={"X-CSRF-Token": csrf_token}, + ) + assert created.status_code == 200, created.text + + # --- 3. POST /runs/stream with the bootstrap wire format --- + # This is the EXACT shape the React frontend sends after PR #2784: + # thread.submit(input, {config, context}) -> + # POST /api/threads/{id}/runs/stream body = + # {assistant_id, input, config, context} + body = { + "assistant_id": "lead_agent", + "input": { + "messages": [ + { + "role": "user", + "content": (f"The new custom agent name is {agent_name}. Help me design its SOUL.md before saving it."), + } + ] + }, + "config": {"recursion_limit": 50}, + "context": { + "agent_name": agent_name, + "is_bootstrap": True, + "mode": "flash", + "thinking_enabled": False, + "is_plan_mode": False, + "subagent_enabled": False, + }, + "stream_mode": ["values"], + } + # The /stream endpoint returns SSE; we drain it so the server-side + # background task (run_agent) gets to completion before we look at disk. + with client.stream( + "POST", + f"/api/threads/{thread_id}/runs/stream", + json=body, + headers={"X-CSRF-Token": csrf_token}, + ) as resp: + assert resp.status_code == 200, resp.read().decode() + transcript = _drain_stream(resp) + + # Sanity: the stream should have produced at least one event + assert "event:" in transcript, f"no SSE events in response: {transcript[:500]!r}" + + # --- 4. Verify filesystem outcome --- + expected_dir = isolated_deer_flow_home / "users" / auth_uid / "agents" / agent_name + default_dir = isolated_deer_flow_home / "users" / "default" / "agents" / agent_name + + # The setup_agent tool runs inside the background asyncio task spawned + # by start_run; SSE-drain typically waits for it, but we add a bounded + # poll to be robust against scheduler jitter. + assert _wait_for_file(expected_dir / "SOUL.md", timeout=15.0), ( + "SOUL.md did not appear under users//agents/. " + f"Expected: {expected_dir / 'SOUL.md'}. " + f"tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}. " + f"SSE transcript tail: {transcript[-1000:]!r}" + ) + + soul_text = (expected_dir / "SOUL.md").read_text() + assert agent_name in soul_text, f"unexpected SOUL content: {soul_text!r}" + + # The smoking-gun assertion: the agent must NOT have landed in default/ + assert not default_dir.exists(), f"REGRESSION: agent landed under users/default/{agent_name} instead of the authenticated user. Default-dir contents: {list(default_dir.rglob('*')) if default_dir.exists() else 'n/a'}" diff --git a/backend/tests/test_update_agent_e2e_user_isolation.py b/backend/tests/test_update_agent_e2e_user_isolation.py new file mode 100644 index 000000000..7fa725352 --- /dev/null +++ b/backend/tests/test_update_agent_e2e_user_isolation.py @@ -0,0 +1,253 @@ +"""End-to-end verification for update_agent's user_id resolution. + +PR #2784 hardened setup_agent to prefer runtime.context["user_id"] over the +contextvar. update_agent had the same latent gap: it unconditionally called +get_effective_user_id() at module level, so any scenario where the contextvar +was unavailable while runtime.context carried user_id (a background task +scheduled outside the request task, a worker pool that doesn't copy_context, +checkpoint resume on a different task) would silently route writes to +users/default/agents/... + +These tests are load-bearing under @no_auto_user (contextvar empty): + +- The negative-control test confirms the fixture actually puts the tool in + the regime where the contextvar fallback would land in users/default/. + Without that, the positive test would be vacuously satisfied. +- The positive test verifies update_agent honours runtime.context["user_id"] + injected by inject_authenticated_user_context in the gateway. Before the + fix in this PR, this test failed; now it passes. +""" + +from __future__ import annotations + +from contextlib import ExitStack +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +import yaml +from _agent_e2e_helpers import build_single_tool_call_model +from langchain_core.messages import HumanMessage + +from app.gateway.services import ( + build_run_config, + inject_authenticated_user_context, + merge_run_context_overrides, +) +from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context + + +def _make_request(user_id_str: str | None) -> SimpleNamespace: + user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") if user_id_str else None + return SimpleNamespace(state=SimpleNamespace(user=user)) + + +def _assemble_config(*, body_context: dict | None, request_user_id: str | None, thread_id: str) -> dict: + config = build_run_config(thread_id, {"recursion_limit": 50}, None, assistant_id="lead_agent") + merge_run_context_overrides(config, body_context) + inject_authenticated_user_context(config, _make_request(request_user_id)) + return config + + +def _seed_existing_agent(tmp_path: Path, user_id: str, agent_name: str, soul: str = "# Original"): + """Pre-create an agent on disk for update_agent to overwrite.""" + agent_dir = tmp_path / "users" / user_id / "agents" / agent_name + agent_dir.mkdir(parents=True, exist_ok=True) + (agent_dir / "config.yaml").write_text( + yaml.dump({"name": agent_name, "description": "old"}, allow_unicode=True), + encoding="utf-8", + ) + (agent_dir / "SOUL.md").write_text(soul, encoding="utf-8") + return agent_dir + + +def _make_paths_mock(tmp_path: Path): + paths = MagicMock() + paths.base_dir = tmp_path + paths.agent_dir = lambda name: tmp_path / "agents" / name + paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name + return paths + + +def _patch_update_agent_dependencies(tmp_path: Path): + """update_agent reads load_agent_config + get_app_config — stub them + minimally so the tool can run without a real config file or LLM.""" + fake_model_cfg = SimpleNamespace(name="fake-model") + fake_app_cfg = MagicMock() + fake_app_cfg.get_model_config = lambda name: fake_model_cfg if name == "fake-model" else None + + return [ + patch( + "deerflow.tools.builtins.update_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ), + patch( + "deerflow.tools.builtins.update_agent_tool.get_app_config", + return_value=fake_app_cfg, + ), + # load_agent_config (used by update_agent to read existing config) also + # reads paths via its own module-level get_paths reference. Patch it too + # or the tool returns "Agent does not exist" before touching disk. + patch( + "deerflow.config.agents_config.get_paths", + return_value=_make_paths_mock(tmp_path), + ), + ] + + +def _build_update_graph(*, soul_payload: str): + from langchain.agents import create_agent + + from deerflow.tools.builtins.update_agent_tool import update_agent + + fake_model = build_single_tool_call_model( + tool_name="update_agent", + tool_args={"soul": soul_payload, "description": "refined"}, + tool_call_id="call_update_1", + final_text="updated", + ) + return create_agent(model=fake_model, tools=[update_agent], system_prompt="updater") + + +# --------------------------------------------------------------------------- +# Negative control — proves the test environment puts update_agent in the +# regime where the contextvar fallback would land in default/. +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +def test_update_agent_falls_back_to_default_when_no_inject_and_no_contextvar(tmp_path: Path): + """No request.state.user, no contextvar — update_agent must look in + users/default/agents/. We seed the file there so the tool succeeds and + we know which directory it actually consulted.""" + from langgraph.runtime import Runtime + + _seed_existing_agent(tmp_path, "default", "fallback-target") + + config = _assemble_config( + body_context={"agent_name": "fallback-target"}, + request_user_id=None, # no auth, inject is no-op + thread_id="thread-update-1", + ) + runtime_ctx = _build_runtime_context("thread-update-1", "run-1", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# Fallback Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + graph.invoke( + {"messages": [HumanMessage(content="update fallback-target")]}, + config=config, + ) + + soul = (tmp_path / "users" / "default" / "agents" / "fallback-target" / "SOUL.md").read_text() + assert soul == "# Fallback Updated", "Sanity: tool should have written under default/" + + +# --------------------------------------------------------------------------- +# Regression guard — passes on this branch, would fail on main before the fix. +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +def test_update_agent_should_use_runtime_context_user_id_when_contextvar_missing(tmp_path: Path): + """update_agent prefers the authenticated user_id carried in + runtime.context (placed there by inject_authenticated_user_context) + over the contextvar — same contract as setup_agent (PR #2784). + + Before this PR's fix, update_agent unconditionally called + get_effective_user_id() and landed in default/ whenever the contextvar + was unavailable. This test pins the corrected behaviour. + """ + from langgraph.runtime import Runtime + + auth_uid = "abcdef01-2345-6789-abcd-ef0123456789" + + # Seed the agent in BOTH locations so we can prove which one was opened. + auth_dir = _seed_existing_agent(tmp_path, auth_uid, "shared-name", soul="# Auth Original") + default_dir = _seed_existing_agent(tmp_path, "default", "shared-name", soul="# Default Original") + + config = _assemble_config( + body_context={"agent_name": "shared-name"}, + request_user_id=auth_uid, + thread_id="thread-update-2", + ) + runtime_ctx = _build_runtime_context("thread-update-2", "run-2", config.get("context"), None) + assert runtime_ctx["user_id"] == auth_uid, "Pre-condition: inject must have placed user_id into runtime_ctx" + + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# Auth Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + graph.invoke( + {"messages": [HumanMessage(content="update shared-name")]}, + config=config, + ) + + auth_soul = (auth_dir / "SOUL.md").read_text() + default_soul = (default_dir / "SOUL.md").read_text() + + assert auth_soul == "# Auth Updated", f"REGRESSION: update_agent ignored runtime.context['user_id']={auth_uid!r} and routed the write to users/default/ instead. auth_soul={auth_soul!r}, default_soul={default_soul!r}" + assert default_soul == "# Default Original", "REGRESSION: update_agent corrupted the shared default-user agent. It should have written under the authenticated user's path." + + +# --------------------------------------------------------------------------- +# Positive — when contextvar IS the auth user (the normal HTTP case), things +# already work. Pin it as a regression guard so future refactors don't +# accidentally break the contextvar path in pursuit of the runtime-context fix. +# --------------------------------------------------------------------------- + + +def test_update_agent_uses_contextvar_when_present(tmp_path: Path, monkeypatch): + """The normal HTTP case: contextvar is set by auth_middleware. This must + keep working regardless of how runtime.context is populated.""" + from types import SimpleNamespace as _SN + + from deerflow.runtime.user_context import reset_current_user, set_current_user + + auth_uid = "11112222-3333-4444-5555-666677778888" + user = _SN(id=auth_uid, email="ctxvar@local") + + _seed_existing_agent(tmp_path, auth_uid, "ctxvar-agent", soul="# Original") + + from langgraph.runtime import Runtime + + config = _assemble_config( + body_context={"agent_name": "ctxvar-agent"}, + request_user_id=auth_uid, + thread_id="thread-update-3", + ) + runtime_ctx = _build_runtime_context("thread-update-3", "run-3", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# CtxVar Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + token = set_current_user(user) + try: + final = graph.invoke( + {"messages": [HumanMessage(content="update ctxvar-agent")]}, + config=config, + ) + finally: + reset_current_user(token) + + # surface the tool's reply for debug if it errored + tool_replies = [m.content for m in final["messages"] if getattr(m, "type", "") == "tool"] + soul = (tmp_path / "users" / auth_uid / "agents" / "ctxvar-agent" / "SOUL.md").read_text() + assert soul == "# CtxVar Updated", f"tool replies: {tool_replies}" From e9deb6c2f203d633b88578e6400c7fab4466ad86 Mon Sep 17 00:00:00 2001 From: He Wang Date: Tue, 12 May 2026 23:21:22 +0800 Subject: [PATCH 05/12] perf(harness): push thread metadata filters into SQL (#2865) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf(harness): push thread metadata filters into SQL Replace Python-side metadata filtering (5x overfetch + in-memory match) with database-side json_extract predicates so LIMIT/OFFSET pagination is exact regardless of match density. Co-Authored-By: Claude Opus 4 * fix(harness): add dialect-aware JsonMatch compiler for type-safe metadata SQL filters Replace SQLAlchemy JSON index/comparator APIs with a custom JsonMatch ColumnElement that compiles to json_type/json_extract on SQLite and jsonb_typeof/->>/-> on PostgreSQL. Tighten key validation regex to single-segment identifiers, handle None/bool/numeric value types with json_type-based discrimination, and strengthen test coverage for edge cases and discriminability. Co-Authored-By: Claude Opus 4 * fix(harness): address Copilot review comments on JSON metadata filters - Use json_typeof instead of jsonb_typeof in PostgreSQL compiler; the metadata_json column is JSON not JSONB so jsonb_typeof would error at runtime on any PostgreSQL backend - Align _is_safe_json_key with json_match's _KEY_CHARSET_RE so keys containing hyphens or leading digits are not silently skipped - Add thread_id as secondary ORDER BY in search() to make pagination deterministic when updated_at values collide; remove asyncio.sleep from the pagination regression test Co-Authored-By: Claude Sonnet 4 * fix(harness): address remaining review comments on metadata SQL filters - Remove _is_safe_json_key() and reuse json_match ValueError to avoid validator drift (Copilot #3217603895, #3217411616) - Raise ValueError when all metadata keys are rejected so callers never get silent unfiltered results (WillemJiang) - Fix integer precision: split int/float branches, bind int as Integer() with INTEGER/BIGINT CAST instead of float() coercion (Copilot #3217603972) - Fix jsonb_typeof -> json_typeof on JSON column (Copilot #3217411579) - Replace manual _cleanup() calls with async yield fixture so teardown always runs (Copilot #3217604019) - Remove asyncio.sleep(0.01) pagination ordering; use thread_id secondary sort instead (Copilot #3217411636) - Add type annotations to _bind/_build_clause/_compile_* and remove EOL comments from _Dialect fields (coding.mdc) - Expand test coverage: boolean/null/mixed-type/large-int precision, partial unsafe-key skip with caplog assertion Co-Authored-By: Claude Sonnet 4.6 * fix(harness): address third-round Copilot review comments on JsonMatch - Reject unsupported value types (list, dict, ...) in JsonMatch.__init__ with TypeError so inherit_cache=True never receives an unhashable value and callers get an explicit error instead of silent str() coercion (Copilot #3217933201) - Upgrade int bindparam from Integer() to BigInteger() to align with BIGINT CAST and avoid overflow on large integers (Copilot #3217933252) - Catch TypeError alongside ValueError in search() so non-string metadata keys are warned and skipped rather than raising unexpectedly (Copilot #3217933300) - Add three tests: json_match rejects unsupported value types, search() warns and raises on non-string key, search() warns and raises on unsupported value type Co-Authored-By: Claude Sonnet 4.6 * fix(harness): address fourth-round Copilot review comments on JsonMatch - Add CASE WHEN guard for PostgreSQL integer matching: json_typeof returns 'number' for both ints and floats; wrap CAST in CASE with regex guard '^-?[0-9]+$' so float rows never trigger CAST error (Copilot #3218413860) - Validate isinstance(key, str) before regex match in JsonMatch.__init__ so non-string keys raise ValueError consistently instead of TypeError from re.match (Copilot #3218413900) - Include exception message in metadata filter skip warning so callers can distinguish invalid key from unsupported value type (Copilot #3218413924) - Update tests: assert CASE WHEN guard in PG int compilation, cover non-string key ValueError in test_json_match_rejects_unsafe_key Co-Authored-By: Claude Sonnet 4.6 * fix(harness): align ThreadMetaStore.search() signature with sql.py implementation Use `dict[str, Any]` for `metadata` and `list[dict[str, Any]]` as return type in base class and MemoryThreadMetaStore to resolve an LSP signature mismatch; also correct a test docstring that cited the wrong exception type. Co-Authored-By: Claude Sonnet 4.6 * fix(harness): surface InvalidMetadataFilterError as HTTP 400 in search endpoint Replace bare ValueError with a domain-specific InvalidMetadataFilterError (subclass of ValueError) so the Gateway handler can catch it and return HTTP 400 instead of letting it bubble up as a 500. Co-Authored-By: Claude Opus 4 * fix(harness): sanitize metadata keys in log output to prevent log injection Use ascii() instead of %r to escape control characters in client-supplied metadata keys before logging, preventing multiline/forged log entries. Co-Authored-By: Claude Opus 4 * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(harness): validate metadata filters at API boundary and dedupe key/value rules - Add Pydantic ``field_validator`` on ``ThreadSearchRequest.metadata`` so unsafe keys / unsupported value types are rejected with HTTP 422 from both SQL and memory backends (closes Copilot review 3218830849). - Export ``validate_metadata_filter_key`` / ``validate_metadata_filter_value`` (and ``ALLOWED_FILTER_VALUE_TYPES``) from ``json_compat`` and have ``JsonMatch.__init__`` reuse them — the Gateway-side validator and the SQL-side ``JsonMatch`` constructor now share one admission rule and cannot drift. - Format ``InvalidMetadataFilterError`` rejected-keys list as a comma-separated plain string instead of a Python list repr so the surfaced HTTP 400 detail is readable (closes Copilot review 3218830899). - Update router tests to cover both 422 boundary paths plus the 400 defense-in-depth path when a backend still raises the error. Co-authored-by: Cursor * fix(harness): harden JsonMatch compile-time key validation against __init__ bypass Co-Authored-By: Claude Sonnet 4 * fix: address review feedback on metadata filter SQL push-down - Add signed 64-bit range check to validate_metadata_filter_value; give out-of-range ints a distinct TypeError message. - Replace assert guards in _compile_sqlite/_compile_pg with explicit if/raise so they survive python -O optimisation. Co-Authored-By: Claude Sonnet 4 --------- Co-authored-by: Claude Opus 4 Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Cursor --- backend/app/gateway/routers/threads.py | 38 +- .../deerflow/persistence/json_compat.py | 195 +++++++ .../persistence/thread_meta/__init__.py | 3 +- .../deerflow/persistence/thread_meta/base.py | 9 +- .../persistence/thread_meta/memory.py | 4 +- .../deerflow/persistence/thread_meta/sql.py | 46 +- backend/tests/test_thread_meta_repo.py | 504 +++++++++++++++--- backend/tests/test_threads_router.py | 54 ++ 8 files changed, 757 insertions(+), 96 deletions(-) create mode 100644 backend/packages/harness/deerflow/persistence/json_compat.py diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index cb048152e..e6f4fa2ae 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -90,6 +90,28 @@ class ThreadSearchRequest(BaseModel): offset: int = Field(default=0, ge=0, description="Pagination offset") status: str | None = Field(default=None, description="Filter by thread status") + @field_validator("metadata") + @classmethod + def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]: + """Reject filter entries the SQL backend cannot compile. + + Enforces consistent behaviour across SQL and memory backends. + See ``deerflow.persistence.json_compat`` for the shared validators. + """ + if not v: + return v + from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value + + bad_entries: list[str] = [] + for key, value in v.items(): + if not validate_metadata_filter_key(key): + bad_entries.append(f"{key!r} (unsafe key)") + elif not validate_metadata_filter_value(value): + bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})") + if bad_entries: + raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}") + return v + class ThreadStateResponse(BaseModel): """Response model for thread state.""" @@ -294,14 +316,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th (SQL-backed for sqlite/postgres, Store-backed for memory mode). """ from app.gateway.deps import get_thread_store + from deerflow.persistence.thread_meta import InvalidMetadataFilterError repo = get_thread_store(request) - rows = await repo.search( - metadata=body.metadata or None, - status=body.status, - limit=body.limit, - offset=body.offset, - ) + try: + rows = await repo.search( + metadata=body.metadata or None, + status=body.status, + limit=body.limit, + offset=body.offset, + ) + except InvalidMetadataFilterError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc return [ ThreadResponse( thread_id=r["thread_id"], diff --git a/backend/packages/harness/deerflow/persistence/json_compat.py b/backend/packages/harness/deerflow/persistence/json_compat.py new file mode 100644 index 000000000..442b29e22 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/json_compat.py @@ -0,0 +1,195 @@ +"""Dialect-aware JSON value matching for SQLAlchemy (SQLite + PostgreSQL).""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import BigInteger, Float, String, bindparam +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.expression import ColumnElement +from sqlalchemy.sql.visitors import InternalTraversal +from sqlalchemy.types import Boolean, TypeEngine + +# Key is interpolated into compiled SQL; restrict charset to prevent injection. +_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$") + +# Allowed value types for metadata filter values (same set accepted by JsonMatch). +ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str) + +# SQLite raises an overflow when binding values outside signed 64-bit range; +# PostgreSQL overflows during BIGINT cast. Reject at validation time instead. +_INT64_MIN = -(2**63) +_INT64_MAX = 2**63 - 1 + + +def validate_metadata_filter_key(key: object) -> bool: + """Return True if *key* is safe for use as a JSON metadata filter key. + + A key is "safe" when it is a string matching ``[A-Za-z0-9_-]+``. The + charset is restricted because the key is interpolated into the + compiled SQL path expression (``$.""`` / ``->`` literal), so any + laxer pattern would open a SQL/JSONPath injection surface. + """ + return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key)) + + +def validate_metadata_filter_value(value: object) -> bool: + """Return True if *value* is an allowed type for a JSON metadata filter. + + Matches the set of types ``_build_clause`` knows how to compile into + a dialect-portable predicate. Anything else (list/dict/bytes/...) is + intentionally rejected rather than silently coerced via ``str()`` — + silent coercion would (a) produce wrong matches and (b) break + SQLAlchemy's ``inherit_cache`` invariant when ``value`` is unhashable. + + Integer values are additionally restricted to the signed 64-bit range + ``[-2**63, 2**63 - 1]``: SQLite overflows when binding larger values + and PostgreSQL overflows during the ``BIGINT`` cast. + """ + if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES): + return False + if isinstance(value, int) and not isinstance(value, bool): + if not (_INT64_MIN <= value <= _INT64_MAX): + return False + return True + + +class JsonMatch(ColumnElement): + """Dialect-portable ``column[key] == value`` for JSON columns. + + Compiles to ``json_type``/``json_extract`` on SQLite and + ``json_typeof``/``->>`` on PostgreSQL, with type-safe comparison + that distinguishes bool vs int and NULL vs missing key. + + *key* must be a single literal key matching ``[A-Za-z0-9_-]+``. + *value* must be one of: ``None``, ``bool``, ``int`` (signed 64-bit), ``float``, ``str``. + """ + + inherit_cache = True + type = Boolean() + _is_implicitly_boolean = True + + _traverse_internals = [ + ("column", InternalTraversal.dp_clauseelement), + ("key", InternalTraversal.dp_string), + ("value", InternalTraversal.dp_plain_obj), + ] + + def __init__(self, column: ColumnElement, key: str, value: object) -> None: + if not validate_metadata_filter_key(key): + raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}") + if not validate_metadata_filter_value(value): + if isinstance(value, int) and not isinstance(value, bool): + raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}") + raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}") + self.column = column + self.key = key + self.value = value + super().__init__() + + +@dataclass(frozen=True) +class _Dialect: + """Per-dialect names used when emitting JSON type/value comparisons.""" + + null_type: str + num_types: tuple[str, ...] + num_cast: str + int_types: tuple[str, ...] + int_cast: str + # None for SQLite where json_type already returns 'integer'/'real'; + # regex literal for PostgreSQL where json_typeof returns 'number' for + # both ints and floats, so an extra guard prevents CAST errors on floats. + int_guard: str | None + string_type: str + bool_type: str | None + + +_SQLITE = _Dialect( + null_type="null", + num_types=("integer", "real"), + num_cast="REAL", + int_types=("integer",), + int_cast="INTEGER", + int_guard=None, + string_type="text", + bool_type=None, +) + +_PG = _Dialect( + null_type="null", + num_types=("number",), + num_cast="DOUBLE PRECISION", + int_types=("number",), + int_cast="BIGINT", + int_guard="'^-?[0-9]+$'", + string_type="string", + bool_type="boolean", +) + + +def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str: + param = bindparam(None, value, type_=sa_type) + return compiler.process(param, **kw) + + +def _type_check(typeof: str, types: tuple[str, ...]) -> str: + if len(types) == 1: + return f"{typeof} = '{types[0]}'" + quoted = ", ".join(f"'{t}'" for t in types) + return f"{typeof} IN ({quoted})" + + +def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str: + if value is None: + return f"{typeof} = '{dialect.null_type}'" + if isinstance(value, bool): + # bool check must precede int check — bool is a subclass of int in Python + bool_str = "true" if value else "false" + if dialect.bool_type is None: + return f"{typeof} = '{bool_str}'" + return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')" + if isinstance(value, int): + bp = _bind(compiler, value, BigInteger(), **kw) + if dialect.int_guard: + # CASE prevents CAST error when json_typeof = 'number' also matches floats + return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" + return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})" + if isinstance(value, float): + bp = _bind(compiler, value, Float(), **kw) + return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})" + bp = _bind(compiler, str(value), String(), **kw) + return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})" + + +@compiles(JsonMatch, "sqlite") +def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + path = f'$."{element.key}"' + typeof = f"json_type({col}, '{path}')" + extract = f"json_extract({col}, '{path}')" + return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw) + + +@compiles(JsonMatch, "postgresql") +def _compile_pg(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + typeof = f"json_typeof({col} -> '{element.key}')" + extract = f"({col} ->> '{element.key}')" + return _build_clause(compiler, typeof, extract, element.value, _PG, **kw) + + +@compiles(JsonMatch) +def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + raise NotImplementedError(f"JsonMatch supports only sqlite and postgresql; got dialect: {compiler.dialect.name}") + + +def json_match(column: ColumnElement, key: str, value: object) -> JsonMatch: + return JsonMatch(column, key, value) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py index 080ce8093..b5231f0f9 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.sql import ThreadMetaRepository @@ -14,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker __all__ = [ + "InvalidMetadataFilterError", "MemoryThreadMetaStore", "ThreadMetaRepository", "ThreadMetaRow", diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index c87c10a16..ed55ade8e 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -15,10 +15,15 @@ three-state semantics (see :mod:`deerflow.runtime.user_context`): from __future__ import annotations import abc +from typing import Any from deerflow.runtime.user_context import AUTO, _AutoSentinel +class InvalidMetadataFilterError(ValueError): + """Raised when all client-supplied metadata filter keys are rejected.""" + + class ThreadMetaStore(abc.ABC): @abc.abstractmethod async def create( @@ -40,12 +45,12 @@ class ThreadMetaStore(abc.ABC): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: pass @abc.abstractmethod diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index fbe66fdaf..4f642a938 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -69,12 +69,12 @@ class MemoryThreadMetaStore(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") filter_dict: dict[str, Any] = {} if metadata: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 688fbb247..0d3f587de 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -2,16 +2,20 @@ from __future__ import annotations +import logging from datetime import UTC, datetime from typing import Any from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.json_compat import json_match +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +logger = logging.getLogger(__name__) + class ThreadMetaRepository(ThreadMetaStore): def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: @@ -20,7 +24,7 @@ class ThreadMetaRepository(ThreadMetaStore): @staticmethod def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]: d = row.to_dict() - d["metadata"] = d.pop("metadata_json", {}) + d["metadata"] = d.pop("metadata_json", None) or {} for key in ("created_at", "updated_at"): val = d.get(key) if isinstance(val, datetime): @@ -104,39 +108,43 @@ class ThreadMetaRepository(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: """Search threads with optional metadata and status filters. Owner filter is enforced by default: caller must be in a user context. Pass ``user_id=None`` to bypass (migration/CLI). """ resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search") - stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) + stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc(), ThreadMetaRow.thread_id.desc()) if resolved_user_id is not None: stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id) if status: stmt = stmt.where(ThreadMetaRow.status == status) if metadata: - # When metadata filter is active, fetch a larger window and filter - # in Python. TODO(Phase 2): use JSON DB operators (Postgres @>, - # SQLite json_extract) for server-side filtering. - stmt = stmt.limit(limit * 5 + offset) - async with self._sf() as session: - result = await session.execute(stmt) - rows = [self._row_to_dict(r) for r in result.scalars()] - rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())] - return rows[offset : offset + limit] - else: - stmt = stmt.limit(limit).offset(offset) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] + applied = 0 + for key, value in metadata.items(): + try: + stmt = stmt.where(json_match(ThreadMetaRow.metadata_json, key, value)) + applied += 1 + except (ValueError, TypeError) as exc: + logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc) + if applied == 0: + # Comma-separated plain string (no list repr / nested + # quoting) so the 400 detail surfaced by the Gateway is + # easy for clients to read. Sorted for determinism. + rejected_keys = ", ".join(sorted(str(k) for k in metadata)) + raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}") + + stmt = stmt.limit(limit).offset(offset) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool: """Return True if the row exists and is owned (or filter bypassed).""" diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 3a6532567..1cef3752b 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -1,28 +1,25 @@ """Tests for ThreadMetaRepository (SQLAlchemy-backed).""" +import logging + import pytest -from deerflow.persistence.thread_meta import ThreadMetaRepository +from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository -async def _make_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine +@pytest.fixture +async def repo(tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return ThreadMetaRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - + yield ThreadMetaRepository(get_session_factory()) await close_engine() class TestThreadMetaRepository: @pytest.mark.anyio - async def test_create_and_get(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_and_get(self, repo): record = await repo.create("t1") assert record["thread_id"] == "t1" assert record["status"] == "idle" @@ -31,148 +28,523 @@ class TestThreadMetaRepository: fetched = await repo.get("t1") assert fetched is not None assert fetched["thread_id"] == "t1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_assistant_id(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_assistant_id(self, repo): record = await repo.create("t1", assistant_id="agent1") assert record["assistant_id"] == "agent1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_owner_and_display_name(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_owner_and_display_name(self, repo): record = await repo.create("t1", user_id="user1", display_name="My Thread") assert record["user_id"] == "user1" assert record["display_name"] == "My Thread" - await _cleanup() @pytest.mark.anyio - async def test_create_with_metadata(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_metadata(self, repo): record = await repo.create("t1", metadata={"key": "value"}) assert record["metadata"] == {"key": "value"} - await _cleanup() @pytest.mark.anyio - async def test_get_nonexistent(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_get_nonexistent(self, repo): assert await repo.get("nonexistent") is None - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_record_allows(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_record_allows(self, repo): assert await repo.check_access("unknown", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_matches(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_matches(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_mismatch(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_mismatch(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2") is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_owner_allows_all(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_owner_allows_all(self, repo): # Explicit user_id=None to bypass the new AUTO default that # would otherwise pick up the test user from the autouse fixture. await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_missing_row_denied(self, tmp_path): + async def test_check_access_strict_missing_row_denied(self, repo): """require_existing=True flips the missing-row case to *denied*. Closes the delete-idempotence cross-user gap: after a thread is deleted, the row is gone, and the permissive default would let any caller "claim" it as untracked. The strict mode demands a row. """ - repo = await _make_repo(tmp_path) assert await repo.check_access("never-existed", "user1", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_match_allowed(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_match_allowed(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_mismatch_denied(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): + async def test_check_access_strict_null_owner_still_allowed(self, repo): """Even in strict mode, a row with NULL user_id stays shared. The strict flag tightens the *missing row* case, not the *shared row* case — legacy pre-auth rows that survived a clean migration without an owner are still everyone's. """ - repo = await _make_repo(tmp_path) await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_update_status(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_status(self, repo): await repo.create("t1") await repo.update_status("t1", "busy") record = await repo.get("t1") assert record["status"] == "busy" - await _cleanup() @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete(self, repo): await repo.create("t1") await repo.delete("t1") assert await repo.get("t1") is None - await _cleanup() @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete_nonexistent_is_noop(self, repo): await repo.delete("nonexistent") # should not raise - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_merges(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_merges(self, repo): await repo.create("t1", metadata={"a": 1, "b": 2}) await repo.update_metadata("t1", {"b": 99, "c": 3}) record = await repo.get("t1") # Existing key preserved, overlapping key overwritten, new key added assert record["metadata"] == {"a": 1, "b": 99, "c": 3} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_on_empty(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_on_empty(self, repo): await repo.create("t1") await repo.update_metadata("t1", {"k": "v"}) record = await repo.get("t1") assert record["metadata"] == {"k": "v"} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_nonexistent_is_noop(self, repo): await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise - await _cleanup() + + # --- search with metadata filter (SQL push-down) --- + + @pytest.mark.anyio + async def test_search_metadata_filter_string(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + await repo.create("t3", metadata={"env": "prod", "region": "us"}) + + results = await repo.search(metadata={"env": "prod"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_numeric(self, repo): + await repo.create("t1", metadata={"priority": 1}) + await repo.create("t2", metadata={"priority": 2}) + await repo.create("t3", metadata={"priority": 1, "extra": "x"}) + + results = await repo.search(metadata={"priority": 1}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_multiple_keys(self, repo): + await repo.create("t1", metadata={"env": "prod", "region": "us"}) + await repo.create("t2", metadata={"env": "prod", "region": "eu"}) + await repo.create("t3", metadata={"env": "staging", "region": "us"}) + + results = await repo.search(metadata={"env": "prod", "region": "us"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_metadata_no_match(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "dev"}) + assert results == [] + + @pytest.mark.anyio + async def test_search_metadata_pagination_correct(self, repo): + """Regression: SQL push-down makes limit/offset exact even when most rows don't match.""" + for i in range(30): + meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"} + await repo.create(f"t{i:03d}", metadata=meta) + + # Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows + all_matches = await repo.search(metadata={"target": "yes"}, limit=100) + assert len(all_matches) == 10 + + # Paginate: first page + page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0) + assert len(page1) == 3 + + # Paginate: second page + page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3) + assert len(page2) == 3 + + # No overlap between pages + page1_ids = {r["thread_id"] for r in page1} + page2_ids = {r["thread_id"] for r in page2} + assert page1_ids.isdisjoint(page2_ids) + + # Last page + page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9) + assert len(page_last) == 1 + + @pytest.mark.anyio + async def test_search_metadata_with_status_filter(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "prod"}) + await repo.update_status("t1", "busy") + + results = await repo.search(metadata={"env": "prod"}, status="busy") + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_without_metadata_still_works(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2") + + results = await repo.search(limit=10) + assert len(results) == 2 + + @pytest.mark.anyio + async def test_search_metadata_missing_key_no_match(self, repo): + """Rows without the requested metadata key should not match.""" + await repo.create("t1", metadata={"other": "val"}) + await repo.create("t2", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "prod"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t2" + + @pytest.mark.anyio + async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog): + """When ALL metadata keys are unsafe, raises InvalidMetadataFilterError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info: + await repo.search(metadata={"bad;key": "x"}) + assert any("bad;key" in r.message for r in caplog.records) + # Subclass of ValueError for backward compatibility + assert isinstance(exc_info.value, ValueError) + + @pytest.mark.anyio + async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog): + """Valid keys filter rows; only the invalid key is warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + results = await repo.search(metadata={"env": "prod", "bad;key": "x"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + assert any("bad;key" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_filter_boolean(self, repo): + """True matches only boolean true, not integer 1.""" + await repo.create("t1", metadata={"active": True}) + await repo.create("t2", metadata={"active": False}) + await repo.create("t3", metadata={"active": True, "extra": "x"}) + await repo.create("t4", metadata={"active": 1}) + + results = await repo.search(metadata={"active": True}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_none(self, repo): + """Only rows with explicit JSON null match; missing key does not.""" + await repo.create("t1", metadata={"tag": None}) + await repo.create("t2", metadata={"tag": "present"}) + await repo.create("t3", metadata={"other": "val"}) + + results = await repo.search(metadata={"tag": None}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + + @pytest.mark.anyio + async def test_search_metadata_non_string_key_skipped(self, repo, caplog): + """Non-string keys raise ValueError from isinstance check; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={1: "x"}) + assert any("1" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog): + """Unsupported value types (list, dict) raise TypeError; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"env": ["prod", "staging"]}) + + @pytest.mark.anyio + async def test_search_metadata_dotted_key_raises(self, repo, caplog): + """Dotted keys are rejected; when ALL keys are dotted, raises ValueError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"a.b": "anything"}) + assert any("a.b" in r.message for r in caplog.records) + + # --- dialect-aware type-safe filtering edge cases --- + + @pytest.mark.anyio + async def test_search_metadata_bool_vs_int_distinction(self, repo): + """True must not match 1; False must not match 0.""" + await repo.create("bool_true", metadata={"flag": True}) + await repo.create("bool_false", metadata={"flag": False}) + await repo.create("int_one", metadata={"flag": 1}) + await repo.create("int_zero", metadata={"flag": 0}) + + true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})} + assert true_hits == {"bool_true"} + + false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})} + assert false_hits == {"bool_false"} + + @pytest.mark.anyio + async def test_search_metadata_int_does_not_match_bool(self, repo): + """Integer 1 must not match boolean True.""" + await repo.create("bool_true", metadata={"val": True}) + await repo.create("int_one", metadata={"val": 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})} + assert hits == {"int_one"} + + @pytest.mark.anyio + async def test_search_metadata_none_excludes_missing_key(self, repo): + """Filtering by None matches explicit JSON null only, not missing key or empty {}.""" + await repo.create("explicit_null", metadata={"k": None}) + await repo.create("missing_key", metadata={"other": "x"}) + await repo.create("empty_obj", metadata={}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})} + assert hits == {"explicit_null"} + + @pytest.mark.anyio + async def test_search_metadata_float_value(self, repo): + await repo.create("t1", metadata={"score": 3.14}) + await repo.create("t2", metadata={"score": 2.71}) + await repo.create("t3", metadata={"score": 3.14}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})} + assert hits == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_mixed_types_same_key(self, repo): + """Each type query only matches its own type, even when the key is shared.""" + await repo.create("str_row", metadata={"x": "hello"}) + await repo.create("int_row", metadata={"x": 42}) + await repo.create("bool_row", metadata={"x": True}) + await repo.create("null_row", metadata={"x": None}) + + assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"} + + @pytest.mark.anyio + async def test_search_metadata_large_int_precision(self, repo): + """Integers beyond float precision (> 2**53) must match exactly.""" + large = 2**53 + 1 + await repo.create("t1", metadata={"id": large}) + await repo.create("t2", metadata={"id": large - 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})} + assert hits == {"t1"} + + +class TestJsonMatchCompilation: + """Verify compiled SQL for both SQLite and PostgreSQL dialects.""" + + def test_json_match_compiles_sqlite(self): + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + cases = [ + (None, "json_type(t.data, '$.\"k\"') = 'null'"), + (True, "json_type(t.data, '$.\"k\"') = 'true'"), + (False, "json_type(t.data, '$.\"k\"') = 'false'"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: uses INTEGER cast for precision, type-check narrows to 'integer' only + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "= 'integer'" in sql + assert "INTEGER" in sql + assert "CAST" in sql + + # float: uses REAL cast, type-check spans 'integer' and 'real' + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "IN ('integer', 'real')" in sql + assert "REAL" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "'text'" in sql + + def test_json_match_compiles_pg(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + dialect = postgresql.dialect() + + cases = [ + (None, "json_typeof(t.data -> 'k') = 'null'"), + (True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"), + (False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: CASE guard prevents CAST error when 'number' also matches floats + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "BIGINT" in sql + assert "CASE WHEN" in sql + assert "'^-?[0-9]+$'" in sql + + # float: uses DOUBLE PRECISION cast + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "DOUBLE PRECISION" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'string'" in sql + + def test_json_match_rejects_unsafe_key(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, bad_key, "x") + + # Non-string keys must also raise ValueError (not TypeError from re.match) + for non_str_key in [42, None, ("k",)]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, non_str_key, "x") + + def test_json_match_rejects_unsupported_value_type(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_value in [[], {}, object()]: + with pytest.raises(TypeError, match="JsonMatch value must be"): + json_match(t.c.data, "k", bad_value) + + def test_json_match_unsupported_dialect_raises(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import mysql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + expr = json_match(t.c.data, "k", "v") + + with pytest.raises(NotImplementedError, match="mysql"): + str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True})) + + def test_json_match_rejects_out_of_range_int(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + # boundary values must be accepted + json_match(t.c.data, "k", 2**63 - 1) + json_match(t.c.data, "k", -(2**63)) + + # one beyond each boundary must be rejected + for out_of_range in [2**63, -(2**63) - 1, 10**30]: + with pytest.raises(TypeError, match="out of signed 64-bit range"): + json_match(t.c.data, "k", out_of_range) + + def test_compiler_raises_on_escaped_key(self): + """Compiler raises ValueError even when __init__ validation is bypassed.""" + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + elem = json_match(t.c.data, "k", "v") + elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index daf0c0b13..9e37f3c86 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore from app.gateway.routers import threads from deerflow.config.paths import Paths +from deerflow.persistence.thread_meta import InvalidMetadataFilterError from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore _ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None assert entries, "expected at least one history entry" for entry in entries: assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry + + +# ── Metadata filter validation at API boundary ──────────────────────────────── + + +def test_search_threads_rejects_invalid_key_at_api_boundary() -> None: + """Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic + validator on ThreadSearchRequest.metadata — 422 from both backends. + """ + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}}) + + assert response.status_code == 422 + + +def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None: + """Value types outside (None, bool, int, float, str) are rejected.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}}) + + assert response.status_code == 422 + + +def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None: + """If the backend still raises InvalidMetadataFilterError (defense in + depth), the handler surfaces it as HTTP 400. + """ + app, _store, _checkpointer = _build_thread_app() + thread_store = app.state.thread_store + + async def _raise(**kwargs): + raise InvalidMetadataFilterError("rejected") + + with TestClient(app) as client: + with patch.object(thread_store, "search", side_effect=_raise): + response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}}) + + assert response.status_code == 400 + assert "rejected" in response.json()["detail"] + + +def test_search_threads_succeeds_with_valid_metadata() -> None: + """Sanity check: valid metadata passes through without error.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}}) + + assert response.status_code == 200 From 2a1ac06bf4ee0efdc6bc312d5b0548a321f591d6 Mon Sep 17 00:00:00 2001 From: Eilen Shin <136898293+Eilen6316@users.noreply.github.com> Date: Wed, 13 May 2026 15:49:34 +0800 Subject: [PATCH 06/12] fix(persistence): reuse token usage model grouping expression (#2910) --- .../harness/deerflow/persistence/run/sql.py | 5 +- backend/tests/test_run_repository.py | 48 +++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 430fbe4f6..5331451e3 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -223,10 +223,11 @@ class RunRepository(RunStore): """Aggregate token usage via a single SQL GROUP BY query.""" _completed = RunRow.status.in_(("success", "error")) _thread = RunRow.thread_id == thread_id + model_name = func.coalesce(RunRow.model_name, "unknown") stmt = ( select( - func.coalesce(RunRow.model_name, "unknown").label("model"), + model_name.label("model"), func.count().label("runs"), func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"), func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"), @@ -236,7 +237,7 @@ class RunRepository(RunStore): func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"), ) .where(_thread, _completed) - .group_by(func.coalesce(RunRow.model_name, "unknown")) + .group_by(model_name) ) async with self._sf() as session: diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 6fd534829..5e230e790 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -3,7 +3,10 @@ Uses a temp SQLite DB to test ORM-backed CRUD operations. """ +import re + import pytest +from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository @@ -278,3 +281,48 @@ class TestRunRepository: assert row4["model_name"] is None await _cleanup() + + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self): + captured = [] + + class FakeResult: + def all(self): + return [] + + class FakeSession: + async def execute(self, stmt): + captured.append(stmt) + return FakeResult() + + class FakeSessionContext: + async def __aenter__(self): + return FakeSession() + + async def __aexit__(self, exc_type, exc, tb): + return None + + repo = RunRepository(lambda: FakeSessionContext()) + + agg = await repo.aggregate_tokens_by_thread("t1") + assert agg == { + "total_tokens": 0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_runs": 0, + "by_model": {}, + "by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0}, + } + assert len(captured) == 1 + + stmt = captured[0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect())) + select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1) + model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)" + + select_match = re.search(model_expr_pattern + r" AS model", select_sql) + group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip()) + + assert select_match is not None + assert group_by_match is not None + assert select_match.group(1) == group_by_match.group(1) From f1a0ab699aee5642fccf9f0fc211b231d41e6b5d Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Wed, 13 May 2026 23:45:47 +0800 Subject: [PATCH 07/12] fix(tools): preserve tool_search promotions across re-entrant get_available_tools (#2885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(tools): preserve tool_search promotions across re-entrant get_available_tools Closes #2884. ``get_available_tools`` used to unconditionally call ``reset_deferred_registry()`` and rebuild a fresh ``DeferredToolRegistry`` on every invocation. That works for the first call of a request (the ContextVar starts at its default of ``None``), but any RE-ENTRANT call during the same async context — e.g. ``task_tool`` building a subagent's toolset, or a custom middleware that rebuilds tools mid-run — wiped any ``tool_search`` promotions the parent agent had already made. The ``DeferredToolFilterMiddleware`` would then re-hide those tools from the next model call, leaving the agent able to see a tool's name (via the prior ``tool_search`` result that's still in conversation history) but unable to invoke it. Fix: when the ContextVar already holds a registry, reuse it instead of rebuilding. Fresh requests still get a fresh registry because each new graph run starts in a new asyncio task with the ContextVar at ``None``. ## Verification - Unit-level reproduction (``test_get_available_tools_resets_registry_wiping_promotion``): promote a tool in the registry, call ``get_available_tools`` again, assert the promotion is preserved. Fails on main, passes on this branch. - Graph-execution reproduction (two tests): drive a real ``langchain.agents.create_agent`` graph with the real ``DeferredToolFilterMiddleware`` through two model turns, including one that issues a re-entrant ``get_available_tools`` call to simulate the task_tool subagent path. - Real-LLM end-to-end (``test_deferred_tool_promotion_real_llm.py``, opt-in via ``ONEAPI_E2E=1``): drives the same flow against a real OpenAI-compatible model (verified on GPT-5.4-mini through the one-api gateway), watches the model call the promoted ``fake_calculator`` through the deferred-filter middleware, and asserts the right arithmetic result. Passes against the fixed branch. - Companion update to ``test_tool_deduplication.py``: dropped the ``@patch("deerflow.tools.tools.reset_deferred_registry")`` decorators because the symbol is no longer imported there. - Test fixtures in the new files patch ``deerflow.tools.tools.get_app_config`` with a minimal ``model_construct``-ed ``AppConfig`` instead of calling the real loader, so they never trigger ``_apply_singleton_configs`` and never leak ``_memory_config``/``_title_config``/… mutations into the rest of the suite. Full backend suite: 3208 passed / 14 skipped / 0 failed. ruff check + format clean. * fix(tools): address Copilot review on #2885 - tools.py: rewrite the reuse-path comment to spell out (a) why we don't reconcile the registry against the current ``mcp_tools`` snapshot — the MCP cache doesn't refresh mid-graph-run, the lead agent's ``ToolNode`` is already bound to the previous tool set anyway, and ``promote()`` drops the entry so a naive re-sync misclassifies promotions as new tools — and (b) why the log uses ``max(0, …)`` to avoid negative counts when the cache shrinks between snapshots. - Replace direct ``ts_mod._registry_var.set(None)`` in test fixtures with the public ``reset_deferred_registry()`` helper so tests don't couple to module internals. - Correct the docstring path in ``test_deferred_tool_registry_promotion.py`` to match the actual monkeypatch target (``deerflow.mcp.cache.get_cached_mcp_tools``). - Rename ``test_get_available_tools_resets_registry_wiping_promotion`` to ``test_get_available_tools_preserves_promotions_across_reentrant_calls`` so the test name describes the contract being asserted, not the bug it originally reproduced. Full backend suite: 3208 passed / 14 skipped. Real-LLM e2e: 1 passed. --- .../packages/harness/deerflow/tools/tools.py | 53 ++- .../test_deferred_tool_promotion_real_llm.py | 222 ++++++++++ .../test_deferred_tool_registry_promotion.py | 390 ++++++++++++++++++ backend/tests/test_tool_deduplication.py | 12 +- 4 files changed, 661 insertions(+), 16 deletions(-) create mode 100644 backend/tests/test_deferred_tool_promotion_real_llm.py create mode 100644 backend/tests/test_deferred_tool_registry_promotion.py diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 01bfce43f..5c97962fc 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig from deerflow.reflection import resolve_variable from deerflow.sandbox.security import is_host_bash_allowed from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool -from deerflow.tools.builtins.tool_search import reset_deferred_registry +from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.sync import make_sync_tool_wrapper logger = logging.getLogger(__name__) @@ -116,8 +116,6 @@ def get_available_tools( # made through the Gateway API (which runs in a separate process) are immediately # reflected when loading MCP tools. mcp_tools = [] - # Reset deferred registry upfront to prevent stale state from previous calls - reset_deferred_registry() if include_mcp: try: from deerflow.config.extensions_config import ExtensionsConfig @@ -135,12 +133,51 @@ def get_available_tools( from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool - registry = DeferredToolRegistry() - for t in mcp_tools: - registry.register(t) - set_deferred_registry(registry) + # Reuse the existing registry if one is already set for + # this async context. ``get_available_tools`` is + # re-entered whenever a subagent is spawned + # (``task_tool`` calls it to build the child agent's + # toolset), and previously we used to unconditionally + # rebuild the registry — wiping out the parent agent's + # tool_search promotions. The + # ``DeferredToolFilterMiddleware`` then re-hid those + # tools from subsequent model calls, leaving the agent + # able to see a tool's name but unable to invoke it + # (issue #2884). ``contextvars`` already gives us the + # lifetime semantics we want: a fresh request / graph + # run starts in a new asyncio task with the + # ContextVar at its default of ``None``, so reuse is + # only triggered for re-entrant calls inside one run. + # + # Intentionally NOT reconciling against the current + # ``mcp_tools`` snapshot. The MCP cache only refreshes + # on ``extensions_config.json`` mtime changes, which + # in practice happens between graph runs — not inside + # one. And even if a refresh did happen mid-run, the + # already-built lead agent's ``ToolNode`` still holds + # the *previous* tool set (LangGraph binds tools at + # graph construction time), so a brand-new MCP tool + # couldn't actually be invoked anyway. The + # ``DeferredToolRegistry`` doesn't retain the names + # of previously-promoted tools (``promote()`` drops + # the entry entirely), so re-syncing the registry + # against a fresh ``mcp_tools`` list would + # mis-classify those promotions as new tools and + # re-register them as deferred — exactly the bug + # this fix exists to prevent. + existing_registry = get_deferred_registry() + if existing_registry is None: + registry = DeferredToolRegistry() + for t in mcp_tools: + registry.register(t) + set_deferred_registry(registry) + logger.info(f"Tool search active: {len(mcp_tools)} tools deferred") + else: + mcp_tool_names = {t.name for t in mcp_tools} + still_deferred = len(existing_registry) + promoted_count = max(0, len(mcp_tool_names) - still_deferred) + logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted") builtin_tools.append(tool_search_tool) - logger.info(f"Tool search active: {len(mcp_tools)} tools deferred") except ImportError: logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.") except Exception as e: diff --git a/backend/tests/test_deferred_tool_promotion_real_llm.py b/backend/tests/test_deferred_tool_promotion_real_llm.py new file mode 100644 index 000000000..46ae24d41 --- /dev/null +++ b/backend/tests/test_deferred_tool_promotion_real_llm.py @@ -0,0 +1,222 @@ +"""Real-LLM end-to-end verification for issue #2884. + +Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI- +compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware`` +and the production ``get_available_tools`` pipeline. The only thing we mock is +the MCP tool source — we hand-roll two ``@tool``s and inject them through +``deerflow.mcp.cache.get_cached_mcp_tools``. + +The flow exercised: + 1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger`` + that re-enters ``get_available_tools`` on the same task — this is the + code path issue #2884 reports). It must call ``tool_search`` to + discover the deferred ``fake_calculator`` tool. + 2. Tool batch: ``tool_search`` promotes ``fake_calculator``; + ``fake_subagent_trigger`` re-enters ``get_available_tools``. + 3. Turn 2: the promoted ``fake_calculator`` schema must reach the model + so it can actually call it. Without this PR's fix, the re-entry wipes + the promotion and the model can no longer invoke the tool. + +Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every +test run. Run with:: + + ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \ + PYTHONPATH=. uv run pytest \ + tests/test_deferred_tool_promotion_real_llm.py -v -s +""" + +from __future__ import annotations + +import os + +import pytest +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool as as_tool + +# --------------------------------------------------------------------------- +# Skip control: only run when explicitly opted in. +# --------------------------------------------------------------------------- + + +pytestmark = pytest.mark.skipif( + os.getenv("ONEAPI_E2E") != "1", + reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)", +) + + +# --------------------------------------------------------------------------- +# Fake "MCP" tools the agent should discover via tool_search. +# Keep them obviously synthetic so the model can pattern-match the search. +# --------------------------------------------------------------------------- + + +_calls: list[str] = [] + + +@as_tool +def fake_calculator(expression: str) -> str: + """Evaluate a tiny arithmetic expression like '2 + 2'. + + Reserved for the user — only call this if the user asks for arithmetic. + """ + _calls.append(f"fake_calculator:{expression}") + try: + # Trivially safe-eval just for the e2e check + allowed = set("0123456789+-*/() .") + if not set(expression) <= allowed: + return "expression contains disallowed characters" + return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307 + except Exception as e: + return f"error: {e}" + + +@as_tool +def fake_translator(text: str, target_lang: str) -> str: + """Translate text into the given language code. Decorative — not used.""" + _calls.append(f"fake_translator:{text}:{target_lang}") + return f"[{target_lang}] {text}" + + +# --------------------------------------------------------------------------- +# Pipeline wiring (same shape as the in-process tests). +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_registry_between_tests(): + from deerflow.tools.builtins.tool_search import reset_deferred_registry + + reset_deferred_registry() + yield + reset_deferred_registry() + + +def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: + from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig + + real_ext = ExtensionsConfig( + mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)}, + ) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: real_ext), + ) + monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools)) + + +def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Build a minimal mock AppConfig and patch the symbol — never call the + real loader, which would trigger ``_apply_singleton_configs`` and + permanently mutate cross-test singletons (memory, title, …).""" + from deerflow.config.app_config import AppConfig + from deerflow.config.tool_search_config import ToolSearchConfig + + mock_cfg = AppConfig.model_construct( + log_level="info", + models=[], + tools=[], + tool_groups=[], + sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"), + tool_search=ToolSearchConfig(enabled=True), + ) + monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg) + + +# --------------------------------------------------------------------------- +# Real-LLM e2e test +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch): + """End-to-end against a real OpenAI-compatible LLM. + + The model must: + Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and + batch-call BOTH ``tool_search(select:fake_calculator)`` AND + ``fake_subagent_trigger(...)``. + Turn 2 — call ``fake_calculator`` and finish. + + Pass criterion: ``fake_calculator`` actually gets invoked at the tool + layer — recorded in ``_calls`` — which proves the model received the + promoted schema after the re-entrant ``get_available_tools`` call. + """ + from langchain.agents import create_agent + from langchain_openai import ChatOpenAI + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator]) + _force_tool_search_enabled(monkeypatch) + _calls.clear() + + @as_tool + async def fake_subagent_trigger(prompt: str) -> str: + """Pretend to spawn a subagent. Internally rebuilds the toolset. + + Use this whenever the user asks you to delegate work — pass a short + description as ``prompt``. + """ + # ``task_tool`` does this internally. Whether the registry-reset that + # used to happen here actually leaks back to the parent task depends + # on asyncio's implicit context-copying semantics (gather creates + # child tasks with copied contexts, so reset_deferred_registry is + # task-local) — but the fix in this PR is what GUARANTEES the + # promotion sticks regardless of which integration path triggers a + # re-entrant ``get_available_tools`` call. + get_available_tools(subagent_enabled=False) + _calls.append(f"fake_subagent_trigger:{prompt}") + return "subagent completed" + + tools = get_available_tools() + [fake_subagent_trigger] + + model = ChatOpenAI( + model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"), + api_key=os.environ["OPENAI_API_KEY"], + base_url=os.environ["OPENAI_API_BASE"], + temperature=0, + max_retries=1, + ) + + system_prompt = ( + "You are a meticulous assistant. Available deferred tools include a " + "calculator and a translator — their schemas are hidden until you " + "search for them via tool_search.\n\n" + "Procedure for the user's request:\n" + " 1. Call tool_search with query 'select:fake_calculator' AND " + "in the SAME tool batch also call fake_subagent_trigger(prompt='go') " + "to delegate the side work. Put both tool_calls in your first response.\n" + " 2. After both tool messages come back, call fake_calculator with " + "the user's expression.\n" + " 3. Reply with just the numeric result." + ) + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt=system_prompt, + ) + + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]}, + config={"recursion_limit": 12}, + ) + + print("\n=== tool calls recorded ===") + for c in _calls: + print(f" {c}") + print("\n=== final message ===") + final_text = result["messages"][-1].content if result["messages"] else "(none)" + print(f" {final_text!r}") + + # The smoking-gun assertion: fake_calculator was actually invoked at the + # tool layer. This is only possible if the promoted schema reached the + # model in turn 2, despite the subagent-style re-entry in turn 1. + calc_calls = [c for c in _calls if c.startswith("fake_calculator:")] + assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}" + + # And the math should actually be done correctly (sanity that the LLM + # really used the result, not just hallucinated the answer). + assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}" diff --git a/backend/tests/test_deferred_tool_registry_promotion.py b/backend/tests/test_deferred_tool_registry_promotion.py new file mode 100644 index 000000000..23b7649ec --- /dev/null +++ b/backend/tests/test_deferred_tool_registry_promotion.py @@ -0,0 +1,390 @@ +"""Reproduce + regression-guard issue #2884. + +Hypothesis from the issue: + ``tools.tools.get_available_tools`` unconditionally calls + ``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry`` + every time it is invoked. If anything calls ``get_available_tools`` again + during the same async context (after the agent has promoted tools via + ``tool_search``), the promotion is wiped and the next model call hides the + tool's schema again. + +These tests pin two things: + +A. **At the unit boundary** — verify the failure mode directly. Promote a + tool in the registry, then call ``get_available_tools`` again and observe + that the ContextVar registry is reset and the promotion is lost. + +B. **At the graph-execution boundary** — drive a real ``create_agent`` graph + with the real ``DeferredToolFilterMiddleware`` through two model turns. + The first turn calls ``tool_search`` which promotes a tool. The second + turn must see that tool's schema in ``request.tools``. If + ``get_available_tools`` were to run again between the two turns and reset + the registry, the second turn's filter would strip the tool. + +Strategy: use the production ``deerflow.tools.tools.get_available_tools`` +unmodified; mock only the LLM and the MCP tool source. Patch +``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that +``get_available_tools`` resolves via lazy import) to return our fixture +tools so we don't need a real MCP server. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import tool as as_tool + + +class FakeToolCallingModel(FakeMessagesListChatModel): + """FakeMessagesListChatModel + no-op bind_tools so create_agent works.""" + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + +# --------------------------------------------------------------------------- +# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled +# --------------------------------------------------------------------------- + + +@as_tool +def fake_mcp_search(query: str) -> str: + """Pretend to search a knowledge base for the given query.""" + return f"results for {query}" + + +@as_tool +def fake_mcp_fetch(url: str) -> str: + """Pretend to fetch a page at the given URL.""" + return f"content of {url}" + + +@pytest.fixture(autouse=True) +def _supply_env(monkeypatch: pytest.MonkeyPatch): + """config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + +@pytest.fixture(autouse=True) +def _reset_deferred_registry_between_tests(): + """Each test must start with a clean ContextVar. + + The registry lives in a module-level ContextVar with no per-task isolation + in a synchronous test runner, so one test's promotion can leak into the + next and silently break filter assertions. + """ + from deerflow.tools.builtins.tool_search import reset_deferred_registry + + reset_deferred_registry() + yield + reset_deferred_registry() + + +def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: + """Make get_available_tools believe an MCP server is registered. + + Build a real ``ExtensionsConfig`` with one enabled MCP server entry so + that both ``AppConfig.from_file`` (which calls + ``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools`` + (which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``) + see a valid instance. Then point the MCP tool cache at our fixture tools. + """ + from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig + + real_ext = ExtensionsConfig( + mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)}, + ) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: real_ext), + ) + monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools)) + + +def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Force config.tool_search.enabled=True without touching the yaml. + + Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs`` + which permanently mutates module-level singletons (``_memory_config``, + ``_title_config``, …) to match the developer's ``config.yaml`` — even + after pytest restores our patch. That leaks across tests later in the + run that rely on those singletons' DEFAULTS (e.g. memory queue tests + require ``_memory_config.enabled = True``, which is the dataclass default + but FALSE in the actual yaml). + + Build a minimal mock AppConfig instead and never call the real loader. + """ + from deerflow.config.app_config import AppConfig + from deerflow.config.tool_search_config import ToolSearchConfig + + mock_cfg = AppConfig.model_construct( + log_level="info", + models=[], + tools=[], + tool_groups=[], + sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"), + tool_search=ToolSearchConfig(enabled=True), + ) + monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg) + + +# --------------------------------------------------------------------------- +# Section A — direct unit-level reproduction +# --------------------------------------------------------------------------- + + +def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch): + """Re-entrant ``get_available_tools()`` must preserve prior promotions. + + Step 1: call get_available_tools() — registers MCP tools as deferred. + Step 2: simulate the agent calling tool_search by promoting one tool. + Step 3: call get_available_tools() again (the same code path + ``task_tool`` exercises mid-run). + + Assertion: after step 3, the promoted tool is STILL promoted (not + re-deferred). On ``main`` before the fix, step 3's + ``reset_deferred_registry()`` wiped the promotion and re-registered + every MCP tool as deferred — this assertion fired with + ``REGRESSION (#2884)``. + """ + from deerflow.tools.builtins.tool_search import get_deferred_registry + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + # Step 1: first call — both MCP tools start deferred + get_available_tools() + reg1 = get_deferred_registry() + assert reg1 is not None + assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"} + + # Step 2: simulate tool_search promoting one of them + reg1.promote({"fake_mcp_search"}) + assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search" + + # Step 3: second call — registry must NOT silently undo the promotion + get_available_tools() + reg2 = get_deferred_registry() + assert reg2 is not None + deferred_after = {e.name for e in reg2.entries} + assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}" + + +# --------------------------------------------------------------------------- +# Section B — graph-execution reproduction +# --------------------------------------------------------------------------- + + +class _ToolSearchPromotingModel(FakeToolCallingModel): + """Two-turn model that: + + Turn 1 → emit a tool_call for ``tool_search`` (the real one) + Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool) + + Records the tools it received on each turn so the test can inspect what + DeferredToolFilterMiddleware actually fed to ``bind_tools``. + """ + + bound_tools_per_turn: list[list[str]] = [] + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + # Record the tool names the model would see in this turn + names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools] + self.bound_tools_per_turn.append(names) + return self + + +def _build_promoting_model() -> _ToolSearchPromotingModel: + return _ToolSearchPromotingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "tool_search", + "args": {"query": "select:fake_mcp_search"}, + "id": "call_search_1", + "type": "tool_call", + } + ], + ), + AIMessage( + content="", + tool_calls=[ + { + "name": "fake_mcp_search", + "args": {"query": "hello"}, + "id": "call_mcp_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="all done"), + ] + ) + + +def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch): + """End-to-end: drive a real create_agent graph through two turns. + + Without the fix, the second-turn bind_tools call should NOT contain + fake_mcp_search (because DeferredToolFilterMiddleware sees it in the + registry and strips it). With the fix, the model sees the schema and can + invoke it. + """ + from langchain.agents import create_agent + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + tools = get_available_tools() + # Sanity: the assembled tool list includes the deferred tools (they're in + # bind_tools but DeferredToolFilterMiddleware strips deferred ones before + # they reach the model) + tool_names = {getattr(t, "name", "") for t in tools} + assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names + + model = _build_promoting_model() + model.bound_tools_per_turn = [] # reset class-level recorder + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt="bug-2884-repro", + ) + + graph.invoke({"messages": [HumanMessage(content="use the search tool")]}) + + # Turn 1: model should NOT see fake_mcp_search (it's deferred) + turn1 = set(model.bound_tools_per_turn[0]) + assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}" + assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}" + + # Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it. + # This is the load-bearing assertion for issue #2884. + assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}" + turn2 = set(model.bound_tools_per_turn[1]) + assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}" + + +# --------------------------------------------------------------------------- +# Section C — the actual issue #2884 trigger: a re-entrant +# get_available_tools call (e.g. when task_tool spawns a subagent) must not +# wipe the parent's promotion. +# --------------------------------------------------------------------------- + + +def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch): + """Issue #2884 in its real shape: a re-entrant get_available_tools call + (the same pattern that happens when ``task_tool`` builds a subagent's + toolset mid-run) must not wipe the parent agent's tool_search promotions. + + Turn 1's tool batch contains BOTH ``tool_search`` (which promotes + ``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls + ``get_available_tools`` again — exactly what ``task_tool`` does when it + builds a subagent's toolset). With the fix, turn 2's bind_tools sees the + promoted tool. Without the fix, the re-entry wipes the registry and + the filter re-hides it. + """ + from langchain.agents import create_agent + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + # The trigger tool simulates what task_tool does internally: rebuild the + # toolset by calling get_available_tools while the registry is live. + @as_tool + def fake_subagent_trigger(prompt: str) -> str: + """Pretend to spawn a subagent. Internally rebuilds the toolset.""" + get_available_tools(subagent_enabled=False) + return f"spawned subagent for: {prompt}" + + tools = get_available_tools() + [fake_subagent_trigger] + + bound_per_turn: list[list[str]] = [] + + class _Model(FakeToolCallingModel): + def bind_tools(self, tools_arg, **kwargs): # type: ignore[override] + bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg]) + return self + + model = _Model( + responses=[ + # Turn 1: do both in one batch — promote AND trigger the + # subagent-style rebuild. LangGraph executes them in order in the + # same agent step. + AIMessage( + content="", + tool_calls=[ + { + "name": "tool_search", + "args": {"query": "select:fake_mcp_search"}, + "id": "call_search_1", + "type": "tool_call", + }, + { + "name": "fake_subagent_trigger", + "args": {"prompt": "go"}, + "id": "call_trigger_1", + "type": "tool_call", + }, + ], + ), + # Turn 2: try to invoke the promoted tool. The model gets this + # turn only if turn 1's bind_tools recorded what the filter sent. + AIMessage( + content="", + tool_calls=[ + { + "name": "fake_mcp_search", + "args": {"query": "hello"}, + "id": "call_mcp_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="all done"), + ] + ) + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt="bug-2884-subagent-repro", + ) + graph.invoke({"messages": [HumanMessage(content="use the search tool")]}) + + # Turn 1 sanity: deferred tool not visible yet + assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0] + + # The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the + # re-entrant get_available_tools call that happened in turn 1's tool batch. + assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}" + turn2 = set(bound_per_turn[1]) + assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}" diff --git a/backend/tests/test_tool_deduplication.py b/backend/tests/test_tool_deduplication.py index ed9efffaf..f018fc57d 100644 --- a/backend/tests/test_tool_deduplication.py +++ b/backend/tests/test_tool_deduplication.py @@ -65,8 +65,7 @@ def _make_minimal_config(tools): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg): +def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): """Config-loaded async-only tools can still be invoked by sync clients.""" async def async_tool_impl(x: int) -> str: @@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg): +def test_no_duplicates_returned(mock_bash, mock_cfg): """get_available_tools() never returns two tools with the same name.""" mock_cfg.return_value = _make_minimal_config([]) @@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg): +def test_first_occurrence_wins(mock_bash, mock_cfg): """When duplicates exist, the first occurrence is kept.""" mock_cfg.return_value = _make_minimal_config([]) @@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog): +def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog): """A warning is logged for every skipped duplicate.""" import logging From eab7ae3d6283a51fbe759e761d39fce2308cc4a3 Mon Sep 17 00:00:00 2001 From: YuJitang Date: Wed, 13 May 2026 23:52:19 +0800 Subject: [PATCH 08/12] feat: stream subagent token usage to header via terminal task events (#2882) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang --- README.md | 2 +- backend/CLAUDE.md | 2 +- .../middlewares/token_usage_middleware.py | 61 ++++++- .../deerflow/tools/builtins/task_tool.py | 55 ++++++- .../tests/test_memory_queue_user_isolation.py | 5 +- backend/tests/test_task_tool_core_logic.py | 153 ++++++++++++++++++ backend/tests/test_token_usage_middleware.py | 49 +++++- .../messages/message-token-usage.tsx | 41 +++-- frontend/src/core/messages/usage.ts | 4 +- frontend/src/core/threads/hooks.ts | 18 ++- 10 files changed, 349 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 9ff1d501b..8248e8fe4 100644 --- a/README.md +++ b/README.md @@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl Complex tasks rarely fit in a single pass. DeerFlow decomposes them. -The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. +The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step. This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands. diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 67ee9cc7e..5e0aebfdb 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) -11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional) +11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py index f59e7f2b7..0d3607faf 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py @@ -9,7 +9,7 @@ from typing import Any, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.todo import Todo -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str: return "thinking" +def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool: + """Return True if the AIMessage contains a tool_call with the given id.""" + for tc in message.tool_calls or []: + if isinstance(tc, dict): + if tc.get("id") == tool_call_id: + return True + elif hasattr(tc, "id") and tc.id == tool_call_id: + return True + return False + + def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]: tool_calls = getattr(message, "tool_calls", None) or [] actions: list[dict[str, Any]] = [] @@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware): if not messages: return None + # Annotate subagent token usage onto the AIMessage that dispatched it. + # When a task tool completes, its usage is cached by tool_call_id. Detect + # the ToolMessage → search backward for the corresponding AIMessage → merge. + # Walk backward through consecutive ToolMessages before the new AIMessage + # so that multiple concurrent task tool calls all get their subagent tokens + # written back to the same dispatch message (merging into one update). + state_updates: dict[int, AIMessage] = {} + if len(messages) >= 2: + from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage + + idx = len(messages) - 2 + while idx >= 0: + tool_msg = messages[idx] + if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id: + break + + subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id) + if subagent_usage: + # Search backward from the ToolMessage to find the AIMessage + # that dispatched it. A single model response can dispatch + # multiple task tool calls, so we can't assume a fixed offset. + dispatch_idx = idx - 1 + while dispatch_idx >= 0: + candidate = messages[dispatch_idx] + if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id): + # Accumulate into an existing update for the same + # AIMessage (multiple task calls in one response), + # or merge fresh from the original message. + existing_update = state_updates.get(dispatch_idx) + prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {}) + merged = { + **prev, + "input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"], + "output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"], + "total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"], + } + state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged}) + break + dispatch_idx -= 1 + idx -= 1 + last = messages[-1] if not isinstance(last, AIMessage): + if state_updates: + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} return None usage = getattr(last, "usage_metadata", None) @@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware): additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {}) if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution: - return None + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs}) - return {"messages": [updated_msg]} + state_updates[len(messages) - 1] = updated_msg + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 861c45b45..cf9281ff4 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -26,6 +26,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can +# write it back to the triggering AIMessage's usage_metadata. +_subagent_usage_cache: dict[str, dict[str, int]] = {} + + +def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool: + if app_config is None: + try: + app_config = get_app_config() + except FileNotFoundError: + return False + return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False)) + + +def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None: + if enabled and usage: + _subagent_usage_cache[tool_call_id] = usage + + +def pop_cached_subagent_usage(tool_call_id: str) -> dict | None: + return _subagent_usage_cache.pop(tool_call_id, None) + def _is_subagent_terminal(result: Any) -> bool: """Return whether a background subagent result is safe to clean up.""" @@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None: return None +def _summarize_usage(records: list[dict] | None) -> dict | None: + """Summarize token usage records into a compact dict for SSE events.""" + if not records: + return None + return { + "input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records), + "output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records), + "total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records), + } + + def _report_subagent_usage(runtime: Any, result: Any) -> None: """Report subagent token usage to the parent RunJournal, if available. @@ -177,6 +210,7 @@ async def task_tool( subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. """ runtime_app_config = _get_runtime_app_config(runtime) + cache_token_usage = _token_usage_cache_enabled(runtime_app_config) available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names() # Get subagent configuration @@ -312,27 +346,32 @@ async def task_tool( last_message_count = current_message_count # Check if task completed, failed, or timed out + usage = _summarize_usage(getattr(result, "token_usage_records", None)) if result.status == SubagentStatus.COMPLETED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_completed", "task_id": task_id, "result": result.result}) + writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") cleanup_background_task(task_id) return f"Task Succeeded. Result: {result.result}" elif result.status == SubagentStatus.FAILED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_failed", "task_id": task_id, "error": result.error}) + writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage}) logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") cleanup_background_task(task_id) return f"Task failed. Error: {result.error}" elif result.status == SubagentStatus.CANCELLED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_cancelled", "task_id": task_id, "error": result.error}) + writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") cleanup_background_task(task_id) return "Task cancelled by user." elif result.status == SubagentStatus.TIMED_OUT: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) + writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage}) logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") cleanup_background_task(task_id) return f"Task timed out. Error: {result.error}" @@ -351,7 +390,9 @@ async def task_tool( timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id}) + usage = _summarize_usage(getattr(result, "token_usage_records", None)) + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) + writer({"type": "task_timed_out", "task_id": task_id, "usage": usage}) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. @@ -374,4 +415,8 @@ async def task_tool( cleanup_background_task(task_id) else: _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) + _subagent_usage_cache.pop(tool_call_id, None) + raise + except Exception: + _subagent_usage_cache.pop(tool_call_id, None) raise diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py index cf068e095..79250817c 100644 --- a/backend/tests/test_memory_queue_user_isolation.py +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue +from deerflow.config.memory_config import MemoryConfig def test_conversation_context_has_user_id(): @@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none(): def test_queue_add_stores_user_id(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") assert len(q._queue) == 1 assert q._queue[0].user_id == "alice" @@ -26,7 +27,7 @@ def test_queue_add_stores_user_id(): def test_queue_process_passes_user_id_to_updater(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") mock_updater = MagicMock() diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 0591c0e8d..658968d65 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -59,12 +59,15 @@ def _make_result( ai_messages: list[dict] | None = None, result: str | None = None, error: str | None = None, + token_usage_records: list[dict] | None = None, ) -> SimpleNamespace: return SimpleNamespace( status=status, ai_messages=ai_messages or [], result=result, error=error, + token_usage_records=token_usage_records or [], + usage_reported=False, ) @@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch): assert len(report_calls) == 1 assert report_calls[0][1] is cancel_result assert cleanup_calls == ["tc-cancel-report"] + + +@pytest.mark.parametrize( + "status, expected_type", + [ + (FakeSubagentStatus.COMPLETED, "task_completed"), + (FakeSubagentStatus.FAILED, "task_failed"), + (FakeSubagentStatus.CANCELLED, "task_cancelled"), + (FakeSubagentStatus.TIMED_OUT, "task_timed_out"), + ], +) +def test_terminal_events_include_usage(monkeypatch, status, expected_type): + """Terminal task events include a usage summary from token_usage_records.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + records = [ + {"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + {"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280}, + ] + result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-usage", + ) + + terminal_events = [e for e in events if e["type"] == expected_type] + assert len(terminal_events) == 1 + assert terminal_events[0]["usage"] == { + "input_tokens": 300, + "output_tokens": 130, + "total_tokens": 430, + } + + +def test_terminal_event_usage_none_when_no_records(monkeypatch): + """Terminal event has usage=None when token_usage_records is empty.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[]) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-no-records", + ) + + completed = [e for e in events if e["type"] == "task_completed"] + assert len(completed) == 1 + assert completed[0]["usage"] is None + + +def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch): + monkeypatch.setattr( + task_tool_module, + "get_app_config", + MagicMock(side_effect=FileNotFoundError("missing config")), + ) + + assert task_tool_module._token_usage_cache_enabled(None) is False + + +def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False)) + runtime = _make_runtime(app_config=app_config) + records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}] + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records) + + task_tool_module._subagent_usage_cache.clear() + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-disabled-cache", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None + + +def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True)) + runtime = _make_runtime(app_config=app_config) + + task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2} + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed"))) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + with pytest.raises(RuntimeError, match="poll failed"): + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-error", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-error") is None diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py index b24ff7b16..9686455c0 100644 --- a/backend/tests/test_token_usage_middleware.py +++ b/backend/tests/test_token_usage_middleware.py @@ -1,9 +1,10 @@ """Tests for TokenUsageMiddleware attribution annotations.""" +import importlib import logging from unittest.mock import MagicMock -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from deerflow.agents.middlewares.token_usage_middleware import ( TOKEN_USAGE_ATTRIBUTION_KEY, @@ -232,3 +233,49 @@ class TestTokenUsageMiddleware: "tool_call_id": "write_todos:remove", } ] + + def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch): + middleware = TokenUsageMiddleware() + first_dispatch = AIMessage( + content="", + tool_calls=[{"id": "task:first", "name": "task", "args": {}}], + ) + second_dispatch = AIMessage( + content="", + tool_calls=[ + {"id": "task:second-a", "name": "task", "args": {}}, + {"id": "task:second-b", "name": "task", "args": {}}, + ], + ) + messages = [ + first_dispatch, + ToolMessage(content="first", tool_call_id="task:first"), + second_dispatch, + ToolMessage(content="second-a", tool_call_id="task:second-a"), + ToolMessage(content="second-b", tool_call_id="task:second-b"), + AIMessage(content="done"), + ] + cached_usage = { + "task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + "task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27}, + } + + task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool") + monkeypatch.setattr( + task_tool_module, + "pop_cached_subagent_usage", + lambda tool_call_id: cached_usage.pop(tool_call_id, None), + ) + + result = middleware.after_model({"messages": messages}, _make_runtime()) + + assert result is not None + usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)] + assert len(usage_updates) == 1 + updated = usage_updates[0] + assert updated.tool_calls == second_dispatch.tool_calls + assert updated.usage_metadata == { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + } diff --git a/frontend/src/components/workspace/messages/message-token-usage.tsx b/frontend/src/components/workspace/messages/message-token-usage.tsx index cc8d0debb..84f8a8057 100644 --- a/frontend/src/components/workspace/messages/message-token-usage.tsx +++ b/frontend/src/components/workspace/messages/message-token-usage.tsx @@ -12,13 +12,11 @@ function TokenUsageSummary({ inputTokens, outputTokens, totalTokens, - unavailable = false, }: { className?: string; inputTokens?: number; outputTokens?: number; totalTokens?: number; - unavailable?: boolean; }) { const { t } = useI18n(); @@ -33,21 +31,15 @@ function TokenUsageSummary({ {t.tokenUsage.label} - {!unavailable ? ( - <> - - {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} - - - {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} - - - {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} - - - ) : ( - {t.tokenUsage.unavailableShort} - )} + + {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} + + + {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} + + + {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} + ); } @@ -55,7 +47,7 @@ function TokenUsageSummary({ export function MessageTokenUsageList({ className, enabled = false, - isLoading = false, + isLoading: _isLoading = false, messages, }: { className?: string; @@ -63,7 +55,7 @@ export function MessageTokenUsageList({ isLoading?: boolean; messages: Message[]; }) { - if (!enabled || isLoading) { + if (!enabled) { return null; } @@ -75,13 +67,16 @@ export function MessageTokenUsageList({ const usage = accumulateUsage(aiMessages); + if (!usage) { + return null; + } + return ( ); } diff --git a/frontend/src/core/messages/usage.ts b/frontend/src/core/messages/usage.ts index 4679dffa5..01e3a59e1 100644 --- a/frontend/src/core/messages/usage.ts +++ b/frontend/src/core/messages/usage.ts @@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null { return hasUsage ? cumulative : null; } -function hasNonZeroUsage( +export function hasNonZeroUsage( usage: TokenUsage | null | undefined, ): usage is TokenUsage { return ( @@ -75,7 +75,7 @@ function hasNonZeroUsage( ); } -function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { +export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { return { inputTokens: base.inputTokens + delta.inputTokens, outputTokens: base.outputTokens + delta.outputTokens, diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 0ac790eb2..adf9dbbb6 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -296,7 +296,11 @@ export function useThreadStream({ onError(error) { setOptimisticMessages([]); toast.error(getStreamErrorMessage(error)); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ queryKey: threadTokenUsageQueryKey(threadIdRef.current), @@ -305,7 +309,11 @@ export function useThreadStream({ }, onFinish(state) { listeners.current.onFinish?.(state.values); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ @@ -339,7 +347,11 @@ export function useThreadStream({ useEffect(() => { startedRef.current = false; sendInFlightRef.current = false; - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); prevHumanMsgCountRef.current = latestMessageCountsRef.current.humanMessageCount; }, [threadId]); From 6e8e6a969be803227aa71cb6ba4c5d116910b4b7 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Wed, 13 May 2026 23:56:06 +0800 Subject: [PATCH 09/12] test: add blocking IO detector (#2924) * test: add blocking IO detector * test: add blocking IO probe option * test: harden blocking IO probe lifecycle * test: move blocking io detector to support --- backend/tests/conftest.py | 93 ++++++ backend/tests/support/__init__.py | 1 + backend/tests/support/detectors/__init__.py | 1 + .../tests/support/detectors/blocking_io.py | 287 ++++++++++++++++++ backend/tests/test_blocking_io_detector.py | 190 ++++++++++++ .../test_blocking_io_probe_integration.py | 22 ++ 6 files changed, 594 insertions(+) create mode 100644 backend/tests/support/__init__.py create mode 100644 backend/tests/support/detectors/__init__.py create mode 100644 backend/tests/support/detectors/blocking_io.py create mode 100644 backend/tests/test_blocking_io_detector.py create mode 100644 backend/tests/test_blocking_io_probe_integration.py diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index a357a3962..9bc8d4884 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import issues when unit-testing lightweight config/registry code in isolation. """ +from __future__ import annotations + import importlib.util import sys from pathlib import Path @@ -11,11 +13,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io # Make 'app' and 'deerflow' importable from any working directory sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts")) +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT) +_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector" + # Break the circular import chain that exists in production code: # deerflow.subagents.__init__ # -> .executor (SubagentExecutor, SubagentResult) @@ -56,6 +63,92 @@ def provisioner_module(): return module +@pytest.fixture() +def blocking_io_detector(): + """Fail a focused test if blocking calls run on the event loop thread.""" + with detect_blocking_io(fail_on_exit=True) as detector: + yield detector + + +def pytest_addoption(parser: pytest.Parser) -> None: + group = parser.getgroup("blocking-io") + group.addoption( + "--detect-blocking-io", + action="store_true", + default=False, + help="Collect blocking calls made while an asyncio event loop is running and report a summary.", + ) + group.addoption( + "--detect-blocking-io-fail", + action="store_true", + default=False, + help="Set a failing exit status when --detect-blocking-io records violations.", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe") + + +def pytest_sessionstart(session: pytest.Session) -> None: + if _blocking_io_probe_enabled(session.config): + _blocking_io_probe.clear() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_call(item: pytest.Item): + if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item): + yield + return + + detector = detect_blocking_io(fail_on_exit=False, stack_limit=18) + detector.__enter__() + setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector) + yield + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item: pytest.Item): + yield + + detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None) + if detector is None: + return + + try: + detector.__exit__(None, None, None) + _blocking_io_probe.record(item.nodeid, detector.violations) + finally: + delattr(item, _BLOCKING_IO_DETECTOR_ATTR) + + +def pytest_sessionfinish(session: pytest.Session) -> None: + if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK: + session.exitstatus = pytest.ExitCode.TESTS_FAILED + + +def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None: + if not _blocking_io_probe_enabled(terminalreporter.config): + return + + header, *details = _blocking_io_probe.format_summary().splitlines() + terminalreporter.write_sep("=", header) + for line in details: + terminalreporter.write_line(line) + + +def _blocking_io_probe_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_fail_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_probe_skipped(item: pytest.Item) -> bool: + return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None + + # --------------------------------------------------------------------------- # Auto-set user context for every test unless marked no_auto_user # --------------------------------------------------------------------------- diff --git a/backend/tests/support/__init__.py b/backend/tests/support/__init__.py new file mode 100644 index 000000000..38361eaf5 --- /dev/null +++ b/backend/tests/support/__init__.py @@ -0,0 +1 @@ +"""Shared test support helpers.""" diff --git a/backend/tests/support/detectors/__init__.py b/backend/tests/support/detectors/__init__.py new file mode 100644 index 000000000..cf9568cb6 --- /dev/null +++ b/backend/tests/support/detectors/__init__.py @@ -0,0 +1 @@ +"""Runtime and static detectors used by tests.""" diff --git a/backend/tests/support/detectors/blocking_io.py b/backend/tests/support/detectors/blocking_io.py new file mode 100644 index 000000000..c1adfd55a --- /dev/null +++ b/backend/tests/support/detectors/blocking_io.py @@ -0,0 +1,287 @@ +"""Test helper for detecting blocking calls on an asyncio event loop. + +The detector is intentionally test-only. It monkeypatches a small set of +well-known blocking entry points and their already-loaded module-level aliases, +then records calls only when they happen on a thread that is currently running +an asyncio event loop. Aliases captured in closures or default arguments remain +out of scope. +""" + +from __future__ import annotations + +import asyncio +import importlib +import sys +import traceback +from collections import Counter +from collections.abc import Callable, Iterable, Iterator +from contextlib import AbstractContextManager +from dataclasses import dataclass +from functools import wraps +from pathlib import Path +from types import TracebackType +from typing import Any + +BlockingCallable = Callable[..., Any] + + +@dataclass(frozen=True) +class BlockingCallSpec: + """Describes one blocking callable to wrap during a detector run.""" + + name: str + target: str + record_on_iteration: bool = False + + +@dataclass(frozen=True) +class BlockingCall: + """One blocking call observed on an asyncio event loop thread.""" + + name: str + target: str + stack: tuple[traceback.FrameSummary, ...] + + +DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = ( + BlockingCallSpec("time.sleep", "time:sleep"), + BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"), + BlockingCallSpec("httpx.Client.request", "httpx:Client.request"), + BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True), + BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"), + BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"), + BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"), +) + + +def _is_event_loop_thread() -> bool: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return False + return loop.is_running() + + +def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]: + module_name, attr_path = target.split(":", maxsplit=1) + owner: object = importlib.import_module(module_name) + parts = attr_path.split(".") + for part in parts[:-1]: + owner = getattr(owner, part) + + attr_name = parts[-1] + original = getattr(owner, attr_name) + return owner, attr_name, original + + +def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]: + return tuple(frame for frame in stack if frame.filename != __file__) + + +class BlockingIODetector(AbstractContextManager["BlockingIODetector"]): + """Record blocking calls made from async runtime code. + + By default the detector reports violations but does not fail on context + exit. Tests can set ``fail_on_exit=True`` or call + ``assert_no_blocking_calls()`` explicitly. + """ + + def __init__( + self, + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, + ) -> None: + self._specs = tuple(specs) + self._fail_on_exit = fail_on_exit + self._patch_loaded_aliases_enabled = patch_loaded_aliases + self._stack_limit = stack_limit + self._patches: list[tuple[object, str, BlockingCallable]] = [] + self._patch_keys: set[tuple[int, str]] = set() + self.violations: list[BlockingCall] = [] + self._active = False + + def __enter__(self) -> BlockingIODetector: + try: + self._active = True + alias_replacements: dict[int, BlockingCallable] = {} + for spec in self._specs: + owner, attr_name, original = _resolve_target(spec.target) + wrapper = self._wrap(spec, original) + self._patch_attribute(owner, attr_name, original, wrapper) + alias_replacements[id(original)] = wrapper + + if self._patch_loaded_aliases_enabled: + self._patch_loaded_module_aliases(alias_replacements) + except Exception: + self._restore() + self._active = False + raise + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback_value: TracebackType | None, + ) -> bool | None: + self._restore() + self._active = False + if exc_type is None and self._fail_on_exit: + self.assert_no_blocking_calls() + return None + + def _restore(self) -> None: + for owner, attr_name, original in reversed(self._patches): + setattr(owner, attr_name, original) + self._patches.clear() + self._patch_keys.clear() + + def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None: + key = (id(owner), attr_name) + if key in self._patch_keys: + return + setattr(owner, attr_name, replacement) + self._patches.append((owner, attr_name, original)) + self._patch_keys.add(key) + + def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None: + for module in tuple(sys.modules.values()): + namespace = getattr(module, "__dict__", None) + if not isinstance(namespace, dict): + continue + + for attr_name, value in tuple(namespace.items()): + replacement = replacements_by_id.get(id(value)) + if replacement is not None: + self._patch_attribute(module, attr_name, value, replacement) + + def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable: + @wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if spec.record_on_iteration: + result = original(*args, **kwargs) + return self._wrap_iteration(spec, result) + self._record_if_blocking(spec) + return original(*args, **kwargs) + + return wrapper + + def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]: + iterator = iter(iterable) + reported = False + + while True: + if not reported: + reported = self._record_if_blocking(spec) + try: + yield next(iterator) + except StopIteration: + return + + def _record_if_blocking(self, spec: BlockingCallSpec) -> bool: + if self._active and _is_event_loop_thread(): + stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit)) + self.violations.append(BlockingCall(spec.name, spec.target, stack)) + return True + return False + + def assert_no_blocking_calls(self) -> None: + if self.violations: + raise AssertionError(format_blocking_calls(self.violations)) + + +class BlockingIOProbe: + """Collect detector output across tests and format a compact summary.""" + + def __init__(self, project_root: Path) -> None: + self._project_root = project_root.resolve() + self._observed: list[tuple[str, BlockingCall]] = [] + + @property + def violation_count(self) -> int: + return len(self._observed) + + @property + def test_count(self) -> int: + return len({nodeid for nodeid, _violation in self._observed}) + + def clear(self) -> None: + self._observed.clear() + + def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None: + for violation in violations: + self._observed.append((nodeid, violation)) + + def format_summary(self, *, limit: int = 30) -> str: + if not self._observed: + return "blocking io probe: no violations" + + call_sites: Counter[tuple[str, str, int, str, str]] = Counter() + for _nodeid, violation in self._observed: + frame = self._local_call_site(violation.stack) + if frame is None: + call_sites[(violation.name, "", 0, "", "")] += 1 + continue + + call_sites[ + ( + violation.name, + self._relative(frame.filename), + frame.lineno, + frame.name, + (frame.line or "").strip(), + ) + ] += 1 + + lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"] + for (name, filename, lineno, function, line), count in call_sites.most_common(limit): + lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}") + return "\n".join(lines) + + def _relative(self, filename: str) -> str: + try: + return str(Path(filename).resolve().relative_to(self._project_root)) + except ValueError: + return filename + + def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None: + local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")] + if local_frames: + return local_frames[-1] + + test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename] + return test_frames[-1] if test_frames else None + + +def detect_blocking_io( + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, +) -> BlockingIODetector: + """Create a detector context manager for a focused test scope.""" + + return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit) + + +def format_blocking_calls(violations: Iterable[BlockingCall]) -> str: + """Format detector output with enough stack context to locate call sites.""" + + lines = ["Blocking calls were executed on an asyncio event loop thread:"] + for index, violation in enumerate(violations, start=1): + lines.append(f"{index}. {violation.name} ({violation.target})") + lines.extend(_format_stack(violation.stack)) + return "\n".join(lines) + + +def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]: + for frame in stack: + location = f"{frame.filename}:{frame.lineno}" + lines = [f" at {frame.name} ({location})"] + if frame.line: + lines.append(f" {frame.line.strip()}") + yield from lines diff --git a/backend/tests/test_blocking_io_detector.py b/backend/tests/test_blocking_io_detector.py new file mode 100644 index 000000000..af44d746d --- /dev/null +++ b/backend/tests/test_blocking_io_detector.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio +import os +import time +from os import walk as imported_walk +from pathlib import Path +from time import sleep as imported_sleep + +import httpx +import pytest +import requests +from support.detectors.blocking_io import ( + BlockingCallSpec, + BlockingIOProbe, + detect_blocking_io, +) + +pytestmark = pytest.mark.asyncio + + +TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),) +REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),) +HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),) +OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),) +PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),) + + +async def test_records_time_sleep_on_event_loop() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_records_already_imported_sleep_alias_on_event_loop() -> None: + original_alias = imported_sleep + + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + imported_sleep(0) + + assert imported_sleep is original_alias + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_can_disable_loaded_alias_patching() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector: + imported_sleep(0) + + assert detector.violations == [] + + +async def test_does_not_record_time_sleep_offloaded_to_thread() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + await asyncio.to_thread(time.sleep, 0) + + assert detector.violations == [] + + +async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None: + await asyncio.to_thread(time.sleep, 0) + + assert blocking_io_detector.violations == [] + + +async def test_does_not_record_sync_call_without_running_event_loop() -> None: + def call_sleep() -> list[str]: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + return [violation.name for violation in detector.violations] + + assert await asyncio.to_thread(call_sleep) == [] + + +async def test_fail_on_exit_includes_call_site() -> None: + with pytest.raises(AssertionError) as exc_info: + with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True): + time.sleep(0) + + message = str(exc_info.value) + assert "time.sleep" in message + assert "test_fail_on_exit_includes_call_site" in message + + +async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str: + return f"{method}:{url}" + + monkeypatch.setattr(requests.sessions.Session, "request", fake_request) + + with detect_blocking_io(REQUESTS_ONLY) as detector: + assert requests.get("https://example.invalid") == "get:https://example.invalid" + + assert [violation.name for violation in detector.violations] == ["requests.Session.request"] + + +async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response: + return httpx.Response(200, request=httpx.Request(method, url)) + + monkeypatch.setattr(httpx.Client, "request", fake_request) + + with detect_blocking_io(HTTPX_ONLY) as detector: + with httpx.Client() as client: + response = client.get("https://example.invalid") + + assert response.status_code == 200 + assert [violation.name for violation in detector.violations] == ["httpx.Client.request"] + + +async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(os.walk(tmp_path)) + + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + original_alias = imported_walk + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(imported_walk(tmp_path)) + + assert imported_walk is original_alias + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None: + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + + assert list(walker) + assert detector.violations == [] + + +async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + assert await asyncio.to_thread(lambda: list(walker)) + + assert detector.violations == [] + + +async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None: + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector: + assert path.read_text(encoding="utf-8") == "content" + + assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"] + + +async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + summary = probe.format_summary() + + assert "blocking io probe: 1 violations across 1 tests" in summary + assert "pathlib.Path.read_text" in summary + + +async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + + assert probe.format_summary() == "blocking io probe: no violations" + + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + assert probe.violation_count == 1 + + probe.clear() + + assert probe.violation_count == 0 + assert probe.format_summary() == "blocking io probe: no violations" diff --git a/backend/tests/test_blocking_io_probe_integration.py b/backend/tests/test_blocking_io_probe_integration.py new file mode 100644 index 000000000..af7a31b9d --- /dev/null +++ b/backend/tests/test_blocking_io_probe_integration.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import time + +import pytest + +ORIGINAL_SLEEP = time.sleep + + +def replacement_sleep(seconds: float) -> None: + return None + + +def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(time, "sleep", replacement_sleep) + assert time.sleep is replacement_sleep + + +@pytest.mark.no_blocking_io_probe +def test_probe_restores_original_after_monkeypatch_teardown() -> None: + assert time.sleep is ORIGINAL_SLEEP + assert getattr(time.sleep, "__wrapped__", None) is None From ba864112a3b5e9029d6fe3f46ecb0abb1582d118 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 11:02:58 +0800 Subject: [PATCH 10/12] chore(deps): bump langsmith from 0.7.36 to 0.8.0 in /backend (#2943) Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.7.36 to 0.8.0. - [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases) - [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.7.36...v0.8.0) --- updated-dependencies: - dependency-name: langsmith dependency-version: 0.8.0 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- backend/uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/uv.lock b/backend/uv.lock index e144fb07e..cd6bc8543 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -2005,7 +2005,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.7.36" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -2018,9 +2018,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" }, + { url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" }, ] [package.optional-dependencies] From 722c690f4fc734c5057b6ed250e3ff6b93168313 Mon Sep 17 00:00:00 2001 From: LawranceLiao <32213920+kibabsquirrel@users.noreply.github.com> Date: Fri, 15 May 2026 10:26:35 +0800 Subject: [PATCH 11/12] fix(memory): isolate queued memory updates by agent (#2941) * fix(memory): isolate queued memory updates by agent * fix(memory): include user in queue identity * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Fix the lint error --------- Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../harness/deerflow/agents/memory/queue.py | 14 +++- .../agents/memory/summarization_hook.py | 3 + backend/tests/test_memory_queue.py | 84 ++++++++++++++++++- .../tests/test_memory_queue_user_isolation.py | 39 +++++++++ .../tests/test_summarization_middleware.py | 27 +++++- 5 files changed, 163 insertions(+), 4 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index b2a147bce..129a28c66 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -40,6 +40,15 @@ class MemoryUpdateQueue: self._timer: threading.Timer | None = None self._processing = False + @staticmethod + def _queue_key( + thread_id: str, + user_id: str | None, + agent_name: str | None, + ) -> tuple[str, str | None, str | None]: + """Return the debounce identity for a memory update target.""" + return (thread_id, user_id, agent_name) + def add( self, thread_id: str, @@ -115,8 +124,9 @@ class MemoryUpdateQueue: correction_detected: bool, reinforcement_detected: bool, ) -> None: + queue_key = self._queue_key(thread_id, user_id, agent_name) existing_context = next( - (context for context in self._queue if context.thread_id == thread_id), + (context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key), None, ) merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) @@ -130,7 +140,7 @@ class MemoryUpdateQueue: reinforcement_detected=merged_reinforcement_detected, ) - self._queue = [c for c in self._queue if c.thread_id != thread_id] + self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key] self._queue.append(context) def _reset_timer(self) -> None: diff --git a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py index dafa7d977..307548e0a 100644 --- a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py +++ b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py @@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_ from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent from deerflow.config.memory_config import get_memory_config +from deerflow.runtime.user_context import resolve_runtime_user_id def memory_flush_hook(event: SummarizationEvent) -> None: @@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None: correction_detected = detect_correction(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) + user_id = resolve_runtime_user_id(event.runtime) queue = get_memory_queue() queue.add_nowait( thread_id=event.thread_id, messages=filtered_messages, agent_name=event.agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 27808b0e8..3d62f0497 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -1,6 +1,6 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue from deerflow.config.memory_config import MemoryConfig @@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None: assert elapsed < 0.1 assert finished.is_set() is False assert finished.wait(1.0) is True + + +def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a") + queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b") + + assert queue.pending_count == 2 + assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"] + + +def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add( + thread_id="thread-1", + messages=["first"], + agent_name="agent-a", + correction_detected=True, + ) + queue.add( + thread_id="thread-1", + messages=["second"], + agent_name="agent-a", + correction_detected=False, + ) + + assert queue.pending_count == 1 + assert queue._queue[0].agent_name == "agent-a" + assert queue._queue[0].messages == ["second"] + assert queue._queue[0].correction_detected is True + + +def test_process_queue_updates_different_agents_in_same_thread_separately() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a") + queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b") + + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + + with ( + patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater), + patch("deerflow.agents.memory.queue.time.sleep"), + ): + queue.flush() + + assert mock_updater.update_memory.call_count == 2 + mock_updater.update_memory.assert_has_calls( + [ + call( + messages=["agent-a"], + thread_id="thread-1", + agent_name="agent-a", + correction_detected=False, + reinforcement_detected=False, + user_id=None, + ), + call( + messages=["agent-b"], + thread_id="thread-1", + agent_name="agent-b", + correction_detected=False, + reinforcement_detected=False, + user_id=None, + ), + ] + ) diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py index 79250817c..ce5d41210 100644 --- a/backend/tests/test_memory_queue_user_isolation.py +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -38,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater(): mock_updater.update_memory.assert_called_once() call_kwargs = mock_updater.update_memory.call_args.kwargs assert call_kwargs["user_id"] == "alice" + + +def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent(): + q = MemoryUpdateQueue() + + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): + q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice") + q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob") + + assert q.pending_count == 2 + assert [context.user_id for context in q._queue] == ["alice", "bob"] + assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]] + + +def test_queue_still_coalesces_updates_for_same_user_thread_and_agent(): + q = MemoryUpdateQueue() + + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): + q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice") + q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice") + + assert q.pending_count == 1 + assert q._queue[0].messages == ["second"] + assert q._queue[0].user_id == "alice" + assert q._queue[0].agent_name == "researcher" + + +def test_add_nowait_keeps_different_users_separate(): + q = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), + patch.object(q, "_schedule_timer"), + ): + q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice") + q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob") + + assert q.pending_count == 2 + assert [context.user_id for context in q._queue] == ["alice", "bob"] diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index cbd94e434..9cd4fc725 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage: ) -def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace: +def _runtime( + thread_id: str | None = "thread-1", + agent_name: str | None = None, + user_id: str | None = None, +) -> SimpleNamespace: context = {} if thread_id is not None: context["thread_id"] = thread_id if agent_name is not None: context["agent_name"] = agent_name + if user_id is not None: + context["user_id"] = user_id return SimpleNamespace(context=context) @@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon queue.add_nowait.assert_called_once() assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent" + + +def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(_messages()[:2]), + preserved_messages=(), + thread_id="main", + agent_name="researcher", + runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"), + ) + ) + + queue.add_nowait.assert_called_once() + assert queue.add_nowait.call_args.kwargs["user_id"] == "alice" From 45060a9ffcfbda8f0dc0427ae09f518e496f4f33 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 15 May 2026 04:32:09 +0200 Subject: [PATCH 12/12] fix(runtime): avoid postgres aggregate row lock (#2962) --- .../deerflow/runtime/events/store/db.py | 33 ++++++++++++++----- backend/tests/test_run_event_store.py | 33 +++++++++++++++++++ 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 9374769f3..b7e54754f 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -11,7 +11,7 @@ import logging from datetime import UTC, datetime from typing import Any -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow @@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore): user = get_current_user() return str(user.id) if user is not None else None + @staticmethod + async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None: + """Return the current max seq while serializing writers per thread. + + PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate + results are not lockable rows. As a release-safe workaround, take a + transaction-level advisory lock keyed by thread_id before reading the + aggregate. Other dialects keep the existing row-locking statement. + """ + stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id) + bind = session.get_bind() + dialect_name = bind.dialect.name if bind is not None else "" + + if dialect_name == "postgresql": + await session.execute( + text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"), + {"thread_id": thread_id}, + ) + return await session.scalar(stmt) + + return await session.scalar(stmt.with_for_update()) + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 """Write a single event — low-frequency path only. @@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore): user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): - # Use FOR UPDATE to serialize seq assignment within a thread. - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = (max_seq or 0) + 1 row = RunEventRow( thread_id=thread_id, @@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore): async with self._sf() as session: async with session.begin(): # Get max seq for the thread (assume all events in batch belong to same thread). - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. thread_id = events[0]["thread_id"] - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = max_seq or 0 rows = [] for e in events: diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index d2c78ccf0..17b796af7 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -268,6 +268,39 @@ class TestEdgeCases: class TestDbRunEventStore: """Tests for DbRunEventStore with temp SQLite.""" + @pytest.mark.anyio + async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self): + from sqlalchemy.dialects import postgresql + + from deerflow.runtime.events.store.db import DbRunEventStore + + class FakeSession: + def __init__(self): + self.dialect = postgresql.dialect() + self.execute_calls = [] + self.scalar_stmt = None + + def get_bind(self): + return self + + async def execute(self, stmt, params=None): + self.execute_calls.append((stmt, params)) + + async def scalar(self, stmt): + self.scalar_stmt = stmt + return 41 + + session = FakeSession() + + max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1") + + assert max_seq == 41 + assert session.execute_calls + assert session.execute_calls[0][1] == {"thread_id": "thread-1"} + assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0]) + compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect())) + assert "FOR UPDATE" not in compiled + @pytest.mark.anyio async def test_basic_crud(self, tmp_path): from deerflow.persistence.engine import close_engine, get_session_factory, init_engine