Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] dad3997459 fix(sandbox): cleanup dead containers and avoid lock-held liveness checks
Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/96707445-0f8b-4901-8ef3-d8e5667f8a05

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>
2026-05-11 00:09:09 +00:00
Willem Jiang b67c2a4e56 fix(sandbox): auto-restart crashed containers transparently (#2788)
When a sandbox container crashes (e.g. due to an internal error), the
  agent enters a connection-refused loop because AioSandboxProvider.get()
  returns a cached but dead sandbox object. Add a liveness check in get()
  that detects crashed containers via backend.is_alive() and evicts them
  from all caches, allowing ensure_sandbox_initialized() to transparently
  recreate a fresh container on the next acquire().

  The behavior is controlled by a new  config option
  (default: true). Set to false to skip health checks and preserve the
  old behavior of returning stale cached sandboxes.

  Closes #2788
2026-05-10 22:53:58 +08:00
165 changed files with 1148 additions and 10912 deletions
+2 -3
View File
@@ -9,9 +9,8 @@ JINA_API_KEY=your-jina-api-key
# InfoQuest API Key
INFOQUEST_API_KEY=your-infoquest-api-key
# Browser CORS allowlist for split-origin or port-forwarded deployments (comma-separated exact origins).
# Leave unset when using the unified nginx endpoint, e.g. http://localhost:2026.
# GATEWAY_CORS_ORIGINS=http://localhost:3000,http://127.0.0.1:3000
# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001
# CORS_ORIGINS=http://localhost:3000
# Optional:
# FIRECRAWL_API_KEY=your-firecrawl-api-key
+19 -13
View File
@@ -46,12 +46,12 @@ Docker provides a consistent, isolated environment with all dependencies pre-con
All services will start with hot-reload enabled:
- Frontend changes are automatically reloaded
- Backend changes trigger automatic restart
- Gateway-hosted LangGraph-compatible runtime supports hot-reload
- LangGraph server supports hot-reload
4. **Access the application**:
- Web Interface: http://localhost:2026
- API Gateway: http://localhost:2026/api/*
- LangGraph-compatible API: http://localhost:2026/api/langgraph/*
- LangGraph: http://localhost:2026/api/langgraph/*
#### Docker Commands
@@ -94,7 +94,7 @@ Use these as practical starting points for development and review environments:
If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket:
```text
unable to get image 'deer-flow-gateway': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock
unable to get image 'deer-flow-dev-langgraph': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock
```
Recommended fix: add your current user to the `docker` group so Docker commands work without `sudo`.
@@ -131,8 +131,9 @@ Host Machine
Docker Compose (deer-flow-dev)
├→ nginx (port 2026) ← Reverse proxy
├→ web (port 3000) ← Frontend with hot-reload
├→ gateway (port 8001) ← Gateway API + LangGraph-compatible runtime with hot-reload
└→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode
├→ api (port 8001) ← Gateway API with hot-reload
├→ langgraph (port 2024) ← LangGraph server with hot-reload
└→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode
```
**Benefits of Docker Development**:
@@ -183,13 +184,17 @@ Required tools:
If you need to start services individually:
1. **Start backend service**:
1. **Start backend services**:
```bash
# Terminal 1: Start Gateway API + embedded agent runtime (port 8001)
# Terminal 1: Start LangGraph Server (port 2024)
cd backend
make dev
# Terminal 2: Start Frontend (port 3000)
# Terminal 2: Start Gateway API (port 8001)
cd backend
make gateway
# Terminal 3: Start Frontend (port 3000)
cd frontend
pnpm dev
```
@@ -207,10 +212,10 @@ If you need to start services individually:
The nginx configuration provides:
- Unified entry point on port 2026
- Rewrites `/api/langgraph/*` to Gateway's LangGraph-compatible API (8001)
- Routes `/api/langgraph/*` to LangGraph Server (2024)
- Routes other `/api/*` endpoints to Gateway API (8001)
- Routes non-API requests to Frontend (3000)
- Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist
- Centralized CORS handling
- SSE/streaming support for real-time agent responses
- Optimized timeouts for long-running operations
@@ -230,8 +235,8 @@ deer-flow/
│ └── nginx.local.conf # Nginx config for local dev
├── backend/ # Backend application
│ ├── src/
│ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001)
│ │ ├── agents/ # LangGraph agent runtime used by Gateway
│ │ ├── gateway/ # Gateway API (port 8001)
│ │ ├── agents/ # LangGraph agents (port 2024)
│ │ ├── mcp/ # Model Context Protocol integration
│ │ ├── skills/ # Skills system
│ │ └── sandbox/ # Sandbox execution
@@ -251,7 +256,8 @@ Browser
Nginx (port 2026) ← Unified entry point
├→ Frontend (port 3000) ← / (non-API requests)
→ Gateway API (port 8001) ← /api/* and /api/langgraph/* (LangGraph-compatible agent interactions)
→ Gateway API (port 8001) ← /api/models, /api/mcp, /api/skills, /api/threads/*/artifacts
└→ LangGraph Server (port 2024) ← /api/langgraph/* (agent interactions)
```
## Development Workflow
+1 -3
View File
@@ -245,8 +245,6 @@ make down # Stop and remove containers
Access: http://localhost:2026
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
#### Option 2: Local Development
@@ -628,7 +626,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
+3 -3
View File
@@ -228,7 +228,7 @@ make down # Stop and remove containers
```
> [!NOTE]
> Le runtime d'agent s'exécute actuellement dans la Gateway. nginx réécrit `/api/langgraph/*` vers l'API compatible LangGraph servie par la Gateway.
> Le serveur d'agents LangGraph fonctionne actuellement via `langgraph dev` (le serveur CLI open source).
Accès : http://localhost:2026
@@ -296,8 +296,8 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca
```yaml
channels:
# LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api)
langgraph_url: http://localhost:8001/api
# LangGraph Server URL (default: http://localhost:2024)
langgraph_url: http://localhost:2024
# Gateway API URL (default: http://localhost:8001)
gateway_url: http://localhost:8001
+3 -3
View File
@@ -181,7 +181,7 @@ make down # コンテナを停止して削除
```
> [!NOTE]
> Agentランタイムは現在Gateway内で実行されます。`/api/langgraph/*`はnginxによってGatewayのLangGraph-compatible APIへ書き換えられます。
> LangGraphエージェントサーバーは現在`langgraph dev`(オープンソースCLIサーバー)経由で実行されます。
アクセス: http://localhost:2026
@@ -249,8 +249,8 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート
```yaml
channels:
# LangGraph-compatible Gateway API base URL(デフォルト: http://localhost:8001/api
langgraph_url: http://localhost:8001/api
# LangGraphサーバーURL(デフォルト: http://localhost:2024
langgraph_url: http://localhost:2024
# Gateway API URL(デフォルト: http://localhost:8001
gateway_url: http://localhost:8001
+3 -3
View File
@@ -184,7 +184,7 @@ make down # 停止并移除容器
```
> [!NOTE]
> 当前 Agent 运行时嵌入在 Gateway 中运行,`/api/langgraph/*` 会由 nginx 重写到 Gateway 的 LangGraph-compatible API
> 当前 LangGraph agent server 通过开源 CLI 服务 `langgraph dev` 运行
访问地址:http://localhost:2026
@@ -254,8 +254,8 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
```yaml
channels:
# LangGraph-compatible Gateway API base URL(默认:http://localhost:8001/api
langgraph_url: http://localhost:8001/api
# LangGraph Server URL(默认:http://localhost:2024
langgraph_url: http://localhost:2024
# Gateway API URL(默认:http://localhost:8001
gateway_url: http://localhost:8001
+3 -5
View File
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
@@ -207,8 +207,6 @@ Configuration priority:
FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWAY_ENABLE_DOCS=false` to disable `/docs`, `/redoc`, and `/openapi.json` in production (default: enabled).
CORS is same-origin by default when requests enter through nginx on port 2026. Split-origin or port-forwarded browser clients must opt in with `GATEWAY_CORS_ORIGINS` (comma-separated exact origins); Gateway `CORSMiddleware` and `CSRFMiddleware` both read that variable so browser CORS and auth-origin checks stay aligned.
**Routers**:
| Router | Endpoints |
@@ -225,7 +223,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
Proxied through nginx: `/api/langgraph/*` Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
### Sandbox System (`packages/harness/deerflow/sandbox/`)
@@ -245,7 +243,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
- `bash` - Execute commands with path translation and error handling
- `ls` - Directory listing (tree format, max 2 levels)
- `read_file` - Read file contents with optional line range
- `write_file` - Write/append to files, creates directories; overwrites by default and exposes the `append` argument in the model-facing schema for end-of-file writes
- `write_file` - Write/append to files, creates directories
- `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process
### Subagent System (`packages/harness/deerflow/subagents/`)
+4 -1
View File
@@ -56,8 +56,11 @@ export OPENAI_API_KEY="your-api-key"
### Run the Development Server
```bash
# Gateway API + embedded agent runtime
# Terminal 1: LangGraph server
make dev
# Terminal 2: Gateway API
make gateway
```
## Project Structure
+32 -28
View File
@@ -11,26 +11,31 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent
│ Nginx (Port 2026) │
│ Unified reverse proxy │
└───────┬──────────────────┬───────────┘
/api/langgraph/* │ /api/* (other)
rewritten to /api/* │
┌────────────────────────────────────────┐
Gateway API (8001)
FastAPI REST + agent runtime
Models, MCP, Skills, Memory, Uploads, │
Artifacts, Threads, Runs, Streaming
┌────────────────────────────────────┐
│ │ Lead Agent │ │
│ │ Middleware Chain, Tools, Subagents │ │
└────────────────────────────────────┘
└────────────────────────────────────────
/api/langgraph/* │ /api/* (other)
▼ ▼
┌────────────────────┐ ┌────────────────────────┐
│ LangGraph Server │ │ Gateway API (8001) │
(Port 2024) │ FastAPI REST
│ │
┌────────────────┐ │ │ Models, MCP, Skills,
│ Lead Agent │ │ │ Memory, Uploads,
│ ┌──────────┐ │ │ │ Artifacts
│ │Middleware│ │ │ └────────────────────────┘
│ │ Chain │ │
│ │ └──────────┘ │ │
│ │ ┌──────────┐ │ │
│ │ Tools │ │
│ │ └──────────┘ │ │
│ │ ┌──────────┐ │ │
│ │ │Subagents │ │ │
│ │ └──────────┘ │ │
│ └────────────────┘ │
└────────────────────┘
```
**Request Routing** (via Nginx):
- `/api/langgraph/*` Gateway LangGraph-compatible API - agent interactions, threads, streaming
- `/api/langgraph/*` → LangGraph Server - agent interactions, threads, streaming
- `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup
- `/` (non-API) → Frontend - Next.js web interface
@@ -74,7 +79,7 @@ Per-thread isolated execution with virtual path translation:
- **Skills path**: `/mnt/skills``deer-flow/skills/` directory
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
- **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`write_file` overwrites by default and exposes `append` for end-of-file writes; `bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
### Subagent System
@@ -188,7 +193,7 @@ export OPENAI_API_KEY="your-api-key-here"
**Full Application** (from project root):
```bash
make dev # Starts Gateway + Frontend + Nginx
make dev # Starts LangGraph + Gateway + Frontend + Nginx
```
Access at: http://localhost:2026
@@ -196,11 +201,14 @@ Access at: http://localhost:2026
**Backend Only** (from backend directory):
```bash
# Gateway API + embedded agent runtime
# Terminal 1: LangGraph server
make dev
# Terminal 2: Gateway API
make gateway
```
Direct access: Gateway at http://localhost:8001
Direct access: LangGraph at http://localhost:2024, Gateway at http://localhost:8001
---
@@ -236,16 +244,12 @@ backend/
│ └── utils/ # Utilities
├── docs/ # Documentation
├── tests/ # Test suite
├── langgraph.json # LangGraph graph registry for tooling/Studio compatibility
├── langgraph.json # LangGraph server configuration
├── pyproject.toml # Python dependencies
├── Makefile # Development commands
└── Dockerfile # Container build
```
`langgraph.json` is not the default service entrypoint. The scripts and Docker
deployments run the Gateway embedded runtime; the file is kept for LangGraph
tooling, Studio, or direct LangGraph Server compatibility.
---
## Configuration
@@ -358,8 +362,8 @@ If a provider is explicitly enabled but required credentials are missing, or the
```bash
make install # Install dependencies
make dev # Run Gateway API + embedded agent runtime (port 8001)
make gateway # Run Gateway API without reload (port 8001)
make dev # Run LangGraph server (port 2024)
make gateway # Run Gateway API (port 8001)
make lint # Run linter (ruff)
make format # Format code (ruff)
```
+28 -24
View File
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
@@ -8,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
@@ -62,7 +63,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
Subsequent boots (admin already exists):
- Runs the one-time "no-auth → with-auth" orphan thread migration for
existing LangGraph thread metadata that has no user_id.
existing LangGraph thread metadata that has no owner_id.
No SQL persistence migration is needed: the four user_id columns
(threads_meta, runs, run_events, feedback) only come into existence
@@ -177,7 +178,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Check admin bootstrap state and migrate orphan threads after admin exists.
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
@@ -218,9 +219,7 @@ def create_app() -> FastAPI:
Configured FastAPI application instance.
"""
config = get_gateway_config()
docs_url = "/docs" if config.enable_docs else None
redoc_url = "/redoc" if config.enable_docs else None
openapi_url = "/openapi.json" if config.enable_docs else None
docs_kwargs = {"docs_url": "/docs", "redoc_url": "/redoc", "openapi_url": "/openapi.json"} if config.enable_docs else {"docs_url": None, "redoc_url": None, "openapi_url": None}
app = FastAPI(
title="DeerFlow API Gateway",
@@ -240,14 +239,12 @@ API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execu
### Architecture
LangGraph-compatible requests are routed through nginx to this gateway.
This gateway provides runtime endpoints for agent runs plus custom endpoints for models, MCP configuration, skills, and artifacts.
LangGraph requests are handled by nginx reverse proxy.
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
""",
version="0.1.0",
lifespan=lifespan,
docs_url=docs_url,
redoc_url=redoc_url,
openapi_url=openapi_url,
**docs_kwargs,
openapi_tags=[
{
"name": "models",
@@ -310,18 +307,25 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: the unified nginx endpoint is same-origin by default. Split-origin
# browser clients must opt in with this explicit Gateway allowlist so CORS
# and CSRF origin checks share the same source of truth.
cors_origins = sorted(get_configured_cors_origins())
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
# In production, nginx handles CORS and no middleware is needed.
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
@@ -370,7 +374,7 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
app.include_router(runs.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict[str, str]:
async def health_check() -> dict:
"""Health check endpoint.
Returns:
+1 -1
View File
@@ -28,7 +28,7 @@ class User(BaseModel):
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
# Auth lifecycle
needs_setup: bool = Field(default=False, description="True when a reset account must complete setup")
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
+3
View File
@@ -8,6 +8,7 @@ class GatewayConfig(BaseModel):
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
port: int = Field(default=8001, description="Port to bind the gateway server")
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
enable_docs: bool = Field(default=True, description="Enable Swagger/ReDoc/OpenAPI endpoints")
@@ -18,9 +19,11 @@ def get_gateway_config() -> GatewayConfig:
"""Get gateway config, loading from environment if available."""
global _gateway_config
if _gateway_config is None:
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
_gateway_config = GatewayConfig(
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
port=int(os.getenv("GATEWAY_PORT", "8001")),
cors_origins=cors_origins_str.split(","),
enable_docs=os.getenv("GATEWAY_ENABLE_DOCS", "true").lower() == "true",
)
return _gateway_config
+2 -7
View File
@@ -6,7 +6,7 @@ State-changing operations require CSRF protection.
import os
import secrets
from collections.abc import Awaitable, Callable
from collections.abc import Callable
from urllib.parse import urlsplit
from fastapi import Request, Response
@@ -106,11 +106,6 @@ def _configured_cors_origins() -> set[str]:
return origins
def get_configured_cors_origins() -> set[str]:
"""Return normalized explicit browser origins from GATEWAY_CORS_ORIGINS."""
return _configured_cors_origins()
def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
@@ -177,7 +172,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
+4 -8
View File
@@ -1,12 +1,8 @@
"""LangGraph compatibility auth handler — shares JWT logic with Gateway.
"""LangGraph Server auth handler — shares JWT logic with Gateway.
The default DeerFlow runtime is embedded in the FastAPI Gateway; scripts and
Docker deployments do not load this module. It is retained for LangGraph
tooling, Studio, or direct LangGraph Server compatibility through
``langgraph.json``'s ``auth.path``.
When that compatibility path is used, this module reuses the same JWT and CSRF
rules as Gateway so both modes validate sessions consistently.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate — validates JWT cookie, extracts user_id,
+1 -1
View File
@@ -305,7 +305,7 @@ async def login_local(
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
The first admin is created explicitly through /initialize. This endpoint creates regular users.
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
+6 -32
View File
@@ -90,28 +90,6 @@ class ThreadSearchRequest(BaseModel):
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
@field_validator("metadata")
@classmethod
def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]:
"""Reject filter entries the SQL backend cannot compile.
Enforces consistent behaviour across SQL and memory backends.
See ``deerflow.persistence.json_compat`` for the shared validators.
"""
if not v:
return v
from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value
bad_entries: list[str] = []
for key, value in v.items():
if not validate_metadata_filter_key(key):
bad_entries.append(f"{key!r} (unsafe key)")
elif not validate_metadata_filter_value(value):
bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})")
if bad_entries:
raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}")
return v
class ThreadStateResponse(BaseModel):
"""Response model for thread state."""
@@ -316,18 +294,14 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
"""
from app.gateway.deps import get_thread_store
from deerflow.persistence.thread_meta import InvalidMetadataFilterError
repo = get_thread_store(request)
try:
rows = await repo.search(
metadata=body.metadata or None,
status=body.status,
limit=body.limit,
offset=body.offset,
)
except InvalidMetadataFilterError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
rows = await repo.search(
metadata=body.metadata or None,
status=body.status,
limit=body.limit,
offset=body.offset,
)
return [
ThreadResponse(
thread_id=r["thread_id"],
-19
View File
@@ -19,7 +19,6 @@ from langchain_core.messages import HumanMessage
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
@@ -268,23 +267,6 @@ async def start_run(
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
body_context = getattr(body, "context", None) or {}
model_name = body_context.get("model_name")
# Coerce non-string model_name values to str before truncation.
if model_name is not None and not isinstance(model_name, str):
model_name = str(model_name)
# Validate model against the allowlist when a model_name is provided.
if model_name:
app_config = get_app_config()
resolved = app_config.get_model_config(model_name)
if resolved is None:
raise HTTPException(
status_code=400,
detail=f"Model {model_name!r} is not in the configured model allowlist",
)
try:
record = await run_mgr.create_or_reject(
thread_id,
@@ -293,7 +275,6 @@ async def start_run(
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
+35 -52
View File
@@ -6,16 +6,16 @@ This document provides a complete reference for the DeerFlow backend APIs.
DeerFlow backend exposes two sets of APIs:
1. **LangGraph-compatible API** - Agent interactions, threads, and streaming (`/api/langgraph/*`)
1. **LangGraph API** - Agent interactions, threads, and streaming (`/api/langgraph/*`)
2. **Gateway API** - Models, MCP, skills, uploads, and artifacts (`/api/*`)
All APIs are accessed through the Nginx reverse proxy at port 2026.
## LangGraph-compatible API
## LangGraph API
Base URL: `/api/langgraph`
The public LangGraph-compatible API follows LangGraph SDK conventions. In the unified nginx deployment, Gateway owns `/api/langgraph/*` and translates those paths to its native `/api/*` run, thread, and streaming routers.
The LangGraph API is provided by the LangGraph server and follows the LangGraph SDK conventions.
### Threads
@@ -104,11 +104,17 @@ Content-Type: application/json
**Recursion Limit:**
`config.recursion_limit` caps the number of graph steps LangGraph will execute
in a single run. The unified Gateway path defaults to `100` in
`build_run_config` (see `backend/app/gateway/services.py`), which is a safer
starting point for plan-mode or subagent-heavy runs. Clients can still set
`recursion_limit` explicitly in the request body; increase it if you run deeply
nested subagent graphs.
in a single run. The `/api/langgraph/*` endpoints go straight to the LangGraph
server and therefore inherit LangGraph's native default of **25**, which is
too low for plan-mode or subagent-heavy runs — the agent typically errors out
with `GraphRecursionError` after the first round of subagent results comes
back, before the lead agent can synthesize the final answer.
DeerFlow's own Gateway and IM-channel paths mitigate this by defaulting to
`100` in `build_run_config` (see `backend/app/gateway/services.py`), but
clients calling the LangGraph API directly must set `recursion_limit`
explicitly in the request body. `100` matches the Gateway default and is a
safe starting point; increase it if you run deeply nested subagent graphs.
**Configurable Options:**
- `model_name` (string): Override the default model
@@ -535,28 +541,14 @@ All APIs return errors in a consistent format:
## Authentication
DeerFlow enforces authentication for all non-public HTTP routes. Public routes are limited to health/docs metadata and these public auth endpoints:
Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials.
- `POST /api/v1/auth/initialize` creates the first admin account when no admin exists.
- `POST /api/v1/auth/login/local` logs in with email/password and sets an HttpOnly `access_token` cookie.
- `POST /api/v1/auth/register` creates a regular `user` account and sets the session cookie.
- `POST /api/v1/auth/logout` clears the session cookie.
- `GET /api/v1/auth/setup-status` reports whether the first admin still needs to be created.
Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers.
The authenticated auth endpoints are:
- `GET /api/v1/auth/me` returns the current user.
- `POST /api/v1/auth/change-password` changes password, optionally changes email during setup, increments `token_version`, and reissues the cookie.
Protected state-changing requests also require the CSRF double-submit token: send the `csrf_token` cookie value as the `X-CSRF-Token` header. Login/register/initialize/logout are bootstrap auth endpoints: they are exempt from the double-submit token but still reject hostile browser `Origin` headers.
User isolation is enforced from the authenticated user context:
- Thread metadata is scoped by `threads_meta.user_id`; search/read/write/delete APIs only expose the current user's threads.
- Thread files live under `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/` and are exposed inside the sandbox as `/mnt/user-data/`.
- Memory and custom agents are stored under `{base_dir}/users/{user_id}/...`.
Note: MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers; that is separate from DeerFlow API authentication.
For production deployments, it is recommended to:
1. Use Nginx for basic auth or OAuth integration
2. Deploy behind a VPN or private network
3. Implement custom authentication middleware
---
@@ -575,13 +567,12 @@ location /api/ {
---
## Streaming Support
## WebSocket Support
Gateway's LangGraph-compatible API streams run events with Server-Sent Events (SSE):
The LangGraph server supports WebSocket connections for real-time streaming. Connect to:
```http
POST /api/langgraph/threads/{thread_id}/runs/stream
Accept: text/event-stream
```
ws://localhost:2026/api/langgraph/threads/{thread_id}/runs/stream
```
---
@@ -617,21 +608,13 @@ const response = await fetch('/api/models');
const data = await response.json();
console.log(data.models);
// Create a run and stream SSE events
const streamResponse = await fetch(`/api/langgraph/threads/${threadId}/runs/stream`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
},
body: JSON.stringify({
input: { messages: [{ role: "user", content: "Hello" }] },
stream_mode: ["values", "messages-tuple", "custom"],
}),
});
const reader = streamResponse.body?.getReader();
// Decode and parse SSE frames from reader in your client code.
// Using EventSource for streaming
const eventSource = new EventSource(
`/api/langgraph/threads/${threadId}/runs/stream`
);
eventSource.onmessage = (event) => {
console.log(JSON.parse(event.data));
};
```
### cURL Examples
@@ -666,7 +649,7 @@ curl -X POST http://localhost:2026/api/langgraph/threads/abc123/runs \
}'
```
> The unified Gateway path defaults `config.recursion_limit` to 100 for
> plan-mode and subagent-heavy runs. Clients may still set
> `config.recursion_limit` explicitly — see the [Create Run](#create-run)
> section for details.
> The `/api/langgraph/*` endpoints bypass DeerFlow's Gateway and inherit
> LangGraph's native `recursion_limit` default of 25, which is too low for
> plan-mode or subagent runs. Set `config.recursion_limit` explicitly — see
> the [Create Run](#create-run) section for details.
+29 -29
View File
@@ -14,28 +14,30 @@ This document provides a comprehensive overview of the DeerFlow backend architec
│ Nginx (Port 2026) │
│ Unified Reverse Proxy Entry Point │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ /api/langgraph/* → Gateway LangGraph-compatible runtime (8001) │ │
│ │ /api/* → Gateway REST APIs (8001) │ │
│ │ /api/langgraph/* → LangGraph Server (2024) │ │
│ │ /api/* → Gateway API (8001) │ │
│ │ /* → Frontend (3000) │ │
│ └────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────┬────────────────────────────────────────┘
┌──────────────────────────────────────────────┐
┌─────────────────────────────────────────────┐ ┌─────────────────────┐
Gateway API │ │ Frontend │
│ (Port 8001) │ │ (Port 3000) │
│ │ │
│ - LangGraph-compatible runs/threads API │ │ - Next.js App │
│ - Embedded Agent Runtime │ │ - React UI │
│ - SSE Streaming │ │ - Chat Interface │
│ - Checkpointing │ │ │
- Models, MCP, Skills, Uploads, Artifacts │ │ │
- Thread Cleanup │ │ │
└─────────────────────────────────────────────┘ └─────────────────────┘
┌──────────────────────────────────────────────┐
┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐
LangGraph Server │ │ Gateway API │ │ Frontend │
(Port 2024) │ │ (Port 8001) │ │ (Port 3000) │
│ │ │ │ │
│ - Agent Runtime │ │ - Models API │ │ - Next.js App │
│ - Thread Mgmt │ │ - MCP Config │ │ - React UI │
│ - SSE Streaming │ │ - Skills Mgmt │ │ - Chat Interface │
│ - Checkpointing │ │ - File Uploads │ │ │
│ │ - Thread Cleanup │ │ │
│ │ - Artifacts │ │ │
└─────────────────────┘ └─────────────────────┘ └─────────────────────┘
│ ┌─────────────────┘
│ │
▼ ▼
┌──────────────────────────────────────────────────────────────────────────┐
│ Shared Configuration │
│ ┌─────────────────────────┐ ┌────────────────────────────────────────┐ │
@@ -50,9 +52,9 @@ This document provides a comprehensive overview of the DeerFlow backend architec
## Component Details
### Gateway Embedded Agent Runtime
### LangGraph Server
The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for robust multi-agent workflow orchestration. Nginx rewrites `/api/langgraph/*` to Gateway's native `/api/*` routes, so the public API remains compatible with LangGraph SDK clients without running a separate LangGraph server.
The LangGraph server is the core agent runtime, built on LangGraph for robust multi-agent workflow orchestration.
**Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent`
@@ -63,8 +65,7 @@ The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for
- Tool execution orchestration
- SSE streaming for real-time responses
**Graph registry**: `langgraph.json` remains available for tooling, Studio, or direct LangGraph Server compatibility.
It is not the default service entrypoint; scripts and Docker deployments run the Gateway embedded runtime.
**Configuration**: `langgraph.json`
```json
{
@@ -77,13 +78,12 @@ It is not the default service entrypoint; scripts and Docker deployments run the
### Gateway API
FastAPI application providing REST endpoints plus the public LangGraph-compatible `/api/langgraph/*` runtime routes.
FastAPI application providing REST endpoints for non-agent operations.
**Entry Point**: `app/gateway/app.py`
**Routers**:
- `models.py` - `/api/models` - Model listing and details
- `thread_runs.py` / `runs.py` - `/api/threads/{id}/runs`, `/api/runs/*` - LangGraph-compatible runs and streaming
- `mcp.py` - `/api/mcp` - MCP server configuration
- `skills.py` - `/api/skills` - Skills management
- `uploads.py` - `/api/threads/{id}/uploads` - File upload
@@ -91,7 +91,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl
- `artifacts.py` - `/api/threads/{id}/artifacts` - Artifact serving
- `suggestions.py` - `/api/threads/{id}/suggestions` - Follow-up suggestion generation
The web conversation delete flow first deletes Gateway-managed thread state through the LangGraph-compatible route, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`.
The web conversation delete flow is now split across both backend surfaces: LangGraph handles `DELETE /api/langgraph/threads/{thread_id}` for thread state, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`.
### Agent Architecture
@@ -353,10 +353,10 @@ SKILL.md Format:
POST /api/langgraph/threads/{thread_id}/runs
{"input": {"messages": [{"role": "user", "content": "Hello"}]}}
2. Nginx → Gateway API (8001)
`/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes
2. Nginx → LangGraph Server (2024)
Proxied to LangGraph server
3. Gateway embedded runtime
3. LangGraph Server
a. Load/create thread state
b. Execute middleware chain:
- ThreadDataMiddleware: Set up paths
@@ -412,7 +412,7 @@ SKILL.md Format:
### Thread Cleanup Flow
```
1. Client deletes conversation via the LangGraph-compatible Gateway route
1. Client deletes conversation via LangGraph
DELETE /api/langgraph/threads/{thread_id}
2. Web UI follows up with Gateway cleanup
-331
View File
@@ -1,331 +0,0 @@
# 用户认证与隔离设计
本文档描述 DeerFlow 当前内置认证模块的设计,而不是历史 RFC。它覆盖浏览器登录、API 认证、CSRF、用户隔离、首次初始化、密码重置、内部调用和升级迁移。
## 设计目标
认证模块的核心目标是把 DeerFlow 从“本地单用户工具”提升为“可多用户部署的 agent runtime”,并让用户身份贯穿 HTTP API、LangGraph-compatible runtime、文件系统、memory、自定义 agent 和反馈数据。
设计约束:
- 默认强制认证:除健康检查、文档和 auth bootstrap 端点外,HTTP 路由都必须有有效 session。
- 服务端持有所有权:客户端 metadata 不能声明 `user_id``owner_id`
- 隔离默认开启:repository(仓储)、文件路径、memory、agent 配置默认按当前用户解析。
- 旧数据可升级:无认证版本留下的 thread 可以在 admin 存在后迁移到 admin。
- 密码不进日志:首次初始化由操作者设置密码;`reset_admin` 只写 0600 凭据文件。
非目标:
- 当前 OAuth 端点只是占位,尚未实现第三方登录。
- 当前用户角色只有 `admin``user`,尚未实现细粒度 RBAC。
- 当前登录限速是进程内字典,多 worker 下不是全局精确限速。
## 核心模型
```mermaid
graph TB
classDef actor fill:#D8CFC4,stroke:#6E6259,color:#2F2A26;
classDef api fill:#C9D7D2,stroke:#5D706A,color:#21302C;
classDef state fill:#D7D3E8,stroke:#6B6680,color:#29263A;
classDef data fill:#E5D2C4,stroke:#806A5B,color:#30251E;
Browser["Browser — access_token cookie and csrf_token cookie"]:::actor
AuthMiddleware["AuthMiddleware — strict session gate"]:::api
CSRFMiddleware["CSRFMiddleware — double-submit token and Origin check"]:::api
AuthRoutes["Auth routes — initialize login register logout me change-password"]:::api
UserContext["Current user ContextVar — request-scoped identity"]:::state
Repositories["Repositories — AUTO resolves user_id from context"]:::state
Files["Filesystem — users/{user_id}/threads/{thread_id}/user-data"]:::data
Memory["Memory and agents — users/{user_id}/memory.json and agents"]:::data
Browser --> AuthMiddleware
Browser --> CSRFMiddleware
AuthMiddleware --> AuthRoutes
AuthMiddleware --> UserContext
UserContext --> Repositories
UserContext --> Files
UserContext --> Memory
```
### 用户表
用户记录定义在 `app.gateway.auth.models.User`,持久化到 `users` 表。关键字段:
| 字段 | 语义 |
|---|---|
| `id` | 用户主键,JWT `sub` 使用该值 |
| `email` | 唯一登录名 |
| `password_hash` | bcrypt hashOAuth 用户可为空 |
| `system_role` | `admin``user` |
| `needs_setup` | reset 后要求用户完成邮箱 / 密码设置 |
| `token_version` | 改密码或 reset 时递增,用于废弃旧 JWT |
### 运行时身份
认证成功后,`AuthMiddleware` 把用户同时写入:
- `request.state.user`
- `request.state.auth`
- `deerflow.runtime.user_context``ContextVar`
`ContextVar` 是这里的核心边界。上层 Gateway 负责写入身份,下层 persistence / file path 只读取结构化的当前用户,不反向依赖 `app.gateway.auth` 具体类型。
可以把 repository 调用的用户参数理解成一个三态 ADT:
```scala
enum UserScope:
case AutoFromContext
case Explicit(userId: String)
case BypassForMigration
```
对应 Python 实现是 `AUTO | str | None`
- `AUTO`:从 `ContextVar` 解析当前用户;没有上下文则抛错。
- `str`:显式指定用户,主要用于测试或管理脚本。
- `None`:跳过用户过滤,只允许迁移脚本或 admin CLI 使用。
## 登录与初始化流程
### 首次初始化
首次启动时,如果没有 admin,服务不会自动创建账号,只记录日志提示访问 `/setup`
流程:
1. 用户访问 `/setup`
2. 前端调用 `GET /api/v1/auth/setup-status`
3. 如果返回 `{"needs_setup": true}`,前端展示创建 admin 表单。
4. 表单提交 `POST /api/v1/auth/initialize`
5. 服务端确认当前没有 admin,创建 `system_role="admin"``needs_setup=false` 的用户。
6. 服务端设置 `access_token` HttpOnly cookie,用户进入 workspace。
`/api/v1/auth/initialize` 只在没有 admin 时可用。并发初始化由数据库唯一约束兜底,失败方返回 409。
### 普通登录
`POST /api/v1/auth/login/local` 使用 `OAuth2PasswordRequestForm`
- `username` 是邮箱。
- `password` 是密码。
- 成功后签发 JWT,放入 `access_token` HttpOnly cookie。
- 响应体只返回 `expires_in``needs_setup`,不返回 token。
登录失败会按客户端 IP 计数。IP 解析只在 TCP peer 属于 `AUTH_TRUSTED_PROXIES` 时信任 `X-Real-IP`,不使用 `X-Forwarded-For`
### 注册
`POST /api/v1/auth/register` 创建普通 `user`,并自动登录。
当前实现允许在没有 admin 时注册普通用户,但 `setup-status` 仍会返回 `needs_setup=true`,因为 admin 仍不存在。这是当前产品策略边界:如果后续要求“必须先初始化 admin 才能注册普通用户”,需要在 `/register` 增加 admin-exists gate。
### 改密码与 reset setup
`POST /api/v1/auth/change-password` 需要当前密码和新密码:
- 校验当前密码。
- 更新 bcrypt hash。
- `token_version += 1`,使旧 JWT 立即失效。
- 重新签发 cookie。
- 如果 `needs_setup=true` 且传了 `new_email`,则更新邮箱并清除 `needs_setup`
`python -m app.gateway.auth.reset_admin` 会:
- 找到 admin 或指定邮箱用户。
- 生成随机密码。
- 更新密码 hash。
- `token_version += 1`
- 设置 `needs_setup=true`
- 写入 `.deer-flow/admin_initial_credentials.txt`,权限 `0600`
命令行只输出凭据文件路径,不输出明文密码。
## HTTP 认证边界
`AuthMiddleware` 是 fail-closed(默认拒绝)的全局认证门。
公开路径:
- `/health`
- `/docs`
- `/redoc`
- `/openapi.json`
- `/api/v1/auth/login/local`
- `/api/v1/auth/register`
- `/api/v1/auth/logout`
- `/api/v1/auth/setup-status`
- `/api/v1/auth/initialize`
其余路径都要求有效 `access_token` cookie。存在 cookie 但 JWT 无效、过期、用户不存在或 `token_version` 不匹配时,直接返回 401,而不是让请求穿透到业务路由。
路由级别的 owner check 由 `require_permission(..., owner_check=True)` 完成:
- 读类请求允许旧的未追踪 legacy thread 兼容读取。
- 写 / 删除类请求使用 `require_existing=True`,要求 thread row 存在且属于当前用户,避免删除后缺 row 导致其他用户误通过。
## CSRF 设计
DeerFlow 使用 Double Submit Cookie
- 服务端设置 `csrf_token` cookie。
- 前端 state-changing 请求发送同值 `X-CSRF-Token` header。
- 服务端用 `secrets.compare_digest` 比较 cookie/header。
需要 CSRF 的方法:
- `POST`
- `PUT`
- `DELETE`
- `PATCH`
auth bootstrap 端点(login/register/initialize/logout)不要求 double-submit token,因为首次调用时浏览器还没有 token;但这些端点会校验 browser `Origin`,拒绝 hostile Origin,避免 login CSRF / session fixation。
## 用户隔离
### Thread metadata
Thread metadata 存在 `threads_meta`,关键隔离字段是 `user_id`
创建 thread 时:
- 客户端传入的 `metadata.user_id``metadata.owner_id` 会被剥离。
- `ThreadMetaRepository.create(..., user_id=AUTO)``ContextVar` 解析真实用户。
- `/api/threads/search` 默认只返回当前用户的 thread。
读取 / 修改 / 删除时:
- `get()` 默认按当前用户过滤。
- `check_access()` 用于路由 owner check。
- 对其他用户的 thread 返回 404,避免泄露资源存在性。
### 文件系统
当前线程文件布局:
```text
{base_dir}/users/{user_id}/threads/{thread_id}/user-data/
├── workspace/
├── uploads/
└── outputs/
```
agent 在 sandbox 内看到统一虚拟路径:
```text
/mnt/user-data/workspace
/mnt/user-data/uploads
/mnt/user-data/outputs
```
`ThreadDataMiddleware` 使用 `get_effective_user_id()` 解析当前用户并生成线程路径。没有认证上下文时会落到 `default` 用户桶,主要用于内部调用、嵌入式 client 或无 HTTP 的本地执行路径。
### Memory
默认 memory 存储:
```text
{base_dir}/users/{user_id}/memory.json
{base_dir}/users/{user_id}/agents/{agent_name}/memory.json
```
有用户上下文时,空或相对 `memory.storage_path` 都使用上述 per-user 默认路径;只有绝对 `memory.storage_path` 会视为显式 opt-out(退出) per-user isolation,所有用户共享该路径。无用户上下文的 legacy 路径仍会把相对 `storage_path` 解析到 `Paths.base_dir` 下。
### 自定义 agent
用户自定义 agent 写入:
```text
{base_dir}/users/{user_id}/agents/{agent_name}/
├── config.yaml
├── SOUL.md
└── memory.json
```
旧布局 `{base_dir}/agents/{agent_name}/` 只作为只读兼容回退。更新或删除旧共享 agent 会要求先运行迁移脚本。
## 内部调用与 IM 渠道
IM channel worker 不是浏览器用户,不持有浏览器 cookie。它们通过 Gateway 内部认证:
- 请求带 `X-DeerFlow-Internal-Token`
- 同时带匹配的 CSRF cookie/header。
- 服务端识别为内部用户,`id="default"``system_role="internal"`
这意味着 channel 产生的数据默认进入 `default` 用户桶。这个选择适合“平台级 bot 身份”,但不是“每个 IM 用户单独隔离”。如果后续要做到外部 IM 用户隔离,需要把外部 platform user 映射到 DeerFlow user,并让 channel manager 设置对应的 scoped identity。
## LangGraph-compatible 认证
Gateway 内嵌 runtime 路径由 `AuthMiddleware``CSRFMiddleware` 保护。
仓库仍保留 `app.gateway.langgraph_auth`,用于 LangGraph Server 直连模式:
- `@auth.authenticate` 校验 JWT cookie、CSRF、用户存在性和 `token_version`
- `@auth.on` 在写入 metadata 时注入 `user_id`,并在读路径返回 `{"user_id": current_user}` 过滤条件。
这保证 Gateway 路由和 LangGraph-compatible 直连模式使用同一 JWT 语义。
## 升级与迁移
从无认证版本升级时,可能存在没有 `user_id` 的历史 thread。
当前策略:
1. 首次启动如果没有 admin,只提示访问 `/setup`,不迁移。
2. 操作者创建 admin。
3. 后续启动时,`_ensure_admin_user()` 找到 admin,并把 LangGraph store 中缺少 `metadata.user_id` 的 thread 迁移到 admin。
文件系统旧布局迁移由脚本处理:
```bash
cd backend
PYTHONPATH=. python scripts/migrate_user_isolation.py --dry-run
PYTHONPATH=. python scripts/migrate_user_isolation.py --user-id <target-user-id>
```
迁移脚本覆盖 legacy `memory.json``threads/``agents/` 到 per-user layout。
## 安全不变量
必须长期保持的不变量:
- JWT 只在 HttpOnly cookie 中传输,不出现在响应 JSON。
- 任何非 public HTTP 路由都不能只靠“cookie 存在”放行,必须严格验证 JWT。
- `token_version` 不匹配必须拒绝,保证改密码 / reset 后旧 session 失效。
- 客户端 metadata 中的 `user_id` / `owner_id` 必须剥离。
- repository 默认 `AUTO` 必须从当前用户上下文解析,不能静默退化成全局查询。
- 只有迁移脚本和 admin CLI 可以显式传 `user_id=None` 绕过隔离。
- 本地文件路径必须通过 `Paths` 和 sandbox path validation 解析,不能拼接未校验的用户输入。
- 捕获认证、迁移、后台任务异常必须记录日志;不能空 catch。
## 已知边界
| 边界 | 当前行为 | 后续方向 |
|---|---|---|
| 无 admin 时注册普通用户 | 允许注册普通 `user` | 如产品要求先初始化 admin,给 `/register` 加 gate |
| 登录限速 | 进程内 dict,单 worker 精确,多 worker 近似 | Redis / DB-backed rate limiter |
| OAuth | 端点占位,未实现 | 接入 provider 并统一 `token_version` / role 语义 |
| IM 用户隔离 | channel 使用 `default` 内部用户 | 建立外部用户到 DeerFlow user 的映射 |
| 绝对 memory path | 显式共享 memory | UI / docs 明确提示 opt-out 风险 |
## 相关文件
| 文件 | 职责 |
|---|---|
| `app/gateway/auth_middleware.py` | 全局认证门、JWT 严格验证、写入 user context |
| `app/gateway/csrf_middleware.py` | CSRF double-submit 和 auth Origin 校验 |
| `app/gateway/routers/auth.py` | initialize/login/register/logout/me/change-password |
| `app/gateway/auth/jwt.py` | JWT 创建与解析 |
| `app/gateway/auth/reset_admin.py` | 密码 reset CLI |
| `app/gateway/auth/credential_file.py` | 0600 凭据文件写入 |
| `app/gateway/authz.py` | 路由权限与 owner check |
| `deerflow/runtime/user_context.py` | 当前用户 ContextVar 与 `AUTO` sentinel |
| `deerflow/persistence/thread_meta/` | thread metadata owner filter |
| `deerflow/config/paths.py` | per-user filesystem layout |
| `deerflow/agents/middlewares/thread_data_middleware.py` | run 时解析用户线程目录 |
| `deerflow/agents/memory/storage.py` | per-user memory storage |
| `deerflow/config/agents_config.py` | per-user custom agents |
| `app/channels/manager.py` | IM channel 内部认证调用 |
| `scripts/migrate_user_isolation.py` | legacy 数据迁移到 per-user layout |
| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库,包含 users / threads_meta / runs / feedback 等表 |
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
+6 -6
View File
@@ -24,11 +24,11 @@ All other test plan sections were executed against either:
| Case | Title | What it covers | Why not run |
|---|---|---|---|
| TC-DOCKER-01 | `deerflow.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` |
| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` |
| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
## Coverage already provided by non-Docker tests
@@ -41,8 +41,8 @@ the test cases that ran on sg_dev or local:
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies |
| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP |
| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
## Reproduction steps when Docker becomes available
@@ -72,6 +72,6 @@ Then run TC-DOCKER-01..06 from the test plan as written.
about *container packaging* details (bind mounts, multi-worker, log
collection), not about whether the auth code paths work.
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
the current reset flow (`reset_admin` → 0600 credentials file, no log leak).
the post-simplify reality (credentials file → 0600 file, no log leak).
The old "grep 'Password:' in docker logs" expectation would have failed
silently and given a false sense of coverage.
+105 -149
View File
@@ -19,7 +19,7 @@
```bash
# 清除已有数据
rm -f backend/.deer-flow/data/deerflow.db
rm -f backend/.deer-flow/users.db
# 选择模式启动
make dev # 标准模式
@@ -28,11 +28,10 @@ make dev-pro # Gateway 模式
```
**验证点:**
- [ ] 控制台输出 admin 邮箱或明文密码
- [ ] 控制台提示 `First boot detected — no admin account exists.`
- [ ] 控制台提示访问 `/setup` 完成 admin 创建
- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": true}`
- [ ] 前端访问 `/login` 会跳转 `/setup`
- [ ] 控制台输出 admin 邮箱和随机密码
- [ ] 密码格式为 `secrets.token_urlsafe(16)` 的 22 字符字符串
- [ ] 邮箱为 `admin@deerflow.dev`
- [ ] 提示 `Change it after login: Settings -> Account`
### 1.2 非首次启动
@@ -43,8 +42,7 @@ make dev
**验证点:**
- [ ] 控制台不输出密码
- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": false}`
- [ ] 已登录用户如果 `needs_setup=True`,访问 workspace 会被引导到 `/setup` 完成改邮箱 / 改密码流程
- [ ] 如果 admin 仍 `needs_setup=True`,控制台有 warning 提示
### 1.3 环境变量配置
@@ -78,22 +76,19 @@ make dev
curl -s $BASE/api/v1/auth/setup-status | jq .
```
**预期:**
- 干净数据库且尚未初始化 admin:返回 `{"needs_setup": true}`
- 已存在 admin:返回 `{"needs_setup": false}`
**预期:** 返回 `{"needs_setup": false}`admin 在启动时已自动创建,`count_users() > 0`)。仅在启动完成前的极短窗口内可能返回 `true`
#### TC-API-02: 首次初始化 Admin
#### TC-API-02: Admin 首次登录
```bash
curl -s -X POST $BASE/api/v1/auth/initialize \
-H "Content-Type: application/json" \
-d '{"email":"admin@example.com","password":"AdminPass1!"}' \
curl -s -X POST $BASE/api/v1/auth/login/local \
-d "username=admin@deerflow.dev&password=<控制台密码>" \
-c cookies.txt | jq .
```
**预期:**
- 状态码 201
- Body: `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}`
- 状态码 200
- Body: `{"expires_in": 604800, "needs_setup": true}`
- `cookies.txt` 包含 `access_token`HttpOnly)和 `csrf_token`(非 HttpOnly
#### TC-API-03: 获取当前用户
@@ -102,9 +97,9 @@ curl -s -X POST $BASE/api/v1/auth/initialize \
curl -s $BASE/api/v1/auth/me -b cookies.txt | jq .
```
**预期:** `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}`
**预期:** `{"id": "...", "email": "admin@deerflow.dev", "system_role": "admin", "needs_setup": true}`
#### TC-API-04: 改密码流程
#### TC-API-04: Setup 流程(改邮箱 + 改密码
```bash
CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}')
@@ -112,36 +107,13 @@ curl -s -X POST $BASE/api/v1/auth/change-password \
-b cookies.txt \
-H "Content-Type: application/json" \
-H "X-CSRF-Token: $CSRF" \
-d '{"current_password":"AdminPass1!","new_password":"NewPass123!"}' | jq .
-d '{"current_password":"<控制台密码>","new_password":"NewPass123!","new_email":"admin@example.com"}' | jq .
```
**预期:**
- 状态码 200
- `{"message": "Password changed successfully"}`
- 再调 `/auth/me` `admin@example.com``needs_setup` `false`
#### TC-API-04a: reset_admin 后的 Setup 流程(改邮箱 + 改密码)
```bash
cd backend
python -m app.gateway.auth.reset_admin --email admin@example.com
# 从 .deer-flow/admin_initial_credentials.txt 读取 reset 后密码
curl -s -X POST $BASE/api/v1/auth/login/local \
-d "username=admin@example.com&password=<凭据文件密码>" \
-c cookies.txt | jq .
CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}')
curl -s -X POST $BASE/api/v1/auth/change-password \
-b cookies.txt \
-H "Content-Type: application/json" \
-H "X-CSRF-Token: $CSRF" \
-d '{"current_password":"<凭据文件密码>","new_password":"AdminPass2!","new_email":"admin2@example.com"}' | jq .
```
**预期:**
- 登录返回 `{"expires_in": 604800, "needs_setup": true}`
- `change-password``/auth/me` 邮箱变为 `admin2@example.com``needs_setup` 变为 `false`
- 再调 `/auth/me` 邮箱变`admin@example.com``needs_setup` `false`
#### TC-API-05: 普通用户注册
@@ -521,7 +493,7 @@ curl -s -X POST $BASE/api/v1/auth/register \
```bash
# 检查数据库
sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM users LIMIT 3;"
sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMIT 3;"
```
**预期:** `password_hash``$2b$` 开头(bcrypt 格式)
@@ -534,25 +506,24 @@ sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM us
### 4.1 首次登录流程
#### TC-UI-01: 无 admin 时访问 workspace 跳转 setup
#### TC-UI-01: 访问首页跳转登录
1. 打开 `http://localhost:2026/workspace`
2. **预期:** 自动跳转到 `/setup`
2. **预期:** 自动跳转到 `/login`
#### TC-UI-02: Setup 页面创建 admin
#### TC-UI-02: Login 页面
1. 输入 admin 邮箱、密码、确认密码
2. 点击 Create Admin Account
1. 输入 admin 邮箱和控制台密码
2. 点击 Login
3. **预期:** 跳转到 `/setup`(因为 `needs_setup=true`
#### TC-UI-03: Setup 页面
1. 输入新邮箱、控制台密码(current)、新密码、确认密码
2. 点击 Complete Setup
3. **预期:** 跳转到 `/workspace`
4. 刷新页面不跳回 `/setup`
#### TC-UI-03: 已初始化后 Login 页面
1. 退出登录后访问 `/login`
2. 输入 admin 邮箱和密码
3. 点击 Login
4. **预期:** 跳转到 `/workspace`
#### TC-UI-04: Setup 密码不匹配
1. 新密码和确认密码不一致
@@ -631,7 +602,7 @@ sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM us
#### TC-UI-15: reset_admin 后重新登录
1. 执行 `cd backend && python -m app.gateway.auth.reset_admin`
2. `.deer-flow/admin_initial_credentials.txt` 读取新密码登录
2. 使用新密码登录
3. **预期:** 跳转到 `/setup` 页面(`needs_setup` 被重置为 true
4. 旧 session 已失效
@@ -674,28 +645,18 @@ make install
make dev
```
#### TC-UPG-01: 首次启动等待 admin 初始化
#### TC-UPG-01: 首次启动创建 admin
**预期:**
- [ ] 控制台输出 admin 邮箱随机密码
- [ ] 访问 `/setup` 可创建第一个 admin
- [ ] 控制台输出 admin 邮箱`admin@deerflow.dev`)和随机密码
- [ ] 无报错,正常启动
#### TC-UPG-02: 旧 Thread 迁移到 admin
```bash
# 创建第一个 admin
curl -s -X POST http://localhost:2026/api/v1/auth/initialize \
-H "Content-Type: application/json" \
-d '{"email":"admin@example.com","password":"AdminPass1!"}' \
-c cookies.txt
# 重启一次:启动迁移只在已有 admin 的启动路径执行
make stop && make dev
# 登录 admin
curl -s -X POST http://localhost:2026/api/v1/auth/login/local \
-d "username=admin@example.com&password=AdminPass1!" \
-d "username=admin@deerflow.dev&password=<控制台密码>" \
-c cookies.txt
# 查看 thread 列表
@@ -709,8 +670,8 @@ curl -s -X POST http://localhost:2026/api/threads/search \
**预期:**
- [ ] 返回的 thread 数量 ≥ 旧版创建的数量
- [ ] 控制台日志有 `Migrated N orphan LangGraph thread(s) to admin`
- [ ] thread 只对 admin 可见
- [ ] 控制台日志有 `Migrated N orphaned thread(s) to admin`
- [ ] 每个 thread `metadata.owner_id` 都已被设为 admin 的 ID
#### TC-UPG-03: 旧 Thread 内容完整
@@ -722,7 +683,7 @@ curl -s http://localhost:2026/api/threads/<old-thread-id> \
**预期:**
- [ ] `metadata.title` 保留原值(如 `old-thread-1`
- [ ] 响应不回显服务端保留的 `user_id` / `owner_id`
- [ ] `metadata.owner_id` 已填充
#### TC-UPG-04: 新用户看不到旧 Thread
@@ -745,19 +706,18 @@ curl -s -X POST http://localhost:2026/api/threads/search \
### 5.3 数据库 Schema 兼容
#### TC-UPG-05: 无 deerflow.db 时创建 schema 但不创建默认用户
#### TC-UPG-05: 无 users.db 时自动创建
```bash
ls -la backend/.deer-flow/data/deerflow.db
sqlite3 backend/.deer-flow/data/deerflow.db "SELECT COUNT(*) FROM users;"
ls -la backend/.deer-flow/users.db
```
**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup``token_version`;未调用 `/initialize` 前用户数为 0
**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup``token_version`
#### TC-UPG-06: deerflow.db WAL 模式
#### TC-UPG-06: users.db WAL 模式
```bash
sqlite3 backend/.deer-flow/data/deerflow.db "PRAGMA journal_mode;"
sqlite3 backend/.deer-flow/users.db "PRAGMA journal_mode;"
```
**预期:** 返回 `wal`
@@ -808,9 +768,9 @@ make dev
```
**预期:**
- [ ] 服务正常启动(忽略 `deerflow.db`,无 auth 相关代码不报错)
- [ ] 服务正常启动(忽略 `users.db`,无 auth 相关代码不报错)
- [ ] 旧对话数据仍然可访问
- [ ] `deerflow.db` 文件残留但不影响运行
- [ ] `users.db` 文件残留但不影响运行
#### TC-UPG-12: 再次升级到 auth 分支
@@ -821,47 +781,51 @@ make dev
```
**预期:**
- [ ] 识别已有 `deerflow.db`,不重新创建 admin
- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `deerflow.db`
- [ ] 识别已有 `users.db`,不重新创建 admin
- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `users.db`
### 5.7 Admin 初始化与 reset_admin
### 5.7 休眠 Admin初始密码未使用/未更改)
> 首次启动生成默认 admin,也不在日志输出密码。忘记密码时走 `reset_admin`,新密码写入 0600 凭据文件
> 首次启动生成 admin + 随机密码,但运维未登录、未改密码
> 密码只在首次启动的控制台闪过一次,后续启动不再显示。
#### TC-UPG-13: 未初始化 admin 时重启不创建默认账号
#### TC-UPG-13: 重启后自动重置密码并打印
```bash
rm -f backend/.deer-flow/data/deerflow.db
# 首次启动,记录密码
rm -f backend/.deer-flow/users.db
make dev
# 控制台输出密码 P0,不登录
make stop
# 隔了几天,再次启动
make dev
curl -s $BASE/api/v1/auth/setup-status | jq .
# 控制台输出新密码 P1
```
**预期:**
- [ ] 控制台输出密码
- [ ] `setup-status` 仍为 `{"needs_setup": true}`
- [ ] 访问 `/setup` 仍可创建第一个 admin
- [ ] 控制台输出 `Admin account setup incomplete — password reset`
- [ ] 输出新密码 P1P0 已失效)
- [ ] 用 P1 可以登录,P0 不可以
- [ ] 登录后 `needs_setup=true`,跳转 `/setup`
- [ ] `token_version` 递增(旧 session 如有也失效)
#### TC-UPG-14: 密码丢失 — reset_admin 写入凭据文件
#### TC-UPG-14: 密码丢失 — 无需 CLI,重启即可
```bash
python -m app.gateway.auth.reset_admin --email admin@example.com
ls -la backend/.deer-flow/admin_initial_credentials.txt
cat backend/.deer-flow/admin_initial_credentials.txt
# 忘记了控制台密码 → 直接重启服务
make stop && make dev
# 控制台自动输出新密码
```
**预期:**
- [ ] 命令行只输出凭据文件路径,不输出明文密码
- [ ] 凭据文件权限为 `0600`
- [ ] 凭据文件包含 email + password 行
- [ ] 该用户下次登录返回 `needs_setup=true`
- [ ] 无需 `reset_admin`,重启服务即可拿到新密码
- [ ] `reset_admin` CLI 仍然可用作手动备选方案
#### TC-UPG-15: 未初始化 admin 期间普通用户注册策略边界
#### TC-UPG-15: 休眠 admin 期间普通用户注册
```bash
# admin 尚不存在,普通用户尝试注册
# admin 存在但从未登录,普通用户注册
curl -s -X POST $BASE/api/v1/auth/register \
-H "Content-Type: application/json" \
-d '{"email":"earlybird@example.com","password":"EarlyPass1!"}' \
@@ -869,11 +833,11 @@ curl -s -X POST $BASE/api/v1/auth/register \
```
**预期:**
- [ ] 当前代码允许注册普通用户并自动登录201,角色为 `user`
- [ ] `setup-status` 仍为 `{"needs_setup": true}`,因为 admin 仍不存在
- [ ] 这是一个产品策略边界:若要求“必须先有 admin”,需要在 `/register` 增加 admin-exists gate
- [ ] 注册成功201,角色为 `user`
- [ ] 无法提权为 admin
- [ ] 普通用户的数据与 admin 隔离
#### TC-UPG-16: 普通用户数据与后续 admin 隔离
#### TC-UPG-16: 休眠 admin 不影响后续操作
```bash
# 普通用户正常创建 thread、发消息
@@ -885,13 +849,14 @@ curl -s -X POST $BASE/api/threads \
-d '{"metadata":{}}' | jq .thread_id
```
**预期:** 普通用户正常创建 thread;后续 admin 创建后,搜索不到该普通用户 thread
**预期:** 正常创建,不受休眠 admin 影响
#### TC-UPG-17: reset_admin 完成 Setup
#### TC-UPG-17: 休眠 admin 最终完成 Setup
```bash
# 运维终于登录
curl -s -X POST $BASE/api/v1/auth/login/local \
-d "username=admin@example.com&password=<凭据文件密码>" \
-d "username=admin@deerflow.dev&password=<P0或P1>" \
-c admin.txt | jq .needs_setup
# 预期: true
@@ -901,7 +866,7 @@ curl -s -X POST $BASE/api/v1/auth/change-password \
-b admin.txt \
-H "Content-Type: application/json" \
-H "X-CSRF-Token: $CSRF" \
-d '{"current_password":"<凭据文件密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \
-d '{"current_password":"<密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \
-c admin.txt
# 验证
@@ -911,7 +876,7 @@ curl -s $BASE/api/v1/auth/me -b admin.txt | jq '{email, needs_setup}'
**预期:**
- [ ] `email` 变为 `admin@real.com`
- [ ] `needs_setup` 变为 `false`
- [ ] 后续登录使用新密码
- [ ] 后续重启控制台不再有 warning
#### TC-UPG-18: 长期未用后 JWT 密钥轮换
@@ -925,8 +890,8 @@ make stop && make dev
**预期:**
- [ ] 服务正常启动
- [ ] 账号密码仍可登录(密码存在 DB,与 JWT 密钥无关)
- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)
- [ ] 密码仍可登录(密码存在 DB,与 JWT 密钥无关)
- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)— 但因为从未登录过也没有旧 token
---
@@ -945,7 +910,7 @@ for i in 1 2 3; do
done
# 检查 admin 数量
sqlite3 backend/.deer-flow/data/deerflow.db \
sqlite3 backend/.deer-flow/users.db \
"SELECT COUNT(*) FROM users WHERE system_role='admin';"
```
@@ -1090,7 +1055,7 @@ curl -s -X POST $BASE/api/v1/auth/register \
wait
# 检查用户数
sqlite3 backend/.deer-flow/data/deerflow.db \
sqlite3 backend/.deer-flow/users.db \
"SELECT COUNT(*) FROM users WHERE email='race@example.com';"
```
@@ -1200,16 +1165,13 @@ curl -s -w "%{http_code}" -X DELETE "$BASE/api/threads/$TID" \
```bash
cd backend
python -m app.gateway.auth.reset_admin
cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p1.txt
P1=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p1.txt)
# 记录密码 P1
python -m app.gateway.auth.reset_admin
cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p2.txt
P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt)
# 记录密码 P2
```
**预期:**
- [ ] `.deer-flow/admin_initial_credentials.txt` 每次都会被重写,文件权限为 `0600`
- [ ] P1 ≠ P2(每次生成新随机密码)
- [ ] P1 不可用,只有 P2 有效
- [ ] `token_version` 递增了 2
@@ -1362,8 +1324,7 @@ done
```bash
GW=http://localhost:8001
for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local \
/api/v1/auth/register /api/v1/auth/initialize /api/v1/auth/logout; do
for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local /api/v1/auth/register; do
echo "$path: $(curl -s -w '%{http_code}' -o /dev/null $GW$path)"
done
# 预期: 200 或 405/422(方法不对但不是 401
@@ -1438,9 +1399,9 @@ done
>
> 前置条件:
> - `.env` 中设置 `AUTH_JWT_SECRET`(否则每次容器重启 session 全部失效)
> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `deerflow.db`
> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `users.db`
#### TC-DOCKER-01: deerflow.db 通过 volume 持久化
#### TC-DOCKER-01: users.db 通过 volume 持久化
```bash
# 启动容器
@@ -1455,13 +1416,13 @@ curl -s -X POST $BASE/api/v1/auth/register \
-H "Content-Type: application/json" \
-d '{"email":"docker-test@example.com","password":"DockerTest1!"}' -w "\nHTTP %{http_code}"
# 检查宿主机上的 deerflow.db
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db
sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db \
# 检查宿主机上的 users.db
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db
sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db \
"SELECT email FROM users WHERE email='docker-test@example.com';"
```
**预期:** deerflow.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。
**预期:** users.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。
#### TC-DOCKER-02: 重启容器后 session 保持
@@ -1505,24 +1466,22 @@ done
**已知限制:** In-process rate limiter 不跨 worker 共享。生产环境如需精确限速,需要 Redis 等外部存储。
#### TC-DOCKER-04: IM 渠道使用内部认证
#### TC-DOCKER-04: IM 渠道不经过 auth
```bash
# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 调 Gateway
# 请求携带 process-local internal auth header,并带匹配的 CSRF cookie/header
# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 通信
# 不走 nginx,不经过 AuthMiddleware
# 验证方式:检查 gateway 日志中 channel manager 的请求不包含 auth 错误
docker logs deer-flow-gateway 2>&1 | grep -E "ChannelManager|channel" | head -10
```
**预期:** 无 auth 相关错误。渠道不依赖浏览器 cookie;服务端通过内部认证头把请求归入 `default` 用户桶
**预期:** 无 auth 相关错误。渠道通过 `langgraph-sdk` 直连 LangGraph Server`http://langgraph:2024`),不走 auth 层
#### TC-DOCKER-05: reset_admin 密码写入 0600 凭证文件(不再走日志)
#### TC-DOCKER-05: admin 密码写入 0600 凭证文件(不再走日志)
```bash
# 首次启动不会自动生成 admin 密码。先重置已有 admin,凭据文件写在挂载到宿主机的 DEER_FLOW_HOME 下
docker exec deer-flow-gateway python -m app.gateway.auth.reset_admin --email docker-test@example.com
# 凭证文件写在挂载到宿主机的 DEER_FLOW_HOME 下
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/admin_initial_credentials.txt
# 预期文件权限: -rw------- (0600)
@@ -1553,15 +1512,14 @@ sleep 15
docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l
# 预期: 0
# auth 流程正常:未登录受保护接口返回 401
# auth 流程正常
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
# 预期: 401
curl -s -X POST $BASE/api/v1/auth/initialize \
-H "Content-Type: application/json" \
-d '{"email":"admin@example.com","password":"AdminPass1!"}' \
curl -s -X POST $BASE/api/v1/auth/login/local \
-d "username=admin@deerflow.dev&password=<日志密码>" \
-c cookies.txt -w "\nHTTP %{http_code}"
# 预期: 201
# 预期: 200
```
### 7.4 补充边界用例
@@ -1629,15 +1587,13 @@ curl -s -D - -X POST $BASE/api/v1/auth/login/local \
#### TC-EDGE-05: HTTP 无 max_age / HTTPS 有 max_age
```bash
GW=http://localhost:8001
# HTTP
curl -s -D - -X POST $GW/api/v1/auth/login/local \
curl -s -D - -X POST $BASE/api/v1/auth/login/local \
-d "username=admin@example.com&password=正确密码" 2>/dev/null \
| grep "access_token=" | grep -oi "max-age=[0-9]*" || echo "NO max-age (HTTP session cookie)"
# HTTPS:直连 Gateway 才能用 X-Forwarded-Proto 模拟 HTTPSnginx 会覆盖该 header
curl -s -D - -X POST $GW/api/v1/auth/login/local \
# HTTPS
curl -s -D - -X POST $BASE/api/v1/auth/login/local \
-H "X-Forwarded-Proto: https" \
-d "username=admin@example.com&password=正确密码" 2>/dev/null \
| grep "access_token=" | grep -oi "max-age=[0-9]*"
@@ -1756,10 +1712,10 @@ curl -s -X POST $BASE/api/threads \
-b cookies.txt \
-H "Content-Type: application/json" \
-H "X-CSRF-Token: $CSRF" \
-d '{"metadata":{"owner_id":"victim-user-id","user_id":"victim-user-id"}}' | jq .metadata
-d '{"metadata":{"owner_id":"victim-user-id"}}' | jq .metadata.owner_id
```
**预期:** 返回的 `metadata` 不包含 `owner_id` `user_id`真实所有权写入 `threads_meta.user_id`,不从客户端 metadata 接收,也不通过 metadata 回显
**预期:** 返回的 `metadata.owner_id` 应为当前登录用户的 ID,不是请求中注入的 `victim-user-id`服务端应覆盖客户端提供的 `user_id`
#### 7.5.6 HTTP Method 探测
@@ -1840,6 +1796,6 @@ cd backend && PYTHONPATH=. uv run pytest \
# 核心接口冒烟
curl -s $BASE/health # 200
curl -s $BASE/api/models # 401 (无 cookie)
curl -s $BASE/api/v1/auth/setup-status # 200
curl -s -X POST $BASE/api/v1/auth/setup-status # 200
curl -s $BASE/api/v1/auth/me -b cookies.txt # 200 (有 cookie)
```
+24 -35
View File
@@ -2,16 +2,13 @@
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
完整设计见 [AUTH_DESIGN.md](AUTH_DESIGN.md)。
## 核心概念
认证模块采用**始终强制**策略:
- 首次启动时不会自动创建账号;首次访问 `/setup` 时由操作者创建第一个 admin 账号
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
- 认证从一开始就是强制的,无竞争窗口
- 已有 admin 后,服务启动时会把历史对话(升级前创建且缺少 `user_id` 的 thread)迁移到 admin 名下
- 新数据按用户隔离:thread、workspace/uploads/outputs、memory、自定义 agent 都归属当前用户
- 历史对话(升级前创建的 thread自动迁移到 admin 名下
## 升级步骤
@@ -28,41 +25,39 @@ cd backend && make install
make dev
```
如果没有 admin 账号,控制台只会提示
控制台会输出
```
============================================================
First boot detected — no admin account exists.
Visit /setup to complete admin account creation.
Admin account created on first boot
Email: admin@deerflow.dev
Password: aB3xK9mN_pQ7rT2w
Change it after login: Settings → Account
============================================================
```
首次启动不会在日志里打印随机密码,也不会写入默认 admin。这样避免启动日志泄露凭据,也避免在操作者创建账号前出现可被猜测的默认身份
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台
### 3. 创建 admin
### 3. 登录
访问 `http://localhost:2026/setup`,填写邮箱和密码创建第一个 admin 账号。创建成功后会自动登录并进入 workspace
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录
如果这是从无认证版本升级,创建 admin 后重启一次服务,让启动迁移把缺少 `user_id` 的历史 thread 归属到 admin。
### 4. 修改密码
### 4. 登录
后续访问 `http://localhost:2026/login`,使用已创建的邮箱和密码登录。
登录后进入 Settings → Account → Change Password。
### 5. 添加用户(可选)
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话、上传文件、输出文件、memory 和自定义 agent
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
## 安全机制
| 机制 | 说明 |
|------|------|
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
| CSRF Double Submit Cookie | 受保护的 POST/PUT/PATCH/DELETE 请求需携带 `X-CSRF-Token`;登录/注册/初始化/登出走 auth 端点 Origin 校验 |
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
| bcrypt 密码哈希 | 密码不以明文存储 |
| Thread owner filter | `threads_meta.user_id` 由服务端认证上下文写入,搜索、读取、更新、删除默认按当前用户过滤 |
| 文件系统隔离 | 线程数据写入 `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/`sandbox 内统一映射为 `/mnt/user-data/` |
| Memory / agent 隔离 | 用户 memory 和自定义 agent 写入 `{base_dir}/users/{user_id}/...`;旧共享 agent 只作为只读兼容回退 |
| 多租户隔离 | 用户只能访问自己的 thread |
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
## 常见操作
@@ -79,26 +74,22 @@ python -m app.gateway.auth.reset_admin
python -m app.gateway.auth.reset_admin --email user@example.com
```
新的随机密码写入 `.deer-flow/admin_initial_credentials.txt`,文件权限为 `0600`。命令行只输出文件路径,不输出明文密码
输出新的随机密码。
### 完全重置
删除统一 SQLite 数据库,重启后重新访问 `/setup` 创建新 admin
删除用户数据库,重启后自动创建新 admin
```bash
rm -f backend/.deer-flow/data/deerflow.db
# 重启服务后访问 http://localhost:2026/setup
rm -f backend/.deer-flow/users.db
# 重启服务,控制台输出新密码
```
## 数据存储
| 文件 | 内容 |
|------|------|
| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库(users、threads_meta、runs、feedback 等应用数据 |
| `.deer-flow/users/{user_id}/threads/{thread_id}/user-data/` | 用户线程的 workspace、uploads、outputs |
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色 |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
### 生产环境建议
@@ -120,21 +111,19 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
| `/api/v1/auth/change-password` | POST | 修改密码 |
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
| `/api/v1/auth/initialize` | POST | 首次初始化第一个 admin(仅无 admin 时可调用) |
## 兼容性
- **标准模式**`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化
- **标准模式**`make dev`):完全兼容admin 自动创建
- **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
## 故障排查
| 症状 | 原因 | 解决 |
|------|------|------|
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
-2
View File
@@ -8,7 +8,6 @@ This directory contains detailed documentation for the DeerFlow backend.
|----------|-------------|
| [ARCHITECTURE.md](ARCHITECTURE.md) | System architecture overview |
| [API.md](API.md) | Complete API reference |
| [AUTH_DESIGN.md](AUTH_DESIGN.md) | User authentication, CSRF, and per-user isolation design |
| [CONFIGURATION.md](CONFIGURATION.md) | Configuration options |
| [SETUP.md](SETUP.md) | Quick setup guide |
@@ -43,7 +42,6 @@ docs/
├── README.md # This file
├── ARCHITECTURE.md # System architecture
├── API.md # API reference
├── AUTH_DESIGN.md # User authentication and isolation design
├── CONFIGURATION.md # Configuration guide
├── SETUP.md # Setup instructions
├── FILE_UPLOAD.md # File upload feature
-401
View File
@@ -1,401 +0,0 @@
# Storage Package Design
## Background
DeerFlow currently has several persistence responsibilities spread across app, gateway, runtime, and legacy persistence modules. This makes the persistence boundary difficult to reason about and creates several migration risks:
- Routers and runtime services can accidentally depend on concrete persistence implementations instead of stable contracts.
- User/auth, run metadata, thread metadata, feedback, run events, and checkpointer setup are initialized through different paths.
- Some persistence behavior is duplicated between memory, SQLite, and PostgreSQL-oriented code paths.
- Incremental migration is hard because app-level code and storage-level code are coupled.
- Adding or validating another SQL backend requires touching app/runtime code instead of a storage-owned package.
The storage package is introduced to make application data persistence a package-level capability with explicit contracts, a clear boundary, and SQL backend compatibility.
## Goals
- Provide a standalone `packages/storage` package for durable application data.
- Support SQLite, PostgreSQL, and MySQL through a shared persistence construction flow.
- Keep LangGraph checkpointer initialization compatible with the same database backend.
- Expose repository contracts as the only package-level data access boundary.
- Let the app layer depend on app-owned adapters under `app.infra.storage`, not on storage DB implementation classes.
- Allow the app/gateway migration to happen in small steps without forcing a large rewrite.
## Non-Goals
- This design does not remove legacy persistence in the first PR.
- This design does not move routers directly onto storage package models.
- This design does not make app routers own SQLAlchemy sessions.
- Cron persistence is intentionally out of scope for the storage package foundation.
- Memory backend is not part of the durable storage package. Memory compatibility, if still needed by app runtime, belongs outside `packages/storage`.
## Storage Design Principles
### Package-Owned Durable Storage
`packages/storage` owns durable application data persistence. It defines:
- configuration shape for storage-backed persistence
- SQLAlchemy models
- repository contracts and DTOs
- SQL repository implementations
- persistence factory functions
- compatibility helpers for config-driven initialization
The package should be usable without importing `app.gateway`, routers, auth providers, or runtime-specific gateway objects.
### SQL Backend Compatibility
The package supports three SQL backends:
- SQLite for local/single-node deployments
- PostgreSQL for production multi-node deployments
- MySQL for deployments that standardize on MySQL
Backend-specific differences are handled inside the storage package:
- SQLAlchemy async engine URL construction
- LangGraph checkpointer connection-string compatibility
- JSON metadata filtering across SQLite/PostgreSQL/MySQL
- SQL dialect behavior around locking, aggregation, and JSON type semantics
### Unified Persistence Bundle
Storage initialization returns an `AppPersistence` bundle:
```python
@dataclass(slots=True)
class AppPersistence:
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: Callable[[], Awaitable[None]]
aclose: Callable[[], Awaitable[None]]
```
The app runtime can initialize persistence once, call `setup()`, and then inject:
- `checkpointer`
- `session_factory`
- repository adapters
This keeps checkpointer and application data aligned to the same backend without requiring routers to understand database configuration.
## Package Layout
```text
backend/packages/storage/
store/
config/
storage_config.py
app_config.py
persistence/
factory.py
types.py
base_model.py
json_compat.py
drivers/
sqlite.py
postgres.py
mysql.py
repositories/
contracts/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
models/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
db/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
factory.py
```
## Persistence Construction
The primary storage entrypoint is:
```python
from store.persistence import create_persistence_from_storage_config
persistence = await create_persistence_from_storage_config(storage_config)
await persistence.setup()
```
For app-level compatibility with existing database config shape:
```python
from store.persistence import create_persistence_from_database_config
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
```
Expected app startup flow:
```python
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
app.state.persistence = persistence
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
Expected app shutdown flow:
```python
await app.state.persistence.aclose()
```
## Repository Contract Design
Repository contracts are the storage package's public data access boundary. They live under `store.repositories.contracts` and are re-exported from `store.repositories`.
The key contract groups are:
- `UserRepositoryProtocol`
- `RunRepositoryProtocol`
- `ThreadMetaRepositoryProtocol`
- `FeedbackRepositoryProtocol`
- `RunEventRepositoryProtocol`
Each contract owns:
- input DTOs, such as `UserCreate`, `RunCreate`, `ThreadMetaCreate`
- output DTOs, such as `User`, `Run`, `ThreadMeta`
- repository protocol methods
- domain-specific exceptions when needed, such as `InvalidMetadataFilterError`
Repository construction is session-based:
```python
from store.repositories import build_run_repository
async with persistence.session_factory() as session:
repo = build_run_repository(session)
run = await repo.get_run(run_id)
```
This keeps transaction ownership explicit. The storage package does not hide commits or session lifecycle inside global singletons.
## App/Infra Calling Contract
The app layer should not call `store.repositories.db.*` directly. The intended app boundary is `app.infra.storage`.
`app.infra.storage` is responsible for:
- receiving `session_factory` from FastAPI runtime initialization
- owning session lifecycle for app-facing repository methods
- translating storage DTOs to app/gateway DTOs only when needed
- preserving the existing app-facing names during migration
- depending on storage repository protocols, not concrete DB classes
Expected adapter pattern:
```python
class StorageRunRepository(RunRepositoryProtocol):
def __init__(self, session_factory):
self._session_factory = session_factory
async def get_run(self, run_id: str):
async with self._session_factory() as session:
repo = build_run_repository(session)
return await repo.get_run(run_id)
```
For gateway compatibility, app state can keep existing names while the implementation changes:
```python
app.state.run_store = StorageRunStore(run_repository)
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
app.state.run_event_store = StorageRunEventStore(run_event_repository)
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
The app-facing objects may expose legacy method names during migration, but their internal data access should go through storage contracts.
## Boundary Rules
### Allowed Calls
Storage package callers may use:
```python
from store.persistence import create_persistence_from_database_config
from store.persistence import create_persistence_from_storage_config
from store.repositories import build_run_repository
from store.repositories import build_user_repository
from store.repositories import build_thread_meta_repository
from store.repositories import build_feedback_repository
from store.repositories import build_run_event_repository
from store.repositories import RunRepositoryProtocol
from store.repositories import UserRepositoryProtocol
```
App layer callers should use:
```python
from app.infra.storage import StorageRunRepository
from app.infra.storage import StorageUserDataRepository
from app.infra.storage import StorageThreadMetaRepository
from app.infra.storage import StorageFeedbackRepository
from app.infra.storage import StorageRunEventRepository
```
### Prohibited Calls
App/gateway/router/auth code must not import:
```python
from store.repositories.db import DbRunRepository
from store.repositories.models import Run
from store.persistence.base_model import MappedBase
```
Routers must not:
- create SQLAlchemy engines
- create SQLAlchemy sessions directly
- call storage DB repository classes directly
- commit/rollback storage transactions directly unless explicitly scoped by an infra adapter
- depend on storage SQLAlchemy model classes
Storage package code must not import:
```python
import app.gateway
import app.infra
import deerflow.runtime
```
The dependency direction is:
```text
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
```
The reverse direction is forbidden.
## Checkpointer Compatibility
The storage persistence bundle initializes the LangGraph checkpointer alongside application data persistence.
Backend-specific notes:
- SQLite uses `langgraph-checkpoint-sqlite`.
- PostgreSQL uses `langgraph-checkpoint-postgres` and requires a string `postgresql://...` connection URL.
- MySQL uses `langgraph-checkpoint-mysql` and requires a string MySQL connection URL.
SQLAlchemy may use async driver URLs such as `postgresql+asyncpg://...` or `mysql+aiomysql://...`, but LangGraph checkpointer constructors expect plain string connection URLs. This conversion belongs inside the storage driver implementation.
## JSON Metadata Filtering
Thread metadata search supports dialect-aware JSON filtering through `store.persistence.json_compat`.
The matcher supports:
- `None`
- `bool`
- `int`
- `float`
- `str`
It rejects:
- unsafe keys
- nested JSON path expressions
- dict/list values
- integers outside signed 64-bit range
This prevents SQL/JSON path injection, avoids compiled-cache type drift, and preserves type semantics such as `True != 1` and explicit JSON `null` not matching a missing key.
## Step-by-Step Implementation Plan
### Step 1: Introduce Storage Package Foundation
- Add `backend/packages/storage`.
- Add storage config models.
- Add `AppPersistence`.
- Add SQLite/PostgreSQL/MySQL persistence drivers.
- Add repository contracts, models, DB implementations, and factory helpers.
- Add package dependency wiring.
- Exclude cron persistence.
### Step 2: Harden Storage Backend Compatibility
- Validate SQLite setup and repository behavior.
- Validate PostgreSQL and MySQL with local E2E tests.
- Fix checkpointer connection-string compatibility.
- Fix PostgreSQL locking and aggregation differences.
- Add dialect-aware JSON metadata filtering.
### Step 3: Add App Infra Adapters
- Add `backend/app/infra/storage`.
- Implement app-facing repositories that own session lifecycle.
- Keep storage contracts as the only data access boundary.
- Add legacy compatibility adapters for existing app/gateway method shapes.
- Keep app/gateway imports out of `packages/storage`.
### Step 4: Switch FastAPI Runtime Injection
- Initialize storage persistence in FastAPI startup/lifespan.
- Attach `persistence`, `checkpointer`, and `session_factory` to `app.state`.
- Preserve existing external state names:
- `run_store`
- `feedback_repo`
- `thread_store`
- `run_event_store`
- `checkpointer`
- `session_factory`
- Start with user/auth provider construction, then migrate run/thread/feedback/run_event.
### Step 5: Router and Auth Compatibility
- Ensure routers consume app-facing adapters, not storage DB classes.
- Ensure auth providers depend on user repository contracts.
- Keep router response shapes unchanged.
- Add focused auth/admin/router regression tests.
### Step 6: Cleanup Legacy Persistence
- Compare old persistence usage after app/gateway migration.
- Remove unused old repository implementations only after all call sites move.
- Keep compatibility shims only where needed for a transition window.
- Delete memory backend paths from storage-owned durable persistence.
## Testing Strategy
Unit tests should cover:
- config parsing
- persistence setup
- table creation
- repository CRUD/query behavior
- typed JSON metadata filtering
- dialect SQL compilation
- cron exclusion
E2E tests should cover:
- SQLite persistence setup
- PostgreSQL temporary database setup
- MySQL temporary database setup
- repository contract behavior across all supported SQL backends
- JSON/Unicode round trip
- rollback behavior
- persistence close/cleanup
E2E tests may remain local-only if CI does not provide PostgreSQL/MySQL services.
-401
View File
@@ -1,401 +0,0 @@
# Storage Package 设计文档
## 背景
DeerFlow 当前有多类持久化职责分散在 app、gateway、runtime 和旧 persistence 模块中。这会带来几个问题:
- routers 和 runtime services 容易依赖具体 persistence 实现,而不是稳定契约。
- user/auth、run metadata、thread metadata、feedback、run events、checkpointer setup 的初始化路径不统一。
- memory、SQLite、PostgreSQL 相关路径中存在部分重复逻辑。
- app 层代码和 storage 层代码耦合,导致增量迁移困难。
- 增加或验证新的 SQL backend 时,需要改动 app/runtime,而不是只改 storage package。
引入 storage package 的目标,是把应用数据持久化抽象成 package 级能力,并提供明确契约、清晰边界和 SQL backend 兼容性。
## 目标
- 新增独立的 `packages/storage`,负责 durable application data。
- 通过统一 persistence 构造流程支持 SQLite、PostgreSQL、MySQL。
- 保持 LangGraph checkpointer 与同一个数据库 backend 兼容。
- 将 repository contracts 作为 package 对外唯一数据访问边界。
- app 层通过 `app.infra.storage` 适配 storage,而不是直接依赖 storage DB 实现类。
- 支持 app/gateway 后续小步迁移,避免一次性大重构。
## 非目标
- 第一阶段不删除旧 persistence。
- 不让 routers 直接依赖 storage package models。
- 不让 app routers 管理 SQLAlchemy sessions。
- cron persistence 不属于 storage package 基础迁移范围。
- memory backend 不属于 durable storage package。若 app runtime 仍需要 memory 兼容,应放在 `packages/storage` 之外。
## Storage 设计理念
### Package 自己负责 Durable Storage
`packages/storage` 负责应用数据的 durable persistence,包括:
- storage 持久化配置
- SQLAlchemy models
- repository contracts 和 DTOs
- SQL repository 实现
- persistence factory functions
- 面向现有 config 的兼容初始化入口
该 package 不应该 import `app.gateway`、routers、auth providers 或 runtime 中的 gateway 对象。
### SQL Backend 兼容
该 package 支持三种 SQL backend
- SQLite:本地或单节点部署
- PostgreSQL:生产多节点部署
- MySQL:使用 MySQL 作为标准数据库的部署
backend 差异在 storage package 内部处理:
- SQLAlchemy async engine URL 构造
- LangGraph checkpointer 连接串兼容
- SQLite/PostgreSQL/MySQL 的 JSON metadata filter
- 不同 SQL 方言在 locking、aggregation、JSON 类型语义上的差异
### 统一 Persistence Bundle
Storage 初始化返回 `AppPersistence` bundle
```python
@dataclass(slots=True)
class AppPersistence:
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: Callable[[], Awaitable[None]]
aclose: Callable[[], Awaitable[None]]
```
app runtime 只需要初始化一次 persistence,调用 `setup()`,然后注入:
- `checkpointer`
- `session_factory`
- repository adapters
这样 checkpointer 和应用数据可以对齐到同一个 backend,同时 routers 不需要理解数据库配置。
## Package 结构
```text
backend/packages/storage/
store/
config/
storage_config.py
app_config.py
persistence/
factory.py
types.py
base_model.py
json_compat.py
drivers/
sqlite.py
postgres.py
mysql.py
repositories/
contracts/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
models/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
db/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
factory.py
```
## Persistence 构造
storage 的主要入口:
```python
from store.persistence import create_persistence_from_storage_config
persistence = await create_persistence_from_storage_config(storage_config)
await persistence.setup()
```
为了兼容现有 app database config,也提供:
```python
from store.persistence import create_persistence_from_database_config
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
```
预期 app startup 流程:
```python
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
app.state.persistence = persistence
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
预期 app shutdown 流程:
```python
await app.state.persistence.aclose()
```
## Repository 契约设计
Repository contracts 是 storage package 对外公开的数据访问边界。它们位于 `store.repositories.contracts`,并通过 `store.repositories` re-export。
主要契约包括:
- `UserRepositoryProtocol`
- `RunRepositoryProtocol`
- `ThreadMetaRepositoryProtocol`
- `FeedbackRepositoryProtocol`
- `RunEventRepositoryProtocol`
每组契约包含:
- 输入 DTO,例如 `UserCreate``RunCreate``ThreadMetaCreate`
- 输出 DTO,例如 `User``Run``ThreadMeta`
- repository protocol methods
- 必要的领域异常,例如 `InvalidMetadataFilterError`
Repository 通过 session 构造:
```python
from store.repositories import build_run_repository
async with persistence.session_factory() as session:
repo = build_run_repository(session)
run = await repo.get_run(run_id)
```
这样可以让 transaction ownership 保持明确。storage package 不通过全局 singleton 隐式隐藏 commit 或 session 生命周期。
## App/Infra 调用契约
app 层不应该直接调用 `store.repositories.db.*`。预期的 app 边界是 `app.infra.storage`
`app.infra.storage` 负责:
- 从 FastAPI runtime 初始化中接收 `session_factory`
- 为 app-facing repository methods 管理 session 生命周期
- 在必要时将 storage DTOs 转成 app/gateway DTOs
- 迁移期间保留现有 app-facing 名称
- 依赖 storage repository protocols,而不是具体 DB classes
预期 adapter 模式:
```python
class StorageRunRepository(RunRepositoryProtocol):
def __init__(self, session_factory):
self._session_factory = session_factory
async def get_run(self, run_id: str):
async with self._session_factory() as session:
repo = build_run_repository(session)
return await repo.get_run(run_id)
```
为了兼容 gatewayapp state 可以暂时保持现有名字,只替换内部实现:
```python
app.state.run_store = StorageRunStore(run_repository)
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
app.state.run_event_store = StorageRunEventStore(run_event_repository)
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
app-facing objects 可以在迁移期间保留旧方法名,但内部数据访问必须经过 storage contracts。
## 边界规则
### 允许调用的范围
storage package 调用方可以使用:
```python
from store.persistence import create_persistence_from_database_config
from store.persistence import create_persistence_from_storage_config
from store.repositories import build_run_repository
from store.repositories import build_user_repository
from store.repositories import build_thread_meta_repository
from store.repositories import build_feedback_repository
from store.repositories import build_run_event_repository
from store.repositories import RunRepositoryProtocol
from store.repositories import UserRepositoryProtocol
```
app 层应该使用:
```python
from app.infra.storage import StorageRunRepository
from app.infra.storage import StorageUserDataRepository
from app.infra.storage import StorageThreadMetaRepository
from app.infra.storage import StorageFeedbackRepository
from app.infra.storage import StorageRunEventRepository
```
### 禁止调用的范围
app/gateway/router/auth 代码不应该 import
```python
from store.repositories.db import DbRunRepository
from store.repositories.models import Run
from store.persistence.base_model import MappedBase
```
routers 禁止:
- 创建 SQLAlchemy engines
- 直接创建 SQLAlchemy sessions
- 直接调用 storage DB repository classes
- 直接 commit/rollback storage transactions,除非这是 infra adapter 明确管理的范围
- 依赖 storage SQLAlchemy model classes
storage package 禁止 import
```python
import app.gateway
import app.infra
import deerflow.runtime
```
依赖方向必须是:
```text
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
```
禁止反向依赖。
## Checkpointer 兼容
storage persistence bundle 会同时初始化 LangGraph checkpointer 和应用数据持久化。
backend 说明:
- SQLite 使用 `langgraph-checkpoint-sqlite`
- PostgreSQL 使用 `langgraph-checkpoint-postgres`,需要字符串形式的 `postgresql://...` 连接串。
- MySQL 使用 `langgraph-checkpoint-mysql`,需要字符串形式的 MySQL 连接串。
SQLAlchemy 可以使用 `postgresql+asyncpg://...``mysql+aiomysql://...` 这类 async driver URL,但 LangGraph checkpointer 构造函数需要普通字符串连接串。这个转换应该封装在 storage driver implementation 内部。
## JSON Metadata Filtering
Thread metadata search 通过 `store.persistence.json_compat` 支持跨方言 JSON filtering。
支持的 filter value 类型:
- `None`
- `bool`
- `int`
- `float`
- `str`
拒绝:
- unsafe keys
- nested JSON path expressions
- dict/list values
- 超出 signed 64-bit 范围的整数
这样可以避免 SQL/JSON path injection,避免 compiled-cache 类型漂移,并保留类型语义,例如 `True != 1`,显式 JSON `null` 不等于 missing key。
## 分步实现方案
### 第 1 步:新增 Storage Package 基础
- 新增 `backend/packages/storage`
- 增加 storage config models。
- 增加 `AppPersistence`
- 增加 SQLite/PostgreSQL/MySQL persistence drivers。
- 增加 repository contracts、models、DB implementations 和 factory helpers。
- 接入 package dependency。
- 排除 cron persistence。
### 第 2 步:补齐 Storage Backend 兼容性
- 验证 SQLite setup 和 repository 行为。
- 使用本地 E2E 验证 PostgreSQL 和 MySQL。
- 修复 checkpointer 连接串兼容。
- 修复 PostgreSQL locking 和 aggregation 差异。
- 增加跨方言 JSON metadata filtering。
### 第 3 步:新增 App Infra Adapters
- 新增 `backend/app/infra/storage`
- 实现 app-facing repositories,由它们管理 session 生命周期。
- 保持 storage contracts 作为唯一数据访问边界。
- 为现有 app/gateway method shape 增加兼容 adapters。
- 避免 `packages/storage` import app/gateway。
### 第 4 步:切换 FastAPI Runtime 注入
- 在 FastAPI startup/lifespan 中初始化 storage persistence。
- 将 `persistence``checkpointer``session_factory` 注入 `app.state`
- 暂时保留现有对外 state 名称:
- `run_store`
- `feedback_repo`
- `thread_store`
- `run_event_store`
- `checkpointer`
- `session_factory`
- 先切 user/auth provider 构造,再逐步迁移 run/thread/feedback/run_event。
### 第 5 步:Router 和 Auth 兼容
- 确保 routers 消费 app-facing adapters,而不是 storage DB classes。
- 确保 auth providers 依赖 user repository contracts。
- 保持 router response shapes 不变。
- 增加 auth/admin/router regression tests。
### 第 6 步:清理旧 Persistence
- app/gateway 迁移完成后,再比较旧 persistence usage。
- 所有 call sites 迁移完成后,再删除未使用的旧 repository implementations。
- 只在必要时保留短期 compatibility shims。
- 从 storage-owned durable persistence 中移除 memory backend 路径。
## 测试策略
单测应覆盖:
- config parsing
- persistence setup
- table creation
- repository CRUD/query behavior
- typed JSON metadata filtering
- dialect SQL compilation
- cron exclusion
E2E 应覆盖:
- SQLite persistence setup
- PostgreSQL temporary database setup
- MySQL temporary database setup
- 所有支持 SQL backend 下的 repository contract 行为
- JSON/Unicode round trip
- rollback behavior
- persistence close/cleanup
如果 CI 暂时没有 PostgreSQL/MySQL servicesE2E 可以先作为 local-only 验证保留。
@@ -40,15 +40,6 @@ class MemoryUpdateQueue:
self._timer: threading.Timer | None = None
self._processing = False
@staticmethod
def _queue_key(
thread_id: str,
user_id: str | None,
agent_name: str | None,
) -> tuple[str, str | None, str | None]:
"""Return the debounce identity for a memory update target."""
return (thread_id, user_id, agent_name)
def add(
self,
thread_id: str,
@@ -124,9 +115,8 @@ class MemoryUpdateQueue:
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
queue_key = self._queue_key(thread_id, user_id, agent_name)
existing_context = next(
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
@@ -140,7 +130,7 @@ class MemoryUpdateQueue:
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
def _reset_timer(self) -> None:
@@ -6,7 +6,6 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import resolve_runtime_user_id
def memory_flush_hook(event: SummarizationEvent) -> None:
@@ -22,13 +21,11 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
user_id = resolve_runtime_user_id(event.runtime)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -36,73 +36,42 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
@staticmethod
def _message_tool_calls(msg) -> list[dict]:
"""Return normalized tool calls from structured fields or raw provider payloads.
LangChain stores malformed provider function calls in ``invalid_tool_calls``.
They do not execute, but provider adapters may still serialize enough of
the call id/name back into the next request that strict OpenAI-compatible
validators expect a matching ToolMessage. Treat them as dangling calls so
the next model request stays well-formed and the model sees a recoverable
tool error instead of another provider 400.
"""
normalized: list[dict] = []
"""Return normalized tool calls from structured fields or raw provider payloads."""
tool_calls = getattr(msg, "tool_calls", None) or []
normalized.extend(list(tool_calls))
if tool_calls:
return list(tool_calls)
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
if not tool_calls:
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
for invalid_tc in getattr(msg, "invalid_tool_calls", None) or []:
if not isinstance(invalid_tc, dict):
normalized: list[dict] = []
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": invalid_tc.get("id"),
"name": invalid_tc.get("name") or "unknown",
"args": {},
"invalid": True,
"error": invalid_tc.get("error"),
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
return normalized
@staticmethod
def _synthetic_tool_message_content(tool_call: dict) -> str:
if tool_call.get("invalid"):
error = tool_call.get("error")
if isinstance(error, str) and error:
return f"[Tool call could not be executed because its arguments were invalid: {error}]"
return "[Tool call could not be executed because its arguments were invalid.]"
return "[Tool call was interrupted and did not return a result.]"
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
@@ -145,7 +114,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content=self._synthetic_tool_message_content(tc),
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
@@ -9,7 +9,7 @@ from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@@ -217,17 +217,6 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
return "thinking"
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
"""Return True if the AIMessage contains a tool_call with the given id."""
for tc in message.tool_calls or []:
if isinstance(tc, dict):
if tc.get("id") == tool_call_id:
return True
elif hasattr(tc, "id") and tc.id == tool_call_id:
return True
return False
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
@@ -272,51 +261,8 @@ class TokenUsageMiddleware(AgentMiddleware):
if not messages:
return None
# Annotate subagent token usage onto the AIMessage that dispatched it.
# When a task tool completes, its usage is cached by tool_call_id. Detect
# the ToolMessage → search backward for the corresponding AIMessage → merge.
# Walk backward through consecutive ToolMessages before the new AIMessage
# so that multiple concurrent task tool calls all get their subagent tokens
# written back to the same dispatch message (merging into one update).
state_updates: dict[int, AIMessage] = {}
if len(messages) >= 2:
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
idx = len(messages) - 2
while idx >= 0:
tool_msg = messages[idx]
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
break
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
if subagent_usage:
# Search backward from the ToolMessage to find the AIMessage
# that dispatched it. A single model response can dispatch
# multiple task tool calls, so we can't assume a fixed offset.
dispatch_idx = idx - 1
while dispatch_idx >= 0:
candidate = messages[dispatch_idx]
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
# Accumulate into an existing update for the same
# AIMessage (multiple task calls in one response),
# or merge fresh from the original message.
existing_update = state_updates.get(dispatch_idx)
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
merged = {
**prev,
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
}
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
break
dispatch_idx -= 1
idx -= 1
last = messages[-1]
if not isinstance(last, AIMessage):
if state_updates:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return None
usage = getattr(last, "usage_metadata", None)
@@ -342,12 +288,11 @@ class TokenUsageMiddleware(AgentMiddleware):
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
state_updates[len(messages) - 1] = updated_msg
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
@@ -80,6 +80,7 @@ class AioSandboxProvider(SandboxProvider):
port: 8080 # Base port for local containers
container_prefix: deer-flow-sandbox
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
auto_restart: true # Restart crashed containers automatically
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
mounts: # Volume mounts for local containers
- host_path: /path/on/host
@@ -164,12 +165,14 @@ class AioSandboxProvider(SandboxProvider):
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
auto_restart = getattr(sandbox_config, "auto_restart", True)
return {
"image": sandbox_config.image or DEFAULT_IMAGE,
"port": sandbox_config.port or DEFAULT_PORT,
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
"auto_restart": auto_restart,
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
"mounts": sandbox_config.mounts or [],
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
@@ -608,18 +611,58 @@ class AioSandboxProvider(SandboxProvider):
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp.
When ``auto_restart`` is enabled (the default), the container's liveness
is verified on each lookup. If the underlying container has crashed, the
sandbox is evicted from all caches so that the next ``acquire()`` call will
transparently create a fresh container.
Args:
sandbox_id: The ID of the sandbox.
Returns:
The sandbox instance if found, None otherwise.
The sandbox instance if found and alive, None otherwise.
"""
with self._lock:
sandbox = self._sandboxes.get(sandbox_id)
if sandbox is not None:
self._last_activity[sandbox_id] = time.time()
if sandbox is None:
return None
self._last_activity[sandbox_id] = time.time()
auto_restart = self._config.get("auto_restart", True)
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
if not info:
return sandbox
if self._backend.is_alive(info):
return sandbox
info_to_destroy = None
with self._lock:
current_sandbox = self._sandboxes.get(sandbox_id)
current_info = self._sandbox_infos.get(sandbox_id)
if current_sandbox is None:
return None
if current_info is not info:
self._last_activity[sandbox_id] = time.time()
return current_sandbox
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
self._sandboxes.pop(sandbox_id, None)
self._sandbox_infos.pop(sandbox_id, None)
self._last_activity.pop(sandbox_id, None)
self._warm_pool.pop(sandbox_id, None)
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids:
del self._thread_sandboxes[tid]
info_to_destroy = info
if info_to_destroy:
try:
self._backend.destroy(info_to_destroy)
except Exception as e:
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
return None
def release(self, sandbox_id: str) -> None:
"""Release a sandbox from active use into the warm pool.
@@ -23,6 +23,9 @@ class SandboxConfig(BaseModel):
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
container_prefix: Prefix for container names (default: deer-flow-sandbox)
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
on the next acquire. Set to false to disable.
mounts: List of volume mounts to share directories with the container
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
"""
@@ -55,6 +58,10 @@ class SandboxConfig(BaseModel):
default=None,
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
)
auto_restart: bool = Field(
default=True,
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
)
mounts: list[VolumeMountConfig] = Field(
default_factory=list,
description="List of volume mounts to share directories between host and container",
+43 -2
View File
@@ -1,6 +1,11 @@
"""Load MCP tools using langchain-mcp-adapters."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
from langchain_core.tools import BaseTool
@@ -8,10 +13,46 @@ from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
# Global thread pool for sync tool invocation in async environments
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
# Register shutdown hook for the global executor
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine.
Args:
coro: The tool's asynchronous coroutine.
tool_name: Name of the tool (for logging).
Returns:
A synchronous function that correctly handles nested event loops.
"""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
# Use global executor to avoid nested loop issues and improve performance
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
else:
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
raise
return sync_wrapper
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
@@ -85,7 +126,7 @@ async def get_mcp_tools() -> list[BaseTool]:
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
@@ -1,195 +0,0 @@
"""Dialect-aware JSON value matching for SQLAlchemy (SQLite + PostgreSQL)."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
from sqlalchemy import BigInteger, Float, String, bindparam
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.types import Boolean, TypeEngine
# Key is interpolated into compiled SQL; restrict charset to prevent injection.
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
# Allowed value types for metadata filter values (same set accepted by JsonMatch).
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
# SQLite raises an overflow when binding values outside signed 64-bit range;
# PostgreSQL overflows during BIGINT cast. Reject at validation time instead.
_INT64_MIN = -(2**63)
_INT64_MAX = 2**63 - 1
def validate_metadata_filter_key(key: object) -> bool:
"""Return True if *key* is safe for use as a JSON metadata filter key.
A key is "safe" when it is a string matching ``[A-Za-z0-9_-]+``. The
charset is restricted because the key is interpolated into the
compiled SQL path expression (``$."<key>"`` / ``->`` literal), so any
laxer pattern would open a SQL/JSONPath injection surface.
"""
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
def validate_metadata_filter_value(value: object) -> bool:
"""Return True if *value* is an allowed type for a JSON metadata filter.
Matches the set of types ``_build_clause`` knows how to compile into
a dialect-portable predicate. Anything else (list/dict/bytes/...) is
intentionally rejected rather than silently coerced via ``str()``
silent coercion would (a) produce wrong matches and (b) break
SQLAlchemy's ``inherit_cache`` invariant when ``value`` is unhashable.
Integer values are additionally restricted to the signed 64-bit range
``[-2**63, 2**63 - 1]``: SQLite overflows when binding larger values
and PostgreSQL overflows during the ``BIGINT`` cast.
"""
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
return False
if isinstance(value, int) and not isinstance(value, bool):
if not (_INT64_MIN <= value <= _INT64_MAX):
return False
return True
class JsonMatch(ColumnElement):
"""Dialect-portable ``column[key] == value`` for JSON columns.
Compiles to ``json_type``/``json_extract`` on SQLite and
``json_typeof``/``->>`` on PostgreSQL, with type-safe comparison
that distinguishes bool vs int and NULL vs missing key.
*key* must be a single literal key matching ``[A-Za-z0-9_-]+``.
*value* must be one of: ``None``, ``bool``, ``int`` (signed 64-bit), ``float``, ``str``.
"""
inherit_cache = True
type = Boolean()
_is_implicitly_boolean = True
_traverse_internals = [
("column", InternalTraversal.dp_clauseelement),
("key", InternalTraversal.dp_string),
("value", InternalTraversal.dp_plain_obj),
]
def __init__(self, column: ColumnElement, key: str, value: object) -> None:
if not validate_metadata_filter_key(key):
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
if not validate_metadata_filter_value(value):
if isinstance(value, int) and not isinstance(value, bool):
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
self.column = column
self.key = key
self.value = value
super().__init__()
@dataclass(frozen=True)
class _Dialect:
"""Per-dialect names used when emitting JSON type/value comparisons."""
null_type: str
num_types: tuple[str, ...]
num_cast: str
int_types: tuple[str, ...]
int_cast: str
# None for SQLite where json_type already returns 'integer'/'real';
# regex literal for PostgreSQL where json_typeof returns 'number' for
# both ints and floats, so an extra guard prevents CAST errors on floats.
int_guard: str | None
string_type: str
bool_type: str | None
_SQLITE = _Dialect(
null_type="null",
num_types=("integer", "real"),
num_cast="REAL",
int_types=("integer",),
int_cast="INTEGER",
int_guard=None,
string_type="text",
bool_type=None,
)
_PG = _Dialect(
null_type="null",
num_types=("number",),
num_cast="DOUBLE PRECISION",
int_types=("number",),
int_cast="BIGINT",
int_guard="'^-?[0-9]+$'",
string_type="string",
bool_type="boolean",
)
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
param = bindparam(None, value, type_=sa_type)
return compiler.process(param, **kw)
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
if len(types) == 1:
return f"{typeof} = '{types[0]}'"
quoted = ", ".join(f"'{t}'" for t in types)
return f"{typeof} IN ({quoted})"
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
if value is None:
return f"{typeof} = '{dialect.null_type}'"
if isinstance(value, bool):
# bool check must precede int check — bool is a subclass of int in Python
bool_str = "true" if value else "false"
if dialect.bool_type is None:
return f"{typeof} = '{bool_str}'"
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
if isinstance(value, int):
bp = _bind(compiler, value, BigInteger(), **kw)
if dialect.int_guard:
# CASE prevents CAST error when json_typeof = 'number' also matches floats
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
if isinstance(value, float):
bp = _bind(compiler, value, Float(), **kw)
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
bp = _bind(compiler, str(value), String(), **kw)
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
@compiles(JsonMatch, "sqlite")
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"json_type({col}, '{path}')"
extract = f"json_extract({col}, '{path}')"
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
@compiles(JsonMatch, "postgresql")
def _compile_pg(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
typeof = f"json_typeof({col} -> '{element.key}')"
extract = f"({col} ->> '{element.key}')"
return _build_clause(compiler, typeof, extract, element.value, _PG, **kw)
@compiles(JsonMatch)
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
raise NotImplementedError(f"JsonMatch supports only sqlite and postgresql; got dialect: {compiler.dialect.name}")
def json_match(column: ColumnElement, key: str, value: object) -> JsonMatch:
return JsonMatch(column, key, value)
@@ -23,18 +23,6 @@ class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _normalize_model_name(model_name: str | None) -> str | None:
"""Normalize model_name for storage: strip whitespace, truncate to 128 chars."""
if model_name is None:
return None
if not isinstance(model_name, str):
model_name = str(model_name)
normalized = model_name.strip()
if len(normalized) > 128:
normalized = normalized[:128]
return normalized
@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
@@ -82,7 +70,6 @@ class RunRepository(RunStore):
thread_id,
assistant_id=None,
user_id: str | None | _AutoSentinel = AUTO,
model_name: str | None = None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -98,7 +85,6 @@ class RunRepository(RunStore):
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
model_name=self._normalize_model_name(model_name),
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
@@ -223,11 +209,10 @@ class RunRepository(RunStore):
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
stmt = (
select(
model_name.label("model"),
func.coalesce(RunRow.model_name, "unknown").label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
@@ -237,7 +222,7 @@ class RunRepository(RunStore):
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(model_name)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
)
async with self._sf() as session:
@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
__all__ = [
"InvalidMetadataFilterError",
"MemoryThreadMetaStore",
"ThreadMetaRepository",
"ThreadMetaRow",
@@ -15,15 +15,10 @@ three-state semantics (see :mod:`deerflow.runtime.user_context`):
from __future__ import annotations
import abc
from typing import Any
from deerflow.runtime.user_context import AUTO, _AutoSentinel
class InvalidMetadataFilterError(ValueError):
"""Raised when all client-supplied metadata filter keys are rejected."""
class ThreadMetaStore(abc.ABC):
@abc.abstractmethod
async def create(
@@ -45,12 +40,12 @@ class ThreadMetaStore(abc.ABC):
async def search(
self,
*,
metadata: dict[str, Any] | None = None,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict[str, Any]]:
) -> list[dict]:
pass
@abc.abstractmethod
@@ -69,12 +69,12 @@ class MemoryThreadMetaStore(ThreadMetaStore):
async def search(
self,
*,
metadata: dict[str, Any] | None = None,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict[str, Any]]:
) -> list[dict]:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
filter_dict: dict[str, Any] = {}
if metadata:
@@ -2,20 +2,16 @@
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.json_compat import json_match
from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
logger = logging.getLogger(__name__)
class ThreadMetaRepository(ThreadMetaStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
@@ -24,7 +20,7 @@ class ThreadMetaRepository(ThreadMetaStore):
@staticmethod
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
d = row.to_dict()
d["metadata"] = d.pop("metadata_json", None) or {}
d["metadata"] = d.pop("metadata_json", {})
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
@@ -108,43 +104,39 @@ class ThreadMetaRepository(ThreadMetaStore):
async def search(
self,
*,
metadata: dict[str, Any] | None = None,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict[str, Any]]:
) -> list[dict]:
"""Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user
context. Pass ``user_id=None`` to bypass (migration/CLI).
"""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc(), ThreadMetaRow.thread_id.desc())
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_user_id is not None:
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
if status:
stmt = stmt.where(ThreadMetaRow.status == status)
if metadata:
applied = 0
for key, value in metadata.items():
try:
stmt = stmt.where(json_match(ThreadMetaRow.metadata_json, key, value))
applied += 1
except (ValueError, TypeError) as exc:
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
if applied == 0:
# Comma-separated plain string (no list repr / nested
# quoting) so the 400 detail surfaced by the Gateway is
# easy for clients to read. Sorted for determinism.
rejected_keys = ", ".join(sorted(str(k) for k in metadata))
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
stmt = stmt.limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
# When metadata filter is active, fetch a larger window and filter
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
# SQLite json_extract) for server-side filtering.
stmt = stmt.limit(limit * 5 + offset)
async with self._sf() as session:
result = await session.execute(stmt)
rows = [self._row_to_dict(r) for r in result.scalars()]
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
return rows[offset : offset + limit]
else:
stmt = stmt.limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed)."""
@@ -11,7 +11,7 @@ import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select, text
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
@@ -86,28 +86,6 @@ class DbRunEventStore(RunEventStore):
user = get_current_user()
return str(user.id) if user is not None else None
@staticmethod
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
"""Return the current max seq while serializing writers per thread.
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
results are not lockable rows. As a release-safe workaround, take a
transaction-level advisory lock keyed by thread_id before reading the
aggregate. Other dialects keep the existing row-locking statement.
"""
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
bind = session.get_bind()
dialect_name = bind.dialect.name if bind is not None else ""
if dialect_name == "postgresql":
await session.execute(
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
{"thread_id": thread_id},
)
return await session.scalar(stmt)
return await session.scalar(stmt.with_for_update())
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
@@ -122,7 +100,10 @@ class DbRunEventStore(RunEventStore):
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
max_seq = await self._max_seq_for_thread(session, thread_id)
# Use FOR UPDATE to serialize seq assignment within a thread.
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
@@ -145,8 +126,10 @@ class DbRunEventStore(RunEventStore):
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
thread_id = events[0]["thread_id"]
max_seq = await self._max_seq_for_thread(session, thread_id)
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
seq = max_seq or 0
rows = []
for e in events:
@@ -20,13 +20,12 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Mapping
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.types import Command
if TYPE_CHECKING:
@@ -64,16 +63,6 @@ class RunJournal(BaseCallbackHandler):
self._total_tokens = 0
self._llm_call_count = 0
# Caller-bucketed token accumulators
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
self._counted_llm_run_ids: set[str] = set()
self._counted_external_source_ids: set[str] = set()
self._counted_message_llm_run_ids: set[str] = set()
# Convenience fields
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
@@ -88,50 +77,6 @@ class RunJournal(BaseCallbackHandler):
# -- Lifecycle callbacks --
@staticmethod
def _message_text(message: BaseMessage) -> str:
"""Extract displayable text from a message's mixed content shape."""
content = getattr(message, "content", None)
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, Mapping):
text = block.get("text")
if isinstance(text, str):
parts.append(text)
else:
nested = block.get("content")
if isinstance(nested, str):
parts.append(nested)
return "".join(parts)
if isinstance(content, Mapping):
for key in ("text", "content"):
value = content.get(key)
if isinstance(value, str):
return value
text = getattr(message, "text", None)
if isinstance(text, str):
return text
return ""
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
"""Update run-level convenience fields for persisted run rows."""
self._msg_count += 1
# ``last_ai_message`` should represent the lead agent's user-facing
# answer. Middleware/subagent model calls and empty tool-call-only
# AI messages must not overwrite the last useful assistant text.
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
if is_ai_message and (caller is None or caller == "lead_agent"):
text = self._message_text(message).strip()
if text:
self._last_ai_msg = text[:2000]
def on_chain_start(
self,
serialized: dict[str, Any],
@@ -210,7 +155,6 @@ class RunJournal(BaseCallbackHandler):
content=m.model_dump(),
metadata={"caller": caller},
)
self._record_message_summary(m, caller=caller)
break
if self._first_human_msg:
break
@@ -269,34 +213,20 @@ class RunJournal(BaseCallbackHandler):
"llm_call_index": call_index,
},
)
if rid not in self._counted_message_llm_run_ids:
self._record_message_summary(message, caller=caller)
# Token accumulation (dedup by langchain run_id to avoid double-counting
# when the callback fires more than once for the same response)
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0 and rid not in self._counted_llm_run_ids:
self._counted_llm_run_ids.add(rid)
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
if messages:
self._counted_message_llm_run_ids.add(str(run_id))
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm.error", category="trace", content=str(error))
@@ -312,14 +242,12 @@ class RunJournal(BaseCallbackHandler):
if isinstance(output, ToolMessage):
msg = cast(ToolMessage, output)
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
self._record_message_summary(msg)
elif isinstance(output, Command):
cmd = cast(Command, output)
messages = cmd.update.get("messages", [])
for message in messages:
if isinstance(message, BaseMessage):
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
self._record_message_summary(message)
else:
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
else:
@@ -402,49 +330,6 @@ class RunJournal(BaseCallbackHandler):
# -- Public methods (called by worker) --
def record_external_llm_usage_records(
self,
records: list[dict[str, int | str]],
) -> None:
"""Record token usage from external sources (e.g., subagents).
Each record should contain:
source_run_id: Unique identifier to prevent double-counting
caller: Caller tag (e.g. "subagent:general-purpose")
input_tokens: Input token count
output_tokens: Output token count
total_tokens: Total token count (computed from input+output if 0/missing)
"""
if not self._track_tokens:
return
for record in records:
source_id = str(record.get("source_run_id", ""))
if not source_id:
continue
if source_id in self._counted_external_source_ids:
continue
total_tk = record.get("total_tokens", 0) or 0
if total_tk <= 0:
input_tk = record.get("input_tokens", 0) or 0
output_tk = record.get("output_tokens", 0) or 0
total_tk = input_tk + output_tk
if total_tk <= 0:
continue
self._counted_external_source_ids.add(source_id)
self._total_input_tokens += record.get("input_tokens", 0) or 0
self._total_output_tokens += record.get("output_tokens", 0) or 0
self._total_tokens += total_tk
caller = str(record.get("caller", ""))
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
@@ -491,9 +376,6 @@ class RunJournal(BaseCallbackHandler):
"total_output_tokens": self._total_output_tokens,
"total_tokens": self._total_tokens,
"llm_call_count": self._llm_call_count,
"lead_agent_tokens": self._lead_agent_tokens,
"subagent_tokens": self._subagent_tokens,
"middleware_tokens": self._middleware_tokens,
"message_count": self._msg_count,
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
@@ -36,7 +36,6 @@ class RunRecord:
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
model_name: str | None = None
class RunManager:
@@ -66,7 +65,6 @@ class RunManager:
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
model_name=record.model_name,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -139,18 +137,6 @@ class RunManager:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
"""Update the model name for a run."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("update_model_name called for unknown run %s", run_id)
return
record.model_name = model_name
record.updated_at = _now_iso()
await self._persist_to_store(record)
logger.info("Run %s model_name=%s", run_id, model_name)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
"""Request cancellation of a run.
@@ -185,7 +171,6 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
model_name: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -236,7 +221,6 @@ class RunManager:
kwargs=kwargs or {},
created_at=now,
updated_at=now,
model_name=model_name,
)
self._runs[run_id] = record
@@ -23,7 +23,6 @@ class RunStore(abc.ABC):
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
model_name: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
@@ -22,7 +22,6 @@ class MemoryRunStore(RunStore):
thread_id,
assistant_id=None,
user_id=None,
model_name=None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -36,7 +35,6 @@ class MemoryRunStore(RunStore):
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"model_name": model_name,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
@@ -230,17 +230,6 @@ async def run_agent(
else:
agent = agent_factory(config=runnable_config)
# Capture the effective (resolved) model name from the agent's metadata.
# _resolve_model_name in agent.py may return the default model if the
# requested name is not in the allowlist — this update ensures the
# persisted model_name reflects the actual model used.
if record.model_name is not None:
resolved = getattr(agent, "metadata", {}) or {}
if isinstance(resolved, dict):
effective = resolved.get("model_name")
if effective and effective != record.model_name:
await run_manager.update_model_name(record.run_id, effective)
# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer
@@ -109,34 +109,6 @@ def get_effective_user_id() -> str:
return str(user.id)
def resolve_runtime_user_id(runtime: object | None) -> str:
"""Single source of truth for a tool/middleware's effective user_id.
Resolution order (most authoritative first):
1. ``runtime.context["user_id"]`` set by ``inject_authenticated_user_context``
in the gateway from the auth-validated ``request.state.user``. This is
the only source that survives boundaries where the contextvar may have
been lost (background tasks scheduled outside the request task,
worker pools that don't copy_context, future cross-process drivers).
2. The ``_current_user`` ContextVar set by the auth middleware at
request entry. Reliable for in-task work; copied by ``asyncio``
child tasks and by ``ContextThreadPoolExecutor``.
3. ``DEFAULT_USER_ID`` last-resort fallback so unauthenticated
CLI / migration / test paths keep working without raising.
Tools that persist user-scoped state (custom agents, memory, uploads)
MUST call this instead of ``get_effective_user_id()`` directly so they
benefit from the runtime.context channel that ``setup_agent`` already
relies on.
"""
context = getattr(runtime, "context", None)
if isinstance(context, dict):
ctx_user_id = context.get("user_id")
if ctx_user_id:
return str(ctx_user_id)
return get_effective_user_id()
# ---------------------------------------------------------------------------
# Sentinel-based user_id resolution
# ---------------------------------------------------------------------------
@@ -119,13 +119,3 @@ class LocalSandboxProvider(SandboxProvider):
# For Docker-based providers (e.g., AioSandboxProvider), cleanup
# happens at application shutdown via the shutdown() method.
pass
def reset(self) -> None:
# reset_sandbox_provider() must also clear the module singleton.
global _singleton
_singleton = None
def shutdown(self) -> None:
# LocalSandboxProvider has no extra resources beyond the shared
# singleton, so shutdown uses the same cleanup path as reset.
self.reset()
@@ -37,10 +37,6 @@ class SandboxProvider(ABC):
"""
pass
def reset(self) -> None:
"""Clear cached state that survives provider instance replacement."""
pass
_default_sandbox_provider: SandboxProvider | None = None
@@ -69,18 +65,11 @@ def reset_sandbox_provider() -> None:
The next call to `get_sandbox_provider()` will create a new instance.
Useful for testing or when switching configurations.
Providers can override `reset()` to clear any module-level state they keep
alive across instances (for example, `LocalSandboxProvider`'s cached
`LocalSandbox` singleton). Without it, config/mount changes would not take
effect on the next acquire().
Note: If the provider has active sandboxes, they will be orphaned.
Use `shutdown_sandbox_provider()` for proper cleanup.
"""
global _default_sandbox_provider
if _default_sandbox_provider is not None:
_default_sandbox_provider.reset()
_default_sandbox_provider = None
_default_sandbox_provider = None
def shutdown_sandbox_provider() -> None:
@@ -1499,13 +1499,12 @@ def write_file_tool(
content: str,
append: bool = False,
) -> str:
"""Write text content to a file. By default this overwrites the target file; set append to true to add content to the end without replacing existing content.
"""Write text content to a file.
Args:
description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
append: Whether to append content to the end of the file instead of overwriting it. Defaults to false.
"""
try:
sandbox = ensure_sandbox_initialized(runtime)
@@ -26,7 +26,7 @@ class SubagentConfig:
name: str
description: str
system_prompt: str | None = None
system_prompt: str
tools: list[str] | None = None
disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"])
skills: list[str] | None = None
@@ -26,7 +26,6 @@ from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
from deerflow.subagents.token_collector import SubagentTokenCollector
logger = logging.getLogger(__name__)
@@ -71,8 +70,6 @@ class SubagentResult:
started_at: datetime | None = None
completed_at: datetime | None = None
ai_messages: list[dict[str, Any]] | None = None
token_usage_records: list[dict[str, int | str]] = field(default_factory=list)
usage_reported: bool = False
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
def __post_init__(self):
@@ -286,13 +283,11 @@ class SubagentExecutor:
# Reuse shared middleware composition with lead agent.
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True)
# system_prompt is included in initial state messages (see _build_initial_state)
# to avoid multiple SystemMessages which some LLM APIs don't support.
return create_agent(
model=model,
tools=tools if tools is not None else self.tools,
middleware=middlewares,
system_prompt=None,
system_prompt=self.config.system_prompt,
state_schema=ThreadState,
)
@@ -367,25 +362,14 @@ class SubagentExecutor:
Returns:
Initial state dictionary and tools filtered by loaded skill metadata.
"""
# Load skills as conversation items (Codex pattern)
skills = await self._load_skills()
filtered_tools = self._apply_skill_allowed_tools(skills)
skill_messages = await self._load_skill_messages(skills)
# Combine system_prompt and skills into a single SystemMessage.
# Some LLM APIs reject multiple SystemMessages with
# "System message must be at the beginning."
system_parts: list[str] = []
if self.config.system_prompt:
system_parts.append(self.config.system_prompt)
for skill_msg in skill_messages:
system_parts.append(skill_msg.content)
messages: list[Any] = []
if system_parts:
messages.append(SystemMessage(content="\n\n".join(system_parts)))
# Skill content injected as developer/system messages before the task
messages.extend(skill_messages)
# Then the actual task
messages.append(HumanMessage(content=task))
@@ -428,20 +412,13 @@ class SubagentExecutor:
ai_messages = []
result.ai_messages = ai_messages
collector: SubagentTokenCollector | None = None
try:
state, filtered_tools = await self._build_initial_state(task)
agent = self._create_agent(filtered_tools)
# Token collector for subagent LLM calls
collector_caller = f"subagent:{self.config.name}"
collector = SubagentTokenCollector(caller=collector_caller)
# Build config with thread_id for sandbox access and recursion limit
run_config: RunnableConfig = {
"recursion_limit": self.config.max_turns,
"callbacks": [collector],
"tags": [collector_caller],
}
context: dict[str, Any] = {}
if self.thread_id:
@@ -464,8 +441,6 @@ class SubagentExecutor:
result.status = SubagentStatus.CANCELLED
result.error = "Cancelled by user"
result.completed_at = datetime.now()
if collector is not None:
result.token_usage_records = collector.snapshot_records()
return result
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
@@ -480,7 +455,6 @@ class SubagentExecutor:
result.status = SubagentStatus.CANCELLED
result.error = "Cancelled by user"
result.completed_at = datetime.now()
result.token_usage_records = collector.snapshot_records()
return result
final_state = chunk
@@ -507,7 +481,6 @@ class SubagentExecutor:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
result.token_usage_records = collector.snapshot_records()
if final_state is None:
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
@@ -587,8 +560,6 @@ class SubagentExecutor:
result.status = SubagentStatus.FAILED
result.error = str(e)
result.completed_at = datetime.now()
if collector is not None:
result.token_usage_records = collector.snapshot_records()
return result
@@ -1,63 +0,0 @@
"""Callback handler that collects LLM token usage within a subagent.
Each subagent execution creates its own collector. After the subagent
finishes, the collected records are transferred to the parent RunJournal
via :meth:`RunJournal.record_external_llm_usage_records`.
"""
from __future__ import annotations
from typing import Any
from langchain_core.callbacks import BaseCallbackHandler
class SubagentTokenCollector(BaseCallbackHandler):
"""Lightweight callback handler that collects LLM token usage within a subagent."""
def __init__(self, caller: str):
super().__init__()
self.caller = caller
self._records: list[dict[str, int | str]] = []
self._counted_run_ids: set[str] = set()
def on_llm_end(
self,
response: Any,
*,
run_id: Any,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
rid = str(run_id)
if rid in self._counted_run_ids:
return
for generation in response.generations:
for gen in generation:
if not hasattr(gen, "message"):
continue
usage = getattr(gen.message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk <= 0:
total_tk = input_tk + output_tk
if total_tk <= 0:
continue
self._counted_run_ids.add(rid)
self._records.append(
{
"source_run_id": rid,
"caller": self.caller,
"input_tokens": input_tk,
"output_tokens": output_tk,
"total_tokens": total_tk,
}
)
return
def snapshot_records(self) -> list[dict[str, int | str]]:
"""Return a copy of the accumulated usage records."""
return list(self._records)
@@ -7,13 +7,20 @@ from langgraph.types import Command
from deerflow.config.agents_config import validate_agent_name
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import resolve_runtime_user_id
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@tool(parse_docstring=True)
def _get_runtime_user_id(runtime: Runtime) -> str:
context_user_id = runtime.context.get("user_id") if runtime.context else None
if context_user_id:
return str(context_user_id)
return get_effective_user_id()
@tool
def setup_agent(
soul: str,
description: str,
@@ -38,7 +45,7 @@ def setup_agent(
if agent_name:
# Custom agents are persisted under the current user's bucket so
# different users do not see each other's agents.
user_id = resolve_runtime_user_id(runtime)
user_id = _get_runtime_user_id(runtime)
agent_dir = paths.user_agent_dir(user_id, agent_name)
else:
# Default agent (no agent_name): SOUL.md lives at the global base dir.
@@ -26,125 +26,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
# write it back to the triggering AIMessage's usage_metadata.
_subagent_usage_cache: dict[str, dict[str, int]] = {}
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
if app_config is None:
try:
app_config = get_app_config()
except (FileNotFoundError, ValueError):
return False
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
if enabled and usage:
_subagent_usage_cache[tool_call_id] = usage
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
return _subagent_usage_cache.pop(tool_call_id, None)
def _is_subagent_terminal(result: Any) -> bool:
"""Return whether a background subagent result is safe to clean up."""
return result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None
async def _await_subagent_terminal(task_id: str, max_polls: int) -> Any | None:
"""Poll until the background subagent reaches a terminal status or we run out of polls."""
for _ in range(max_polls):
result = get_background_task_result(task_id)
if result is None:
return None
if _is_subagent_terminal(result):
return result
await asyncio.sleep(5)
return None
async def _deferred_cleanup_subagent_task(task_id: str, trace_id: str, max_polls: int) -> None:
"""Keep polling a cancelled subagent until it can be safely removed."""
cleanup_poll_count = 0
while True:
result = get_background_task_result(task_id)
if result is None:
return
if _is_subagent_terminal(result):
cleanup_background_task(task_id)
return
if cleanup_poll_count >= max_polls:
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
return
await asyncio.sleep(5)
cleanup_poll_count += 1
def _log_cleanup_failure(cleanup_task: asyncio.Task[None], *, trace_id: str, task_id: str) -> None:
if cleanup_task.cancelled():
return
exc = cleanup_task.exception()
if exc is not None:
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: int) -> None:
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
cleanup_task = asyncio.create_task(_deferred_cleanup_subagent_task(task_id, trace_id, max_polls))
cleanup_task.add_done_callback(lambda task: _log_cleanup_failure(task, trace_id=trace_id, task_id=task_id))
def _find_usage_recorder(runtime: Any) -> Any | None:
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
callbacks = config.get("callbacks", [])
if not callbacks:
return None
for cb in callbacks:
if hasattr(cb, "record_external_llm_usage_records"):
return cb
return None
def _summarize_usage(records: list[dict] | None) -> dict | None:
"""Summarize token usage records into a compact dict for SSE events."""
if not records:
return None
return {
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
}
def _report_subagent_usage(runtime: Any, result: Any) -> None:
"""Report subagent token usage to the parent RunJournal, if available.
Each subagent task must be reported only once (guarded by usage_reported).
"""
if getattr(result, "usage_reported", True):
return
records = getattr(result, "token_usage_records", None) or []
if not records:
return
journal = _find_usage_recorder(runtime)
if journal is None:
logger.debug("No usage recorder found in runtime callbacks — subagent token usage not recorded")
return
try:
journal.record_external_llm_usage_records(records)
result.usage_reported = True
except Exception:
logger.warning("Failed to report subagent token usage", exc_info=True)
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
context = getattr(runtime, "context", None)
@@ -210,7 +91,6 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
runtime_app_config = _get_runtime_app_config(runtime)
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration
@@ -346,32 +226,23 @@ async def task_tool(
last_message_count = current_message_count
# Check if task completed, failed, or timed out
usage = _summarize_usage(getattr(result, "token_usage_records", None))
if result.status == SubagentStatus.COMPLETED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id)
return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}"
@@ -389,34 +260,43 @@ async def task_tool(
if poll_count > max_poll_count:
timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
_report_subagent_usage(runtime, result)
usage = _summarize_usage(getattr(result, "token_usage_records", None))
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
writer({"type": "task_timed_out", "task_id": task_id})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively.
# Without this, the thread (running in ThreadPoolExecutor with its
# own event loop via asyncio.run) would continue executing even
# after the parent task is cancelled.
request_cancel_background_task(task_id)
# Wait (shielded) for the subagent to reach a terminal state so the
# final token usage snapshot is reported to the parent RunJournal
# before the parent worker persists get_completion_data().
terminal_result = None
try:
terminal_result = await asyncio.shield(_await_subagent_terminal(task_id, max_poll_count))
except asyncio.CancelledError:
pass
async def cleanup_when_done() -> None:
max_cleanup_polls = max_poll_count
cleanup_poll_count = 0
# Report whatever the subagent collected (even if we timed out).
final_result = terminal_result or get_background_task_result(task_id)
if final_result is not None:
_report_subagent_usage(runtime, final_result)
if final_result is not None and _is_subagent_terminal(final_result):
cleanup_background_task(task_id)
else:
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
_subagent_usage_cache.pop(tool_call_id, None)
raise
except Exception:
_subagent_usage_cache.pop(tool_call_id, None)
while True:
result = get_background_task_result(task_id)
if result is None:
return
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
cleanup_background_task(task_id)
return
if cleanup_poll_count > max_cleanup_polls:
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
return
await asyncio.sleep(5)
cleanup_poll_count += 1
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
if cleanup_task.cancelled():
return
exc = cleanup_task.exception()
if exc is not None:
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
raise
@@ -27,7 +27,7 @@ from langgraph.types import Command
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import resolve_runtime_user_id
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ def _cleanup_temps(temps: list[Path]) -> None:
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
@tool(parse_docstring=True)
@tool
def update_agent(
runtime: Runtime,
soul: str | None = None,
@@ -118,13 +118,9 @@ def update_agent(
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
# Resolve the active user so that updates only affect this user's agent.
# ``resolve_runtime_user_id`` prefers ``runtime.context["user_id"]`` (set by
# the gateway from the auth-validated request) and falls back to the
# contextvar, then DEFAULT_USER_ID. This matches setup_agent so a user
# creating an agent and later refining it always touches the same files,
# even if the contextvar gets lost across an async/thread boundary
# (issue #2782 / #2862 class of bugs).
user_id = resolve_runtime_user_id(runtime)
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
# is set (matching how memory and thread storage behave).
user_id = get_effective_user_id()
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
# ``_resolve_model_name`` silently falls back to the default at runtime
@@ -10,11 +10,11 @@ from weakref import WeakValueDictionary
from langchain.tools import tool
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
from deerflow.tools.sync import make_sync_tool_wrapper
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@@ -235,4 +235,4 @@ async def skill_manage_tool(
)
skill_manage_tool.func = make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
skill_manage_tool.func = _make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
@@ -1,36 +0,0 @@
"""Utilities for invoking async tools from synchronous agent paths."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
logger = logging.getLogger(__name__)
# Shared thread pool for sync tool invocation in async environments.
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="tool-sync")
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
raise
return sync_wrapper
@@ -7,8 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.sync import make_sync_tool_wrapper
from deerflow.tools.builtins.tool_search import reset_deferred_registry
logger = logging.getLogger(__name__)
@@ -34,13 +33,6 @@ def _is_host_bash_tool(tool: object) -> bool:
return False
def _ensure_sync_invocable_tool(tool: BaseTool) -> BaseTool:
"""Attach a sync wrapper to async-only tools used by sync agent callers."""
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tool
def get_available_tools(
groups: list[str] | None = None,
include_mcp: bool = True,
@@ -85,7 +77,7 @@ def get_available_tools(
cfg.use,
)
loaded_tools = [_ensure_sync_invocable_tool(t) for _, t in loaded_tools_raw]
loaded_tools = [t for _, t in loaded_tools_raw]
# Conditionally add tools based on config
builtin_tools = BUILTIN_TOOLS.copy()
@@ -116,6 +108,8 @@ def get_available_tools(
# made through the Gateway API (which runs in a separate process) are immediately
# reflected when loading MCP tools.
mcp_tools = []
# Reset deferred registry upfront to prevent stale state from previous calls
reset_deferred_registry()
if include_mcp:
try:
from deerflow.config.extensions_config import ExtensionsConfig
@@ -133,51 +127,12 @@ def get_available_tools(
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
# Reuse the existing registry if one is already set for
# this async context. ``get_available_tools`` is
# re-entered whenever a subagent is spawned
# (``task_tool`` calls it to build the child agent's
# toolset), and previously we used to unconditionally
# rebuild the registry — wiping out the parent agent's
# tool_search promotions. The
# ``DeferredToolFilterMiddleware`` then re-hid those
# tools from subsequent model calls, leaving the agent
# able to see a tool's name but unable to invoke it
# (issue #2884). ``contextvars`` already gives us the
# lifetime semantics we want: a fresh request / graph
# run starts in a new asyncio task with the
# ContextVar at its default of ``None``, so reuse is
# only triggered for re-entrant calls inside one run.
#
# Intentionally NOT reconciling against the current
# ``mcp_tools`` snapshot. The MCP cache only refreshes
# on ``extensions_config.json`` mtime changes, which
# in practice happens between graph runs — not inside
# one. And even if a refresh did happen mid-run, the
# already-built lead agent's ``ToolNode`` still holds
# the *previous* tool set (LangGraph binds tools at
# graph construction time), so a brand-new MCP tool
# couldn't actually be invoked anyway. The
# ``DeferredToolRegistry`` doesn't retain the names
# of previously-promoted tools (``promote()`` drops
# the entry entirely), so re-syncing the registry
# against a fresh ``mcp_tools`` list would
# mis-classify those promotions as new tools and
# re-register them as deferred — exactly the bug
# this fix exists to prevent.
existing_registry = get_deferred_registry()
if existing_registry is None:
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
else:
mcp_tool_names = {t.name for t in mcp_tools}
still_deferred = len(existing_registry)
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
builtin_tools.append(tool_search_tool)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e:
-35
View File
@@ -1,35 +0,0 @@
[project]
name = "deerflow-storage"
version = "0.1.0"
description = "DeerFlow storage framework"
requires-python = ">=3.12"
dependencies = [
"dotenv>=0.9.9",
"pydantic>=2.12.5",
"pyyaml>=6.0.3",
"sqlalchemy[asyncio]>=2.0,<3.0",
"alembic>=1.13",
"langgraph>=1.1.9",
]
[project.optional-dependencies]
postgres = [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
mysql = [
"aiomysql>=0.2",
"langgraph-checkpoint-mysql>=3.0.0",
]
sqlite = [
"aiosqlite>=0.22.1",
"langgraph-checkpoint-sqlite>=3.0.3"
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["store"]
@@ -1,5 +0,0 @@
from .enums import DataBaseType
__all__ = [
"DataBaseType",
]
@@ -1,41 +0,0 @@
from enum import Enum
from enum import IntEnum as SourceIntEnum
from enum import StrEnum as SourceStrEnum
from typing import Any, TypeVar
T = TypeVar("T", bound=Enum)
class _EnumBase:
"""Base enum class with common utility methods."""
@classmethod
def get_member_keys(cls) -> list[str]:
"""Return a list of enum member names."""
return list(cls.__members__.keys())
@classmethod
def get_member_values(cls) -> list:
"""Return a list of enum member values."""
return [item.value for item in cls.__members__.values()]
@classmethod
def get_member_dict(cls) -> dict[str, Any]:
"""Return a dict mapping member names to values."""
return {name: item.value for name, item in cls.__members__.items()}
class IntEnum(_EnumBase, SourceIntEnum):
"""Integer enum base class."""
class StrEnum(_EnumBase, SourceStrEnum):
"""String enum base class."""
class DataBaseType(StrEnum):
"""Database type."""
sqlite = "sqlite"
mysql = "mysql"
postgresql = "postgresql"
@@ -1,286 +0,0 @@
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from store.config.storage_config import StorageConfig
load_dotenv()
logger = logging.getLogger(__name__)
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
cwd = Path.cwd().resolve()
candidates = (
cwd / "config.yaml",
backend_dir / "config.yaml",
repo_root / "config.yaml",
)
return tuple(dict.fromkeys(candidates))
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
"""Keep the existing public `database:` config compatible with storage."""
if "storage" in config_data:
return
database = config_data.get("database")
if not isinstance(database, dict):
return
backend = database.get("backend")
if backend == "memory":
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
storage: dict[str, Any] = {
"driver": "postgres" if backend == "postgres" else backend,
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
"echo_sql": database.get("echo_sql", False),
"pool_size": database.get("pool_size", 5),
}
postgres_url = database.get("postgres_url")
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
from sqlalchemy.engine.url import make_url
parsed = make_url(postgres_url)
storage["database_url"] = postgres_url
storage.update(
{
"username": parsed.username or "",
"password": parsed.password or "",
"host": parsed.host or "localhost",
"port": parsed.port or 5432,
"db_name": parsed.database or "deerflow",
}
)
config_data["storage"] = storage
class AppConfig(BaseModel):
"""DeerFlow application configuration."""
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
storage: StorageConfig = Field(default=StorageConfig())
model_config = ConfigDict(extra="allow", frozen=False)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
"""Resolve the config file path.
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
"""
if config_path:
path = Path(config_path)
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
return path
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
for path in _default_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
resolved_path = cls.resolve_config_path(config_path)
with open(resolved_path, encoding="utf-8") as f:
config_data = yaml.safe_load(f) or {}
cls._check_config_version(config_data, resolved_path)
config_data = cls.resolve_env_variables(config_data)
_storage_from_database_config(config_data)
if os.getenv("TIMEZONE"):
config_data["timezone"] = os.getenv("TIMEZONE")
result = cls.model_validate(config_data)
return result
@classmethod
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
Emits a warning if the user's config_version is lower than the example's.
Missing config_version is treated as version 0 (pre-versioning).
"""
try:
user_version = int(config_data.get("config_version", 0))
except (TypeError, ValueError):
user_version = 0
# Find config.example.yaml by searching config.yaml's directory and its parents
example_path = None
search_dir = config_path.parent
for _ in range(5): # search up to 5 levels
candidate = search_dir / "config.example.yaml"
if candidate.exists():
example_path = candidate
break
parent = search_dir.parent
if parent == search_dir:
break
search_dir = parent
if example_path is None:
return
try:
with open(example_path, encoding="utf-8") as f:
example_data = yaml.safe_load(f)
raw = example_data.get("config_version", 0) if example_data else 0
try:
example_version = int(raw)
except (TypeError, ValueError):
example_version = 0
except Exception:
return
if user_version < example_version:
logger.warning(
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
user_version,
example_version,
)
@classmethod
def resolve_env_variables(cls, config: Any) -> Any:
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
if isinstance(config, str):
if config.startswith("$"):
env_value = os.getenv(config[1:])
if env_value is None:
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
return env_value
return config
elif isinstance(config, dict):
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
elif isinstance(config, list):
return [cls.resolve_env_variables(item) for item in config]
return config
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Force reload from file and update the cache."""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Clear the cache so the next `get_app_config()` reloads from file."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Inject a config instance directly, bypassing file loading (for testing)."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -1,69 +0,0 @@
"""Unified storage backend configuration for checkpointer and application data.
SQLite: checkpointer {sqlite_dir}/checkpoints.db, app {sqlite_dir}/deerflow.db
(separate files to avoid write-lock contention)
Postgres: shared URL, independent connection pools per layer.
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
before this config is instantiated.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, Field
def _strip_legacy_state_prefix(path: str) -> str:
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
prefix = ".deer-flow/"
if path == ".deer-flow":
return "."
if path.startswith(prefix):
return path[len(prefix) :]
return path
class StorageConfig(BaseModel):
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
default="sqlite",
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description="Directory for SQLite .db files (sqlite driver only).",
)
username: str = Field(default="", description="db username ")
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
host: str = Field(default="localhost", description="db host.")
port: int = Field(default=5432, description="db port.")
db_name: str = Field(default="deerflow", description="db database name.")
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
pool_size: int = Field(default=5, description="Connection pool size per layer.")
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
from pathlib import Path
path = Path(self.sqlite_dir)
if path.is_absolute():
return str(path.resolve())
try:
from deerflow.config.paths import resolve_path
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
except ImportError:
return str(path.resolve())
@property
def sqlite_storage_path(self) -> str:
"""SQLite file path for storage-owned app data and checkpointer."""
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
@@ -1,32 +0,0 @@
from store.persistence.base_model import (
Base,
DataClassBase,
DateTimeMixin,
MappedBase,
TimeZone,
UniversalText,
id_key,
)
from .factory import (
create_persistence,
create_persistence_from_database_config,
create_persistence_from_storage_config,
storage_config_from_database_config,
)
from .types import AppPersistence
__all__ = [
"Base",
"DataClassBase",
"DateTimeMixin",
"MappedBase",
"TimeZone",
"UniversalText",
"id_key",
"create_persistence",
"create_persistence_from_database_config",
"create_persistence_from_storage_config",
"storage_config_from_database_config",
"AppPersistence",
]
@@ -1,111 +0,0 @@
from datetime import datetime
from typing import Annotated
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
from store.utils import get_timezone
def current_time() -> datetime:
return get_timezone().now()
id_key = Annotated[
int,
mapped_column(
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
unique=True,
index=True,
autoincrement=True,
sort_order=-999,
comment="Primary key ID",
),
]
class UniversalText(TypeDecorator[str]):
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect): # noqa: ANN001
if dialect.name == "mysql":
return dialect.type_descriptor(LONGTEXT())
return dialect.type_descriptor(Text())
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
class TimeZone(TypeDecorator[datetime]):
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
impl = DateTime(timezone=True)
cache_ok = True
@property
def python_type(self) -> type[datetime]:
return datetime
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.utcoffset() != timezone.now().utcoffset():
value = timezone.from_datetime(value)
return value
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.tzinfo is None:
value = value.replace(tzinfo=timezone.tz_info)
return value
class DateTimeMixin(MappedAsDataclass):
"""Mixin that adds created_time / updated_time columns."""
created_time: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
TimeZone,
init=False,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
class MappedBase(AsyncAttrs, DeclarativeBase):
"""Async-capable declarative base for all ORM models."""
@declared_attr.directive
def __tablename__(self) -> str:
return self.__name__.lower()
@declared_attr.directive
def __table_args__(self) -> dict:
return {"comment": self.__doc__ or ""}
class DataClassBase(MappedAsDataclass, MappedBase):
"""Declarative base with native dataclass integration."""
__abstract__ = True
class Base(DataClassBase, DateTimeMixin):
"""Declarative dataclass base with created_time / updated_time columns."""
__abstract__ = True
@@ -1,9 +0,0 @@
from .mysql import build_mysql_persistence
from .postgres import build_postgres_persistence
from .sqlite import build_sqlite_persistence
__all__ = [
"build_postgres_persistence",
"build_mysql_persistence",
"build_sqlite_persistence",
]
@@ -1,76 +0,0 @@
from __future__ import annotations
import json
from sqlalchemy import URL
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
def _validate_mysql_driver(db_url: URL) -> str:
url = make_url(db_url)
driver = url.get_driver_name()
if driver not in {"aiomysql", "asyncmy"}:
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
return driver
def _checkpoint_conn_string(db_url: URL) -> str:
return db_url.render_as_string(hide_password=False)
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
_validate_mysql_driver(db_url)
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
pool_pre_ping=True,
pool_size=pool_size,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables / migrations
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -1,64 +0,0 @@
from __future__ import annotations
import json
from sqlalchemy import URL
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
def _checkpoint_conn_string(db_url: URL) -> str:
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
pool_pre_ping=True,
pool_size=pool_size,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables / migrations
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -1,68 +0,0 @@
from __future__ import annotations
import json
from sqlalchemy import URL, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
@event.listens_for(engine.sync_engine, "connect")
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
cursor = dbapi_conn.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.execute("PRAGMA foreign_keys=ON;")
finally:
cursor.close()
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -1,123 +0,0 @@
from typing import Any
from sqlalchemy import URL
from sqlalchemy.engine.url import make_url
from store.common import DataBaseType
from store.config.app_config import get_app_config
from store.config.storage_config import StorageConfig
from store.persistence.types import AppPersistence
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
"""Convert the existing public DatabaseConfig shape to StorageConfig.
Storage only owns durable database-backed persistence. The app bridge
should handle memory mode before calling into this package.
"""
backend = getattr(database_config, "backend", None)
if backend == "sqlite":
return StorageConfig(
driver="sqlite",
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
echo_sql=getattr(database_config, "echo_sql", False),
pool_size=getattr(database_config, "pool_size", 5),
)
if backend == "postgres":
postgres_url = getattr(database_config, "postgres_url", "")
if not postgres_url:
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
parsed = make_url(postgres_url)
return StorageConfig(
driver="postgres",
database_url=postgres_url,
username=parsed.username or "",
password=parsed.password or "",
host=parsed.host or "localhost",
port=parsed.port or 5432,
db_name=parsed.database or "deerflow",
echo_sql=getattr(database_config, "echo_sql", False),
pool_size=getattr(database_config, "pool_size", 5),
)
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
def _create_database_url(storage_config: StorageConfig) -> URL:
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
if storage_config.driver == DataBaseType.sqlite:
driver = "sqlite+aiosqlite"
elif storage_config.driver == DataBaseType.mysql:
driver = "mysql+aiomysql"
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
driver = "postgresql+asyncpg"
else:
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
if storage_config.driver == DataBaseType.sqlite:
import os
db_path = storage_config.sqlite_storage_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = URL.create(
drivername=driver,
database=db_path,
)
elif storage_config.database_url:
url = make_url(storage_config.database_url)
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
url = url.set(drivername="postgresql+asyncpg")
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
url = url.set(drivername="mysql+aiomysql")
else:
url = URL.create(
drivername=driver,
username=storage_config.username,
password=storage_config.password,
host=storage_config.host,
port=storage_config.port,
database=storage_config.db_name or "deerflow",
)
return url
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
from .drivers.mysql import build_mysql_persistence
from .drivers.postgres import build_postgres_persistence
from .drivers.sqlite import build_sqlite_persistence
driver = storage_config.driver
db_url = _create_database_url(storage_config)
if driver in ("postgres", "postgresql"):
return await build_postgres_persistence(
db_url,
echo=storage_config.echo_sql,
pool_size=storage_config.pool_size,
)
if driver == "mysql":
return await build_mysql_persistence(
db_url,
echo=storage_config.echo_sql,
pool_size=storage_config.pool_size,
)
if driver == "sqlite":
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
raise ValueError(f"Unsupported database driver: {driver}")
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
storage_config = storage_config_from_database_config(database_config)
return await create_persistence_from_storage_config(storage_config)
async def create_persistence() -> AppPersistence:
app_config = get_app_config()
return await create_persistence_from_storage_config(app_config.storage)
@@ -1,189 +0,0 @@
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
from sqlalchemy import BigInteger, Float, String, bindparam
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.types import Boolean, TypeEngine
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
_INT64_MIN = -(2**63)
_INT64_MAX = 2**63 - 1
def validate_metadata_filter_key(key: object) -> bool:
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
def validate_metadata_filter_value(value: object) -> bool:
"""Return True when *value* can be compiled into a portable JSON predicate."""
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
return False
if isinstance(value, int) and not isinstance(value, bool):
return _INT64_MIN <= value <= _INT64_MAX
return True
class JsonMatch(ColumnElement[bool]):
"""Dialect-portable ``column[key] == value`` for JSON columns."""
inherit_cache = True
type = Boolean()
_is_implicitly_boolean = True
_traverse_internals = [
("column", InternalTraversal.dp_clauseelement),
("key", InternalTraversal.dp_string),
("value", InternalTraversal.dp_plain_obj),
("value_type", InternalTraversal.dp_string),
]
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
if not validate_metadata_filter_key(key):
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
if not validate_metadata_filter_value(value):
if isinstance(value, int) and not isinstance(value, bool):
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
self.column = column
self.key = key
self.value = value
self.value_type = type(value).__qualname__
super().__init__()
@dataclass(frozen=True)
class _Dialect:
null_type: str
num_types: tuple[str, ...]
num_cast: str
int_types: tuple[str, ...]
int_cast: str
int_guard: str | None
string_type: str
bool_type: str | None
true_value: str
false_value: str
_SQLITE = _Dialect(
null_type="null",
num_types=("integer", "real"),
num_cast="REAL",
int_types=("integer",),
int_cast="INTEGER",
int_guard=None,
string_type="text",
bool_type=None,
true_value="true",
false_value="false",
)
_POSTGRES = _Dialect(
null_type="null",
num_types=("number",),
num_cast="DOUBLE PRECISION",
int_types=("number",),
int_cast="BIGINT",
int_guard="'^-?[0-9]+$'",
string_type="string",
bool_type="boolean",
true_value="true",
false_value="false",
)
_MYSQL = _Dialect(
null_type="NULL",
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
num_cast="DOUBLE",
int_types=("INTEGER",),
int_cast="SIGNED",
int_guard=None,
string_type="STRING",
bool_type="BOOLEAN",
true_value="true",
false_value="false",
)
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
param = bindparam(None, value, type_=sa_type)
return compiler.process(param, **kw)
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
if len(types) == 1:
return f"{typeof} = '{types[0]}'"
quoted = ", ".join(f"'{type_name}'" for type_name in types)
return f"{typeof} IN ({quoted})"
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
if value is None:
return f"{typeof} = '{dialect.null_type}'"
if isinstance(value, bool):
bool_str = dialect.true_value if value else dialect.false_value
if dialect.bool_type is None:
return f"{typeof} = '{bool_str}'"
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
if isinstance(value, int):
bp = _bind(compiler, value, BigInteger(), **kw)
if dialect.int_guard:
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
if isinstance(value, float):
bp = _bind(compiler, value, Float(), **kw)
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
bp = _bind(compiler, str(value), String(), **kw)
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
@compiles(JsonMatch, "sqlite")
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"json_type({col}, '{path}')"
extract = f"json_extract({col}, '{path}')"
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
@compiles(JsonMatch, "postgresql")
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
typeof = f"json_typeof({col} -> '{element.key}')"
extract = f"({col} ->> '{element.key}')"
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
@compiles(JsonMatch, "mysql")
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
@compiles(JsonMatch)
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
return JsonMatch(column, key, value)
@@ -1,3 +0,0 @@
from .close import close_in_order
__all__ = ["close_in_order"]
@@ -1,28 +0,0 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
AsyncCloser = Callable[[], Awaitable[None]]
async def close_in_order(*closers: AsyncCloser) -> None:
"""
Run async closers in order and raise the first error, if any.
Notes
-----
- Used to keep driver-specific close logic readable.
- We intentionally do not stop at first failure, so later resources
still get a chance to close.
"""
first_error: Exception | None = None
for closer in closers:
try:
await closer()
except Exception as exc:
if first_error is None:
first_error = exc
if first_error is not None:
raise first_error
@@ -1,23 +0,0 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
AsyncSetup = Callable[[], Awaitable[None]]
AsyncClose = Callable[[], Awaitable[None]]
@dataclass(slots=True)
class AppPersistence:
"""
Unified runtime persistence bundle.
"""
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: AsyncSetup
aclose: AsyncClose
@@ -1,53 +0,0 @@
from store.repositories.contracts import (
Feedback,
FeedbackAggregate,
FeedbackCreate,
FeedbackRepositoryProtocol,
InvalidMetadataFilterError,
Run,
RunCreate,
RunEvent,
RunEventCreate,
RunEventRepositoryProtocol,
RunRepositoryProtocol,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
User,
UserCreate,
UserNotFoundError,
UserRepositoryProtocol,
)
from store.repositories.factory import (
build_feedback_repository,
build_run_event_repository,
build_run_repository,
build_thread_meta_repository,
build_user_repository,
)
__all__ = [
"Feedback",
"FeedbackAggregate",
"FeedbackCreate",
"FeedbackRepositoryProtocol",
"InvalidMetadataFilterError",
"Run",
"RunCreate",
"RunEvent",
"RunEventCreate",
"RunEventRepositoryProtocol",
"RunRepositoryProtocol",
"ThreadMeta",
"ThreadMetaCreate",
"ThreadMetaRepositoryProtocol",
"User",
"UserCreate",
"UserNotFoundError",
"UserRepositoryProtocol",
"build_run_repository",
"build_run_event_repository",
"build_thread_meta_repository",
"build_feedback_repository",
"build_user_repository",
]
@@ -1,49 +0,0 @@
from store.repositories.contracts.feedback import (
Feedback,
FeedbackAggregate,
FeedbackCreate,
FeedbackRepositoryProtocol,
)
from store.repositories.contracts.run import (
Run,
RunCreate,
RunRepositoryProtocol,
)
from store.repositories.contracts.run_event import (
RunEvent,
RunEventCreate,
RunEventRepositoryProtocol,
)
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.contracts.user import (
User,
UserCreate,
UserNotFoundError,
UserRepositoryProtocol,
)
__all__ = [
"Feedback",
"FeedbackAggregate",
"FeedbackCreate",
"FeedbackRepositoryProtocol",
"Run",
"RunCreate",
"RunEvent",
"RunEventCreate",
"RunEventRepositoryProtocol",
"RunRepositoryProtocol",
"InvalidMetadataFilterError",
"ThreadMeta",
"ThreadMetaCreate",
"ThreadMetaRepositoryProtocol",
"User",
"UserCreate",
"UserNotFoundError",
"UserRepositoryProtocol",
]
@@ -1,77 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Protocol, TypedDict
from pydantic import BaseModel, ConfigDict
class FeedbackCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
feedback_id: str
run_id: str
thread_id: str
rating: int
user_id: str | None = None
message_id: str | None = None
comment: str | None = None
class Feedback(BaseModel):
model_config = ConfigDict(frozen=True)
feedback_id: str
run_id: str
thread_id: str
rating: int
user_id: str | None
message_id: str | None
comment: str | None
created_time: datetime
class FeedbackAggregate(TypedDict):
run_id: str
total: int
positive: int
negative: int
class FeedbackRepositoryProtocol(Protocol):
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
pass
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
pass
async def get_feedback(self, feedback_id: str) -> Feedback | None:
pass
async def list_feedback_by_run(
self,
run_id: str,
*,
thread_id: str | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
pass
async def list_feedback_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
pass
async def delete_feedback(self, feedback_id: str) -> bool:
pass
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
pass
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
pass
@@ -1,100 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class RunCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
run_id: str
thread_id: str
assistant_id: str | None = None
user_id: str | None = None
status: str = "pending"
model_name: str | None = None
multitask_strategy: str = "reject"
error: str | None = None
follow_up_to_run_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
created_time: datetime | None = None
class Run(BaseModel):
model_config = ConfigDict(frozen=True)
run_id: str
thread_id: str
assistant_id: str | None
user_id: str | None
status: str
model_name: str | None
multitask_strategy: str
error: str | None
follow_up_to_run_id: str | None
metadata: dict[str, Any]
kwargs: dict[str, Any]
total_input_tokens: int
total_output_tokens: int
total_tokens: int
llm_call_count: int
lead_agent_tokens: int
subagent_tokens: int
middleware_tokens: int
message_count: int
first_human_message: str | None
last_ai_message: str | None
created_time: datetime
updated_time: datetime | None
class RunRepositoryProtocol(Protocol):
async def create_run(self, data: RunCreate) -> Run:
pass
async def get_run(self, run_id: str) -> Run | None:
pass
async def list_runs_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[Run]:
pass
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
pass
async def delete_run(self, run_id: str) -> None:
pass
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
pass
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
first_human_message: str | None = None,
last_ai_message: str | None = None,
error: str | None = None,
) -> None:
pass
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
pass
@@ -1,83 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class RunEventCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str
run_id: str
user_id: str | None = None
event_type: str
category: str
content: Any = ""
metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime | None = None
class RunEvent(BaseModel):
model_config = ConfigDict(frozen=True)
thread_id: str
run_id: str
user_id: str | None
event_type: str
category: str
content: Any
metadata: dict[str, Any]
seq: int
created_at: datetime
class RunEventRepositoryProtocol(Protocol):
# Sequence values are time-ordered integer cursors. The application layer
# owns the single-writer invariant for a thread while a run is active.
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
pass
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
pass
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
pass
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
pass
@@ -1,67 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class InvalidMetadataFilterError(ValueError):
"""Raised when all client-supplied metadata filters are rejected."""
class ThreadMetaCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str
assistant_id: str | None = None
user_id: str | None = None
display_name: str | None = None
status: str = "idle"
metadata: dict[str, Any] = Field(default_factory=dict)
class ThreadMeta(BaseModel):
model_config = ConfigDict(frozen=True)
thread_id: str
assistant_id: str | None
user_id: str | None
display_name: str | None
status: str
metadata: dict[str, Any]
created_time: datetime
updated_time: datetime | None
class ThreadMetaRepositoryProtocol(Protocol):
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
pass
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
pass
async def update_thread_meta(
self,
thread_id: str,
*,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
pass
async def delete_thread(self, thread_id: str) -> None:
pass
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
pass
@@ -1,64 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal, Protocol
from pydantic import BaseModel, ConfigDict
class UserNotFoundError(LookupError):
"""Raised when an update targets a user row that no longer exists."""
class UserCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str
email: str
password_hash: str | None = None
system_role: Literal["admin", "user"] = "user"
created_at: datetime | None = None
oauth_provider: str | None = None
oauth_id: str | None = None
needs_setup: bool = False
token_version: int = 0
class User(BaseModel):
model_config = ConfigDict(frozen=True)
id: str
email: str
password_hash: str | None
system_role: Literal["admin", "user"]
created_at: datetime
oauth_provider: str | None
oauth_id: str | None
needs_setup: bool
token_version: int
class UserRepositoryProtocol(Protocol):
async def create_user(self, data: UserCreate) -> User:
pass
async def get_user_by_id(self, user_id: str) -> User | None:
pass
async def get_user_by_email(self, email: str) -> User | None:
pass
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
pass
async def get_first_admin(self) -> User | None:
pass
async def update_user(self, data: User) -> User:
pass
async def count_users(self) -> int:
pass
async def count_admin_users(self) -> int:
pass
@@ -1,13 +0,0 @@
from store.repositories.db.feedback import DbFeedbackRepository
from store.repositories.db.run import DbRunRepository
from store.repositories.db.run_event import DbRunEventRepository
from store.repositories.db.thread_meta import DbThreadMetaRepository
from store.repositories.db.user import DbUserRepository
__all__ = [
"DbFeedbackRepository",
"DbRunRepository",
"DbRunEventRepository",
"DbThreadMetaRepository",
"DbUserRepository",
]
@@ -1,142 +0,0 @@
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import case, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
from store.repositories.models.feedback import Feedback as FeedbackModel
def _to_feedback(m: FeedbackModel) -> Feedback:
return Feedback(
feedback_id=m.feedback_id,
run_id=m.run_id,
thread_id=m.thread_id,
rating=m.rating,
user_id=m.user_id,
message_id=m.message_id,
comment=m.comment,
created_time=m.created_time,
)
class DbFeedbackRepository(FeedbackRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
if data.rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
model = FeedbackModel(
feedback_id=data.feedback_id,
run_id=data.run_id,
thread_id=data.thread_id,
rating=data.rating,
user_id=data.user_id,
message_id=data.message_id,
comment=data.comment,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_feedback(model)
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
if data.rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
result = await self._session.execute(
select(FeedbackModel).where(
FeedbackModel.thread_id == data.thread_id,
FeedbackModel.run_id == data.run_id,
FeedbackModel.user_id == data.user_id,
)
)
model = result.scalar_one_or_none()
if model is None:
return await self.create_feedback(data)
model.rating = data.rating
model.message_id = data.message_id
model.comment = data.comment
model.created_time = datetime.now(UTC)
await self._session.flush()
await self._session.refresh(model)
return _to_feedback(model)
async def get_feedback(self, feedback_id: str) -> Feedback | None:
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
model = result.scalar_one_or_none()
return _to_feedback(model) if model else None
async def list_feedback_by_run(
self,
run_id: str,
*,
thread_id: str | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
if thread_id is not None:
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
stmt = stmt.order_by(FeedbackModel.created_time.desc())
if limit is not None:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
return [_to_feedback(m) for m in result.scalars().all()]
async def list_feedback_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
stmt = stmt.order_by(FeedbackModel.created_time.desc())
if limit is not None:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
return [_to_feedback(m) for m in result.scalars().all()]
async def delete_feedback(self, feedback_id: str) -> bool:
existing = await self.get_feedback(feedback_id)
if existing is None:
return False
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
return True
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
stmt = select(FeedbackModel).where(
FeedbackModel.thread_id == thread_id,
FeedbackModel.run_id == run_id,
)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
result = await self._session.execute(stmt)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
return True
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
stmt = select(
func.count().label("total"),
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
row = (await self._session.execute(stmt)).one()
return {
"run_id": run_id,
"total": int(row.total),
"positive": int(row.positive),
"negative": int(row.negative),
}
@@ -1,185 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
from store.repositories.models.run import Run as RunModel
def _to_run(m: RunModel) -> Run:
return Run(
run_id=m.run_id,
thread_id=m.thread_id,
assistant_id=m.assistant_id,
user_id=m.user_id,
status=m.status,
model_name=m.model_name,
multitask_strategy=m.multitask_strategy,
error=m.error,
follow_up_to_run_id=m.follow_up_to_run_id,
metadata=dict(m.meta or {}),
kwargs=dict(m.kwargs or {}),
total_input_tokens=m.total_input_tokens,
total_output_tokens=m.total_output_tokens,
total_tokens=m.total_tokens,
llm_call_count=m.llm_call_count,
lead_agent_tokens=m.lead_agent_tokens,
subagent_tokens=m.subagent_tokens,
middleware_tokens=m.middleware_tokens,
message_count=m.message_count,
first_human_message=m.first_human_message,
last_ai_message=m.last_ai_message,
created_time=m.created_time,
updated_time=m.updated_time,
)
class DbRunRepository(RunRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_run(self, data: RunCreate) -> Run:
model = RunModel(
run_id=data.run_id,
thread_id=data.thread_id,
assistant_id=data.assistant_id,
user_id=data.user_id,
status=data.status,
model_name=data.model_name,
multitask_strategy=data.multitask_strategy,
error=data.error,
follow_up_to_run_id=data.follow_up_to_run_id,
meta=dict(data.metadata),
kwargs=dict(data.kwargs),
)
if data.created_time is not None:
model.created_time = data.created_time
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_run(model)
async def get_run(self, run_id: str) -> Run | None:
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
model = result.scalar_one_or_none()
return _to_run(model) if model else None
async def list_runs_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[Run]:
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(RunModel.user_id == user_id)
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
result = await self._session.execute(stmt)
return [_to_run(m) for m in result.scalars().all()]
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
values: dict = {"status": status}
if error is not None:
values["error"] = error
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
async def delete_run(self, run_id: str) -> None:
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
if before is None:
before_dt = datetime.now().astimezone()
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
return [_to_run(m) for m in result.scalars().all()]
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
first_human_message: str | None = None,
last_ai_message: str | None = None,
error: str | None = None,
) -> None:
values = {
"status": status,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if error is not None:
values["error"] = error
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = RunModel.status.in_(("success", "error"))
model_expr = func.coalesce(RunModel.model_name, "unknown")
stmt = (
select(
model_expr.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
)
.where(RunModel.thread_id == thread_id, completed)
.group_by(model_expr)
)
rows = (await self._session.execute(stmt)).all()
total_tokens = total_input = total_output = total_runs = 0
lead_agent = subagent = middleware = 0
by_model: dict[str, dict] = {}
for row in rows:
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
total_tokens += row.total_tokens
total_input += row.total_input_tokens
total_output += row.total_output_tokens
total_runs += row.runs
lead_agent += row.lead_agent
subagent += row.subagent
middleware += row.middleware
return {
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": total_runs,
"by_model": by_model,
"by_caller": {
"lead_agent": lead_agent,
"subagent": subagent,
"middleware": middleware,
},
}
@@ -1,207 +0,0 @@
from __future__ import annotations
import json
import secrets
import threading
import time
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
class _SequenceAllocator:
def __init__(self) -> None:
self._last_millis = 0
self._lock = threading.Lock()
def allocate_base(self, batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
now_ms = time.time_ns() // 1_000_000
with self._lock:
seq_ms = max(now_ms, self._last_millis + 1)
self._last_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
_sequence_allocator = _SequenceAllocator()
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str):
next_metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
next_metadata["content_is_dict"] = True
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
return content, metadata
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
return content
try:
return json.loads(content)
except json.JSONDecodeError:
return content
def _to_run_event(model: RunEventModel) -> RunEvent:
raw_metadata = dict(model.meta or {})
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
return RunEvent(
thread_id=model.thread_id,
run_id=model.run_id,
user_id=model.user_id,
event_type=model.event_type,
category=model.category,
content=_deserialize_content(model.content, raw_metadata),
metadata=metadata,
seq=model.seq,
created_at=model.created_at,
)
class DbRunEventRepository(RunEventRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
if not events:
return []
seq_base = _sequence_allocator.allocate_base(len(events))
rows: list[RunEventModel] = []
for index, event in enumerate(events, start=1):
content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel(
thread_id=event.thread_id,
run_id=event.run_id,
user_id=event.user_id,
seq=seq_base + index,
event_type=event.event_type,
category=event.category,
content=content,
meta=metadata,
)
if event.created_at is not None:
row.created_at = event.created_at
self._session.add(row)
rows.append(row)
await self._session.flush()
return [_to_run_event(row) for row in rows]
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.category == "message",
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if before_seq is not None:
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
if after_seq is not None:
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.run_id == run_id,
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if event_types is not None:
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.run_id == run_id,
RunEventModel.category == "message",
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if before_seq is not None:
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
if after_seq is not None:
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
count = await self._session.scalar(stmt)
return int(count or 0)
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
conditions = [RunEventModel.thread_id == thread_id]
if user_id is not None:
conditions.append(RunEventModel.user_id == user_id)
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
await self._session.execute(delete(RunEventModel).where(*conditions))
return int(count or 0)
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
if user_id is not None:
conditions.append(RunEventModel.user_id == user_id)
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
await self._session.execute(delete(RunEventModel).where(*conditions))
return int(count or 0)
@@ -1,113 +0,0 @@
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.persistence.json_compat import json_match
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
logger = logging.getLogger(__name__)
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
return ThreadMeta(
thread_id=m.thread_id,
assistant_id=m.assistant_id,
user_id=m.user_id,
display_name=m.display_name,
status=m.status,
metadata=dict(m.meta or {}),
created_time=m.created_time,
updated_time=m.updated_time,
)
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
model = ThreadMetaModel(
thread_id=data.thread_id,
assistant_id=data.assistant_id,
user_id=data.user_id,
display_name=data.display_name,
status=data.status,
meta=dict(data.metadata),
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_thread_meta(model)
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
model = result.scalar_one_or_none()
return _to_thread_meta(model) if model else None
async def update_thread_meta(
self,
thread_id: str,
*,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
values: dict = {}
if display_name is not None:
values["display_name"] = display_name
if status is not None:
values["status"] = status
if metadata is not None:
values["meta"] = dict(metadata)
if not values:
return
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
async def delete_thread(self, thread_id: str) -> None:
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
stmt = select(ThreadMetaModel)
if status is not None:
stmt = stmt.where(ThreadMetaModel.status == status)
if user_id is not None:
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
if assistant_id is not None:
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
if metadata:
applied = 0
for key, value in metadata.items():
try:
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
applied += 1
except (ValueError, TypeError) as exc:
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
if applied == 0:
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
stmt = stmt.limit(limit).offset(offset)
result = await self._session.execute(stmt)
return [_to_thread_meta(m) for m in result.scalars().all()]
@@ -1,98 +0,0 @@
from __future__ import annotations
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
from store.repositories.models.user import User as UserModel
def _to_user(model: UserModel) -> User:
return User(
id=model.id,
email=model.email,
password_hash=model.password_hash,
system_role=model.system_role, # type: ignore[arg-type]
created_at=model.created_at,
oauth_provider=model.oauth_provider,
oauth_id=model.oauth_id,
needs_setup=model.needs_setup,
token_version=model.token_version,
)
class DbUserRepository(UserRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_user(self, data: UserCreate) -> User:
model = UserModel(
id=data.id,
email=data.email,
system_role=data.system_role,
password_hash=data.password_hash,
oauth_provider=data.oauth_provider,
oauth_id=data.oauth_id,
needs_setup=data.needs_setup,
token_version=data.token_version,
)
if data.created_at is not None:
model.created_at = data.created_at
self._session.add(model)
try:
await self._session.flush()
except IntegrityError as exc:
await self._session.rollback()
raise ValueError(f"Email already registered: {data.email}") from exc
await self._session.refresh(model)
return _to_user(model)
async def get_user_by_id(self, user_id: str) -> User | None:
model = await self._session.get(UserModel, user_id)
return _to_user(model) if model is not None else None
async def get_user_by_email(self, email: str) -> User | None:
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
result = await self._session.execute(
select(UserModel).where(
UserModel.oauth_provider == provider,
UserModel.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def get_first_admin(self) -> User | None:
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def update_user(self, data: User) -> User:
model = await self._session.get(UserModel, data.id)
if model is None:
raise UserNotFoundError(f"User {data.id} no longer exists")
model.email = data.email
model.password_hash = data.password_hash
model.system_role = data.system_role
model.oauth_provider = data.oauth_provider
model.oauth_id = data.oauth_id
model.needs_setup = data.needs_setup
model.token_version = data.token_version
await self._session.flush()
await self._session.refresh(model)
return _to_user(model)
async def count_users(self) -> int:
count = await self._session.scalar(select(func.count()).select_from(UserModel))
return int(count or 0)
async def count_admin_users(self) -> int:
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
return int(count or 0)
@@ -1,36 +0,0 @@
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories import (
FeedbackRepositoryProtocol,
RunEventRepositoryProtocol,
RunRepositoryProtocol,
ThreadMetaRepositoryProtocol,
UserRepositoryProtocol,
)
from store.repositories.db import (
DbFeedbackRepository,
DbRunEventRepository,
DbRunRepository,
DbThreadMetaRepository,
DbUserRepository,
)
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
return DbThreadMetaRepository(session)
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
return DbRunRepository(session)
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
return DbFeedbackRepository(session)
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
return DbRunEventRepository(session)
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
return DbUserRepository(session)
@@ -1,7 +0,0 @@
from store.repositories.models.feedback import Feedback
from store.repositories.models.run import Run
from store.repositories.models.run_event import RunEvent
from store.repositories.models.thread_meta import ThreadMeta
from store.repositories.models.user import User
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
@@ -1,36 +0,0 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Feedback(DataClassBase):
"""Feedback table (create-only, no updated_time)."""
__tablename__ = "feedback"
__table_args__ = (
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
{"comment": "Feedback table."},
)
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), index=True)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
rating: Mapped[int] = mapped_column(Integer)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -1,63 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Index, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Run(DataClassBase):
"""Run metadata table."""
__tablename__ = "runs"
__table_args__ = (
Index("ix_runs_thread_status", "thread_id", "status"),
{"comment": "Run metadata table."},
)
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
message_count: Mapped[int] = mapped_column(Integer, default=0)
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
"updated_at",
TimeZone,
init=False,
default=None,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -1,46 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import (
DataClassBase,
TimeZone,
UniversalText,
current_time,
id_key,
)
class RunEvent(DataClassBase):
"""Run event table."""
__tablename__ = "run_events"
__table_args__ = (
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
{"comment": "Run event table."},
)
id: Mapped[id_key] = mapped_column(init=False)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
run_id: Mapped[str] = mapped_column(String(64), index=True)
event_type: Mapped[str] = mapped_column(String(32), index=True)
category: Mapped[str] = mapped_column(String(16), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
content: Mapped[str] = mapped_column(UniversalText, default="")
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Event timestamp",
)
@@ -1,43 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class ThreadMeta(DataClassBase):
"""Thread metadata table."""
__tablename__ = "threads_meta"
__table_args__ = {"comment": "Thread metadata table."}
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
"updated_at",
TimeZone,
init=False,
default=None,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -1,42 +0,0 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Boolean, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class User(DataClassBase):
"""User account table."""
__tablename__ = "users"
__table_args__ = (
Index(
"idx_users_oauth_identity",
"oauth_provider",
"oauth_id",
unique=True,
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
),
{"comment": "User account table."},
)
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
system_role: Mapped[str] = mapped_column(String(16), default="user")
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
token_version: Mapped[int] = mapped_column(default=0)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -1,3 +0,0 @@
from .timezone import get_timezone
__all__ = ["get_timezone"]
@@ -1,51 +0,0 @@
import zoneinfo
from datetime import UTC, datetime
from store.config.app_config import get_app_config
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
class TimeZone:
def __init__(self) -> None:
app_config = get_app_config()
if app_config.timezone in _UTC_IDENTIFIERS:
self.tz_info = UTC
else:
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
def now(self) -> datetime:
"""Return the current time in the configured timezone."""
return datetime.now(self.tz_info)
def from_datetime(self, t: datetime) -> datetime:
"""Convert a datetime to the configured timezone."""
return t.astimezone(self.tz_info)
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
"""Parse a time string and attach the configured timezone."""
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
@staticmethod
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""Format a datetime to string."""
return t.strftime(format_str)
@staticmethod
def to_utc(t: datetime | int) -> datetime:
"""Convert a datetime or Unix timestamp to UTC."""
if isinstance(t, datetime):
return t.astimezone(UTC)
return datetime.fromtimestamp(t, tz=UTC)
_timezone = None
def get_timezone() -> TimeZone:
"""Return the global TimeZone singleton (lazy-initialized)."""
global _timezone
if _timezone is None:
_timezone = TimeZone()
return _timezone
+2 -5
View File
@@ -6,7 +6,6 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"deerflow-harness",
"deerflow-storage",
"fastapi>=0.115.0",
"httpx>=0.28.0",
"python-multipart>=0.0.27",
@@ -25,8 +24,7 @@ dependencies = [
]
[project.optional-dependencies]
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
mysql = ["deerflow-storage[mysql]"]
postgres = ["deerflow-harness[postgres]"]
[dependency-groups]
dev = [
@@ -45,8 +43,7 @@ markers = [
index-url = "https://pypi.org/simple"
[tool.uv.workspace]
members = ["packages/harness", "packages/storage"]
members = ["packages/harness"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
deerflow-storage = { workspace = true }
-68
View File
@@ -1,68 +0,0 @@
"""Shared helpers for user-isolation e2e tests on the custom-agent tooling.
Centralises the small fake-LLM shim and a few test-data builders that the
three e2e files in this PR (``test_setup_agent_e2e_user_isolation``,
``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``)
all need. The shim is what lets a real ``langchain.agents.create_agent``
graph run without an API key every other layer in those tests is real
production code, which is the entire point of the test design.
"""
from __future__ import annotations
from typing import Any
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import Runnable
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent.
``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to
expose the tool schemas to the model; the upstream fake raises
``NotImplementedError`` there. We just return ``self`` because we
drive deterministic tool_call output via ``responses=...``, no schema
handling needed.
"""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
def build_single_tool_call_model(
*,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str = "call_e2e_1",
final_text: str = "done",
) -> FakeToolCallingModel:
"""Build a fake model that emits exactly one tool_call then finishes.
Two-turn behaviour, identical across our e2e tests:
turn 1 AIMessage with a single tool_call for *tool_name*
turn 2 AIMessage with *final_text* (terminates the agent loop)
"""
return FakeToolCallingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": tool_name,
"args": tool_args,
"id": tool_call_id,
"type": "tool_call",
}
],
),
AIMessage(content=final_text),
]
)

Some files were not shown because too many files have changed in this diff Show More