Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 27b66d6753 | |||
| 636053fb6d | |||
| f56d0b4869 | |||
| a2cb38f62b | |||
| 3aab2445a6 | |||
| f8fb8d6fb1 | |||
| 2d1f90d5dc | |||
| 3a672b39c7 | |||
| df5339b5d0 | |||
| 0eb6550cf4 | |||
| 0a379602b8 | |||
| 1fb5acee39 | |||
| 82c3dbbc6b | |||
| e97c8c9943 | |||
| 68d44f6755 | |||
| c2ff59a5b1 | |||
| 2f3744f807 | |||
| 52c8c06cf2 | |||
| 0cdecf7b30 |
@@ -6,6 +6,11 @@ JINA_API_KEY=your-jina-api-key
|
||||
|
||||
# InfoQuest API Key
|
||||
INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
# Authentication — JWT secret for session signing
|
||||
# If not set, an ephemeral secret is auto-generated (sessions lost on restart)
|
||||
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
# AUTH_JWT_SECRET=your-secure-jwt-secret-here
|
||||
|
||||
# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001
|
||||
# CORS_ORIGINS=http://localhost:3000
|
||||
|
||||
@@ -32,3 +37,5 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
|
||||
# GitHub API Token
|
||||
# GITHUB_TOKEN=your-github-token
|
||||
# WECOM_BOT_ID=your-wecom-bot-id
|
||||
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
||||
|
||||
@@ -54,3 +54,4 @@ web/
|
||||
# Deployment artifacts
|
||||
backend/Dockerfile.langgraph
|
||||
config.yaml.bak
|
||||
.gstack/
|
||||
|
||||
+1
-1
@@ -310,7 +310,7 @@ Every pull request runs the backend regression workflow at [.github/workflows/ba
|
||||
|
||||
- [Configuration Guide](backend/docs/CONFIGURATION.md) - Setup and configuration
|
||||
- [Architecture Overview](backend/CLAUDE.md) - Technical architecture
|
||||
- [MCP Setup Guide](MCP_SETUP.md) - Model Context Protocol configuration
|
||||
- [MCP Setup Guide](backend/docs/MCP_SERVER.md) - Model Context Protocol configuration
|
||||
|
||||
## Need Help?
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# DeerFlow - Unified Development Environment
|
||||
|
||||
.PHONY: help config config-upgrade check install dev dev-daemon start stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||
.PHONY: help config config-upgrade check install dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||
|
||||
PYTHON ?= python
|
||||
BASH ?= bash
|
||||
|
||||
# Detect OS for Windows compatibility
|
||||
ifeq ($(OS),Windows_NT)
|
||||
SHELL := cmd.exe
|
||||
PYTHON ?= python
|
||||
else
|
||||
PYTHON ?= python3
|
||||
endif
|
||||
|
||||
help:
|
||||
@@ -18,18 +20,25 @@ help:
|
||||
@echo " make install - Install all dependencies (frontend + backend)"
|
||||
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
|
||||
@echo " make dev - Start all services in development mode (with hot-reloading)"
|
||||
@echo " make dev-daemon - Start all services in background (daemon mode)"
|
||||
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
|
||||
@echo " make dev-daemon - Start dev services in background (daemon mode)"
|
||||
@echo " make dev-daemon-pro - Start dev daemon + Gateway mode (experimental)"
|
||||
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
|
||||
@echo " make start-pro - Start in prod + Gateway mode (experimental)"
|
||||
@echo " make start-daemon - Start prod services in background (daemon mode)"
|
||||
@echo " make start-daemon-pro - Start prod daemon + Gateway mode (experimental)"
|
||||
@echo " make stop - Stop all running services"
|
||||
@echo " make clean - Clean up processes and temporary files"
|
||||
@echo ""
|
||||
@echo "Docker Production Commands:"
|
||||
@echo " make up - Build and start production Docker services (localhost:2026)"
|
||||
@echo " make up-pro - Build and start production Docker in Gateway mode (experimental)"
|
||||
@echo " make down - Stop and remove production Docker containers"
|
||||
@echo ""
|
||||
@echo "Docker Development Commands:"
|
||||
@echo " make docker-init - Pull the sandbox image"
|
||||
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
|
||||
@echo " make docker-start-pro - Start Docker in Gateway mode (experimental, no LangGraph container)"
|
||||
@echo " make docker-stop - Stop Docker development services"
|
||||
@echo " make docker-logs - View Docker development logs"
|
||||
@echo " make docker-logs-frontend - View Docker frontend logs"
|
||||
@@ -96,39 +105,79 @@ setup-sandbox:
|
||||
|
||||
# Start all services in development mode (with hot-reloading)
|
||||
dev:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev
|
||||
else
|
||||
@./scripts/serve.sh --dev
|
||||
endif
|
||||
|
||||
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
|
||||
dev-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway
|
||||
else
|
||||
@./scripts/serve.sh --dev --gateway
|
||||
endif
|
||||
|
||||
# Start all services in production mode (with optimizations)
|
||||
start:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod
|
||||
else
|
||||
@./scripts/serve.sh --prod
|
||||
endif
|
||||
|
||||
# Start all services in prod + Gateway mode (experimental)
|
||||
start-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway
|
||||
else
|
||||
@./scripts/serve.sh --prod --gateway
|
||||
endif
|
||||
|
||||
# Start all services in daemon mode (background)
|
||||
dev-daemon:
|
||||
@./scripts/start-daemon.sh
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --daemon
|
||||
else
|
||||
@./scripts/serve.sh --dev --daemon
|
||||
endif
|
||||
|
||||
# Start daemon + Gateway mode (experimental)
|
||||
dev-daemon-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway --daemon
|
||||
else
|
||||
@./scripts/serve.sh --dev --gateway --daemon
|
||||
endif
|
||||
|
||||
# Start prod services in daemon mode (background)
|
||||
start-daemon:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --daemon
|
||||
else
|
||||
@./scripts/serve.sh --prod --daemon
|
||||
endif
|
||||
|
||||
# Start prod daemon + Gateway mode (experimental)
|
||||
start-daemon-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway --daemon
|
||||
else
|
||||
@./scripts/serve.sh --prod --gateway --daemon
|
||||
endif
|
||||
|
||||
# Stop all services
|
||||
stop:
|
||||
@echo "Stopping all services..."
|
||||
@-pkill -f "langgraph dev" 2>/dev/null || true
|
||||
@-pkill -f "uvicorn app.gateway.app:app" 2>/dev/null || true
|
||||
@-pkill -f "next dev" 2>/dev/null || true
|
||||
@-pkill -f "next start" 2>/dev/null || true
|
||||
@-pkill -f "next-server" 2>/dev/null || true
|
||||
@-pkill -f "next-server" 2>/dev/null || true
|
||||
@-nginx -c $(PWD)/docker/nginx/nginx.local.conf -p $(PWD) -s quit 2>/dev/null || true
|
||||
@sleep 1
|
||||
@-pkill -9 nginx 2>/dev/null || true
|
||||
@echo "Cleaning up sandbox containers..."
|
||||
@-./scripts/cleanup-containers.sh deer-flow-sandbox 2>/dev/null || true
|
||||
@echo "✓ All services stopped"
|
||||
@./scripts/serve.sh --stop
|
||||
|
||||
# Clean up
|
||||
clean: stop
|
||||
@@ -150,6 +199,10 @@ docker-init:
|
||||
docker-start:
|
||||
@./scripts/docker.sh start
|
||||
|
||||
# Start Docker in Gateway mode (experimental)
|
||||
docker-start-pro:
|
||||
@./scripts/docker.sh start --gateway
|
||||
|
||||
# Stop Docker development environment
|
||||
docker-stop:
|
||||
@./scripts/docker.sh stop
|
||||
@@ -172,6 +225,10 @@ docker-logs-gateway:
|
||||
up:
|
||||
@./scripts/deploy.sh
|
||||
|
||||
# Build and start production services in Gateway mode
|
||||
up-pro:
|
||||
@./scripts/deploy.sh --gateway
|
||||
|
||||
# Stop and remove production containers
|
||||
down:
|
||||
@./scripts/deploy.sh down
|
||||
|
||||
@@ -46,6 +46,7 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
|
||||
- [🦌 DeerFlow - 2.0](#-deerflow---20)
|
||||
- [Official Website](#official-website)
|
||||
- [Coding Plan from ByteDance Volcengine](#coding-plan-from-bytedance-volcengine)
|
||||
- [InfoQuest](#infoquest)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [One-Line Agent Setup](#one-line-agent-setup)
|
||||
@@ -59,6 +60,8 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
- [MCP Server](#mcp-server)
|
||||
- [IM Channels](#im-channels)
|
||||
- [LangSmith Tracing](#langsmith-tracing)
|
||||
- [Langfuse Tracing](#langfuse-tracing)
|
||||
- [Using Both Providers](#using-both-providers)
|
||||
- [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness)
|
||||
- [Core Features](#core-features)
|
||||
- [Skills \& Tools](#skills--tools)
|
||||
@@ -71,6 +74,8 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
- [Embedded Python Client](#embedded-python-client)
|
||||
- [Documentation](#documentation)
|
||||
- [⚠️ Security Notice](#️-security-notice)
|
||||
- [Improper Deployment May Introduce Security Risks](#improper-deployment-may-introduce-security-risks)
|
||||
- [Security Recommendations](#security-recommendations)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
- [Acknowledgments](#acknowledgments)
|
||||
@@ -243,6 +248,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||
If you prefer running services locally:
|
||||
|
||||
Prerequisite: complete the "Configuration" steps above first (`make config` and model API keys). `make dev` requires a valid configuration file (defaults to `config.yaml` in the project root; can be overridden via `DEER_FLOW_CONFIG_PATH`).
|
||||
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
||||
|
||||
1. **Check prerequisites**:
|
||||
```bash
|
||||
@@ -274,6 +280,60 @@ Prerequisite: complete the "Configuration" steps above first (`make config` and
|
||||
|
||||
6. **Access**: http://localhost:2026
|
||||
|
||||
#### Startup Modes
|
||||
|
||||
DeerFlow supports multiple startup modes across two dimensions:
|
||||
|
||||
- **Dev / Prod** — dev enables hot-reload; prod uses pre-built frontend
|
||||
- **Standard / Gateway** — standard uses a separate LangGraph server (4 processes); Gateway mode (experimental) embeds the agent runtime in the Gateway API (3 processes)
|
||||
|
||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||
|---|---|---|---|---|
|
||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||
|
||||
| Action | Local | Docker Dev | Docker Prod |
|
||||
|---|---|---|---|
|
||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||
|
||||
> **Gateway mode** eliminates the LangGraph server process — the Gateway API handles agent execution directly via async tasks, managing its own concurrency.
|
||||
|
||||
#### Why Gateway Mode?
|
||||
|
||||
In standard mode, DeerFlow runs a dedicated [LangGraph Platform](https://langchain-ai.github.io/langgraph/) server alongside the Gateway API. This architecture works well but has trade-offs:
|
||||
|
||||
| | Standard Mode | Gateway Mode |
|
||||
|---|---|---|
|
||||
| **Architecture** | Gateway (REST API) + LangGraph (agent runtime) | Gateway embeds agent runtime |
|
||||
| **Concurrency** | `--n-jobs-per-worker` per worker (requires license) | `--workers` × async tasks (no per-worker cap) |
|
||||
| **Containers / Processes** | 4 (frontend, gateway, langgraph, nginx) | 3 (frontend, gateway, nginx) |
|
||||
| **Resource usage** | Higher (two Python runtimes) | Lower (single Python runtime) |
|
||||
| **LangGraph Platform license** | Required for production images | Not required |
|
||||
| **Cold start** | Slower (two services to initialize) | Faster |
|
||||
|
||||
Both modes are functionally equivalent — the same agents, tools, and skills work in either mode.
|
||||
|
||||
#### Docker Production Deployment
|
||||
|
||||
`deploy.sh` supports building and starting separately. Images are mode-agnostic — runtime mode is selected at start time:
|
||||
|
||||
```bash
|
||||
# One-step (build + start)
|
||||
deploy.sh # standard mode (default)
|
||||
deploy.sh --gateway # gateway mode
|
||||
|
||||
# Two-step (build once, start with any mode)
|
||||
deploy.sh build # build all images
|
||||
deploy.sh start # start in standard mode
|
||||
deploy.sh start --gateway # start in gateway mode
|
||||
|
||||
# Stop
|
||||
deploy.sh down
|
||||
```
|
||||
|
||||
### Advanced
|
||||
#### Sandbox Mode
|
||||
|
||||
@@ -301,6 +361,7 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
|
||||
| Telegram | Bot API (long-polling) | Easy |
|
||||
| Slack | Socket Mode | Moderate |
|
||||
| Feishu / Lark | WebSocket | Moderate |
|
||||
| WeCom | WebSocket | Moderate |
|
||||
|
||||
**Configuration in `config.yaml`:**
|
||||
|
||||
@@ -328,6 +389,11 @@ channels:
|
||||
# domain: https://open.feishu.cn # China (default)
|
||||
# domain: https://open.larksuite.com # International
|
||||
|
||||
wecom:
|
||||
enabled: true
|
||||
bot_id: $WECOM_BOT_ID
|
||||
bot_secret: $WECOM_BOT_SECRET
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
bot_token: $SLACK_BOT_TOKEN # xoxb-...
|
||||
@@ -371,6 +437,10 @@ SLACK_APP_TOKEN=xapp-...
|
||||
# Feishu / Lark
|
||||
FEISHU_APP_ID=cli_xxxx
|
||||
FEISHU_APP_SECRET=your_app_secret
|
||||
|
||||
# WeCom
|
||||
WECOM_BOT_ID=your_bot_id
|
||||
WECOM_BOT_SECRET=your_bot_secret
|
||||
```
|
||||
|
||||
**Telegram Setup**
|
||||
@@ -393,6 +463,14 @@ FEISHU_APP_SECRET=your_app_secret
|
||||
3. Under **Events**, subscribe to `im.message.receive_v1` and select **Long Connection** mode.
|
||||
4. Copy the App ID and App Secret. Set `FEISHU_APP_ID` and `FEISHU_APP_SECRET` in `.env` and enable the channel in `config.yaml`.
|
||||
|
||||
**WeCom Setup**
|
||||
|
||||
1. Create a bot on the WeCom AI Bot platform and obtain the `bot_id` and `bot_secret`.
|
||||
2. Enable `channels.wecom` in `config.yaml` and fill in `bot_id` / `bot_secret`.
|
||||
3. Set `WECOM_BOT_ID` and `WECOM_BOT_SECRET` in `.env`.
|
||||
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
|
||||
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
|
||||
|
||||
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://langgraph:2024` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||
|
||||
**Commands**
|
||||
@@ -422,6 +500,27 @@ LANGSMITH_API_KEY=lsv2_pt_xxxxxxxxxxxxxxxx
|
||||
LANGSMITH_PROJECT=xxx
|
||||
```
|
||||
|
||||
#### Langfuse Tracing
|
||||
|
||||
DeerFlow also supports [Langfuse](https://langfuse.com) observability for LangChain-compatible runs.
|
||||
|
||||
Add the following to your `.env` file:
|
||||
|
||||
```bash
|
||||
LANGFUSE_TRACING=true
|
||||
LANGFUSE_PUBLIC_KEY=pk-lf-xxxxxxxxxxxxxxxx
|
||||
LANGFUSE_SECRET_KEY=sk-lf-xxxxxxxxxxxxxxxx
|
||||
LANGFUSE_BASE_URL=https://cloud.langfuse.com
|
||||
```
|
||||
|
||||
If you are using a self-hosted Langfuse instance, set `LANGFUSE_BASE_URL` to your deployment URL.
|
||||
|
||||
#### Using Both Providers
|
||||
|
||||
If both LangSmith and Langfuse are enabled, DeerFlow attaches both tracing callbacks and reports the same model activity to both systems.
|
||||
|
||||
If a provider is explicitly enabled but missing required credentials, or if its callback fails to initialize, DeerFlow fails fast when tracing is initialized during model creation and the error message names the provider that caused the failure.
|
||||
|
||||
For Docker deployments, tracing is disabled by default. Set `LANGSMITH_TRACING=true` and `LANGSMITH_API_KEY` in your `.env` to enable it.
|
||||
|
||||
## From Deep Research to Super Agent Harness
|
||||
|
||||
@@ -180,6 +180,7 @@ make down # 停止并移除容器
|
||||
如果你更希望直接在本地启动各个服务:
|
||||
|
||||
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`,也可以通过 `DEER_FLOW_CONFIG_PATH` 覆盖。
|
||||
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
|
||||
|
||||
1. **检查依赖环境**:
|
||||
```bash
|
||||
@@ -231,6 +232,7 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
|
||||
| Telegram | Bot API(long-polling) | 简单 |
|
||||
| Slack | Socket Mode | 中等 |
|
||||
| Feishu / Lark | WebSocket | 中等 |
|
||||
| 企业微信智能机器人 | WebSocket | 中等 |
|
||||
|
||||
**`config.yaml` 中的配置示例:**
|
||||
|
||||
@@ -258,6 +260,11 @@ channels:
|
||||
# domain: https://open.feishu.cn # 国内版(默认)
|
||||
# domain: https://open.larksuite.com # 国际版
|
||||
|
||||
wecom:
|
||||
enabled: true
|
||||
bot_id: $WECOM_BOT_ID
|
||||
bot_secret: $WECOM_BOT_SECRET
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
bot_token: $SLACK_BOT_TOKEN # xoxb-...
|
||||
@@ -301,6 +308,10 @@ SLACK_APP_TOKEN=xapp-...
|
||||
# Feishu / Lark
|
||||
FEISHU_APP_ID=cli_xxxx
|
||||
FEISHU_APP_SECRET=your_app_secret
|
||||
|
||||
# 企业微信智能机器人
|
||||
WECOM_BOT_ID=your_bot_id
|
||||
WECOM_BOT_SECRET=your_bot_secret
|
||||
```
|
||||
|
||||
**Telegram 配置**
|
||||
@@ -323,6 +334,14 @@ FEISHU_APP_SECRET=your_app_secret
|
||||
3. 在 **事件订阅** 中订阅 `im.message.receive_v1`,连接方式选择 **长连接**。
|
||||
4. 复制 App ID 和 App Secret,在 `.env` 中设置 `FEISHU_APP_ID` 和 `FEISHU_APP_SECRET`,并在 `config.yaml` 中启用该渠道。
|
||||
|
||||
**企业微信智能机器人配置**
|
||||
|
||||
1. 在企业微信智能机器人平台创建机器人,获取 `bot_id` 和 `bot_secret`。
|
||||
2. 在 `config.yaml` 中启用 `channels.wecom`,并填入 `bot_id` / `bot_secret`。
|
||||
3. 在 `.env` 中设置 `WECOM_BOT_ID` 和 `WECOM_BOT_SECRET`。
|
||||
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
|
||||
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
|
||||
|
||||
**命令**
|
||||
|
||||
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
|
||||
|
||||
+25
-2
@@ -13,6 +13,10 @@ DeerFlow is a LangGraph-based AI super agent system with a full-stack architectu
|
||||
- **Nginx** (port 2026): Unified reverse proxy entry point
|
||||
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
|
||||
|
||||
**Runtime Modes**:
|
||||
- **Standard mode** (`make dev`): LangGraph Server handles agent execution as a separate process. 4 processes total.
|
||||
- **Gateway mode** (`make dev-pro`, experimental): Agent runtime embedded in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Service manages its own concurrency via async tasks. 3 processes total, no LangGraph Server.
|
||||
|
||||
**Project Structure**:
|
||||
```
|
||||
deer-flow/
|
||||
@@ -80,6 +84,8 @@ When making code changes, you MUST update the relevant documentation:
|
||||
make check # Check system requirements
|
||||
make install # Install all dependencies (frontend + backend)
|
||||
make dev # Start all services (LangGraph + Gateway + Frontend + Nginx), with config.yaml preflight
|
||||
make dev-pro # Gateway mode (experimental): skip LangGraph, agent runtime embedded in Gateway
|
||||
make start-pro # Production + Gateway mode (experimental)
|
||||
make stop # Stop all services
|
||||
```
|
||||
|
||||
@@ -232,7 +238,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `ls` - Directory listing (tree format, max 2 levels)
|
||||
- `read_file` - Read file contents with optional line range
|
||||
- `write_file` - Write/append to files, creates directories
|
||||
- `str_replace` - Substring replacement (single or all occurrences)
|
||||
- `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process
|
||||
|
||||
### Subagent System (`packages/harness/deerflow/subagents/`)
|
||||
|
||||
@@ -436,8 +442,25 @@ make dev
|
||||
|
||||
This starts all services and makes the application available at `http://localhost:2026`.
|
||||
|
||||
**All startup modes:**
|
||||
|
||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||
|---|---|---|---|---|
|
||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||
|
||||
| Action | Local | Docker Dev | Docker Prod |
|
||||
|---|---|---|---|
|
||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||
|
||||
Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
|
||||
|
||||
**Nginx routing**:
|
||||
- `/api/langgraph/*` → LangGraph Server (2024)
|
||||
- Standard mode: `/api/langgraph/*` → LangGraph Server (2024)
|
||||
- Gateway mode: `/api/langgraph/*` → Gateway embedded runtime (8001) (via envsubst)
|
||||
- `/api/*` (other) → Gateway API (8001)
|
||||
- `/` (non-API) → Frontend (3000)
|
||||
|
||||
|
||||
+63
-15
@@ -1,34 +1,86 @@
|
||||
# Backend Development Dockerfile
|
||||
# Backend Dockerfile — multi-stage build
|
||||
# Stage 1 (builder): compiles native Python extensions with build-essential
|
||||
# Stage 2 (dev): retains toolchain for dev containers (uv sync at startup)
|
||||
# Stage 3 (runtime): clean image without compiler toolchain for production
|
||||
|
||||
# UV source image (override for restricted networks that cannot reach ghcr.io)
|
||||
ARG UV_IMAGE=ghcr.io/astral-sh/uv:0.7.20
|
||||
FROM ${UV_IMAGE} AS uv-source
|
||||
|
||||
FROM python:3.12-slim-bookworm
|
||||
# ── Stage 1: Builder ──────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim-bookworm AS builder
|
||||
|
||||
ARG NODE_MAJOR=22
|
||||
ARG NODE_VERSION=22.16.0
|
||||
ARG APT_MIRROR
|
||||
ARG UV_INDEX_URL
|
||||
ARG NODE_DIST_URL
|
||||
|
||||
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
|
||||
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.byted.org)
|
||||
RUN if [ -n "${APT_MIRROR}" ]; then \
|
||||
sed -i "s|deb.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list.d/debian.sources 2>/dev/null || true; \
|
||||
sed -i "s|deb.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list 2>/dev/null || true; \
|
||||
fi
|
||||
|
||||
# Install system dependencies + Node.js (provides npx for MCP servers)
|
||||
# Install build tools + Node.js (build-essential needed for native Python extensions)
|
||||
# NODE_DIST_URL: base URL for Node.js binary tarballs in restricted networks.
|
||||
# npmmirror: https://registry.npmmirror.com/-/binary/node
|
||||
# official: https://nodejs.org/dist (default, via nodesource apt)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
build-essential \
|
||||
gnupg \
|
||||
ca-certificates \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" > /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y nodejs \
|
||||
xz-utils \
|
||||
&& if [ -n "${NODE_DIST_URL}" ]; then \
|
||||
curl -fsSL "${NODE_DIST_URL}/v${NODE_VERSION}/node-v${NODE_VERSION}-linux-x64.tar.xz" \
|
||||
| tar -xJ --strip-components=1 -C /usr/local \
|
||||
&& ln -sf /usr/local/bin/node /usr/bin/node \
|
||||
&& ln -sf /usr/local/lib/node_modules /usr/lib/node_modules; \
|
||||
else \
|
||||
mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" > /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y nodejs; \
|
||||
fi \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv (source image overridable via UV_IMAGE build arg)
|
||||
COPY --from=uv-source /uv /uvx /usr/local/bin/
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy backend source code
|
||||
COPY backend ./backend
|
||||
|
||||
# Install dependencies with cache mount
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
|
||||
|
||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||
# source distributions in development containers.
|
||||
FROM builder AS dev
|
||||
|
||||
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
|
||||
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
|
||||
|
||||
EXPOSE 8001 2024
|
||||
|
||||
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
||||
|
||||
# ── Stage 3: Runtime ──────────────────────────────────────────────────────────
|
||||
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
|
||||
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
|
||||
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
|
||||
|
||||
@@ -38,12 +90,8 @@ COPY --from=uv-source /uv /uvx /usr/local/bin/
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy frontend source code
|
||||
COPY backend ./backend
|
||||
|
||||
# Install dependencies with cache mount
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
|
||||
# Copy backend with pre-built virtualenv from builder
|
||||
COPY --from=builder /app/backend ./backend
|
||||
|
||||
# Expose ports (gateway: 8001, langgraph: 2024)
|
||||
EXPOSE 8001 2024
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@ install:
|
||||
uv sync
|
||||
|
||||
dev:
|
||||
uv run langgraph dev --no-browser --allow-blocking --no-reload
|
||||
uv run langgraph dev --no-browser --no-reload --n-jobs-per-worker 10
|
||||
|
||||
gateway:
|
||||
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||
|
||||
+23
-1
@@ -78,6 +78,7 @@ Per-thread isolated execution with virtual path translation:
|
||||
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||
- **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match
|
||||
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
|
||||
|
||||
### Subagent System
|
||||
@@ -330,7 +331,28 @@ LANGSMITH_PROJECT=xxx
|
||||
|
||||
**Legacy variables:** The `LANGCHAIN_TRACING_V2`, `LANGCHAIN_API_KEY`, `LANGCHAIN_PROJECT`, and `LANGCHAIN_ENDPOINT` variables are also supported for backward compatibility. `LANGSMITH_*` variables take precedence when both are set.
|
||||
|
||||
**Docker:** In `docker-compose.yaml`, tracing is disabled by default (`LANGSMITH_TRACING=false`). Set `LANGSMITH_TRACING=true` and provide `LANGSMITH_API_KEY` in your `.env` to enable it in containerized deployments.
|
||||
### Langfuse Tracing
|
||||
|
||||
DeerFlow also supports [Langfuse](https://langfuse.com) observability for LangChain-compatible runs.
|
||||
|
||||
Add the following to your `.env` file:
|
||||
|
||||
```bash
|
||||
LANGFUSE_TRACING=true
|
||||
LANGFUSE_PUBLIC_KEY=pk-lf-xxxxxxxxxxxxxxxx
|
||||
LANGFUSE_SECRET_KEY=sk-lf-xxxxxxxxxxxxxxxx
|
||||
LANGFUSE_BASE_URL=https://cloud.langfuse.com
|
||||
```
|
||||
|
||||
If you are using a self-hosted Langfuse deployment, set `LANGFUSE_BASE_URL` to your Langfuse host.
|
||||
|
||||
### Dual Provider Behavior
|
||||
|
||||
If both LangSmith and Langfuse are enabled, DeerFlow initializes and attaches both callbacks so the same run data is reported to both systems.
|
||||
|
||||
If a provider is explicitly enabled but required credentials are missing, or the provider callback cannot be initialized, DeerFlow raises an error when tracing is initialized during model creation instead of silently disabling tracing.
|
||||
|
||||
**Docker:** In `docker-compose.yaml`, tracing is disabled by default (`LANGSMITH_TRACING=false`). Set `LANGSMITH_TRACING=true` and/or `LANGFUSE_TRACING=true` in your `.env`, together with the required credentials, to enable tracing in containerized deployments.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Shared command definitions used by all channel implementations.
|
||||
|
||||
Keeping the authoritative command set in one place ensures that channel
|
||||
parsers (e.g. Feishu) and the ChannelManager dispatcher stay in sync
|
||||
automatically — adding or removing a command here is the single edit
|
||||
required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/bootstrap",
|
||||
"/new",
|
||||
"/status",
|
||||
"/models",
|
||||
"/memory",
|
||||
"/help",
|
||||
}
|
||||
)
|
||||
@@ -9,11 +9,18 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_feishu_command(text: str) -> bool:
|
||||
if not text.startswith("/"):
|
||||
return False
|
||||
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
||||
|
||||
|
||||
class FeishuChannel(Channel):
|
||||
"""Feishu/Lark IM channel using the ``lark-oapi`` WebSocket client.
|
||||
|
||||
@@ -199,7 +206,9 @@ class FeishuChannel(Channel):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
|
||||
raise last_exc # type: ignore[misc]
|
||||
if last_exc is None:
|
||||
raise RuntimeError("Feishu send failed without an exception from any attempt")
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._api_client:
|
||||
@@ -509,8 +518,9 @@ class FeishuChannel(Channel):
|
||||
logger.info("[Feishu] empty text, ignoring message")
|
||||
return
|
||||
|
||||
# Check if it's a command
|
||||
if text.startswith("/"):
|
||||
# Only treat known slash commands as commands; absolute paths and
|
||||
# other slash-prefixed text should be handled as normal chat.
|
||||
if _is_feishu_command(text):
|
||||
msg_type = InboundMessageType.COMMAND
|
||||
else:
|
||||
msg_type = InboundMessageType.CHAT
|
||||
|
||||
@@ -7,11 +7,13 @@ import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langgraph_sdk.errors import ConflictError
|
||||
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
@@ -35,8 +37,49 @@ CHANNEL_CAPABILITIES = {
|
||||
"feishu": {"supports_streaming": True},
|
||||
"slack": {"supports_streaming": False},
|
||||
"telegram": {"supports_streaming": False},
|
||||
"wecom": {"supports_streaming": True},
|
||||
}
|
||||
|
||||
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
|
||||
|
||||
|
||||
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
|
||||
|
||||
|
||||
def register_inbound_file_reader(channel_name: str, reader: InboundFileReader) -> None:
|
||||
INBOUND_FILE_READERS[channel_name] = reader
|
||||
|
||||
|
||||
async def _read_http_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
|
||||
url = file_info.get("url")
|
||||
if not isinstance(url, str) or not url:
|
||||
return None
|
||||
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
|
||||
async def _read_wecom_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
|
||||
data = await _read_http_inbound_file(file_info, client)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
aeskey = file_info.get("aeskey") if isinstance(file_info.get("aeskey"), str) else None
|
||||
if not aeskey:
|
||||
return data
|
||||
|
||||
try:
|
||||
from aibot.crypto_utils import decrypt_file
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to import WeCom decrypt_file")
|
||||
return None
|
||||
|
||||
return decrypt_file(data, aeskey)
|
||||
|
||||
|
||||
register_inbound_file_reader("wecom", _read_wecom_inbound_file)
|
||||
|
||||
|
||||
class InvalidChannelSessionConfigError(ValueError):
|
||||
"""Raised when IM channel session overrides contain invalid agent config."""
|
||||
@@ -341,6 +384,105 @@ def _prepare_artifact_delivery(
|
||||
return response_text, attachments
|
||||
|
||||
|
||||
async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dict[str, Any]]:
|
||||
if not msg.files:
|
||||
return []
|
||||
|
||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
|
||||
created: list[dict[str, Any]] = []
|
||||
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
|
||||
for idx, f in enumerate(msg.files):
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
|
||||
ftype = f.get("type") if isinstance(f.get("type"), str) else "file"
|
||||
filename = f.get("filename") if isinstance(f.get("filename"), str) else ""
|
||||
|
||||
try:
|
||||
data = await file_reader(f, client)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[Manager] failed to read inbound file: channel=%s, file=%s",
|
||||
msg.channel_name,
|
||||
f.get("url") or filename or idx,
|
||||
)
|
||||
continue
|
||||
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"[Manager] inbound file reader returned no data: channel=%s, file=%s",
|
||||
msg.channel_name,
|
||||
f.get("url") or filename or idx,
|
||||
)
|
||||
continue
|
||||
|
||||
if not filename:
|
||||
ext = ".bin"
|
||||
if ftype == "image":
|
||||
ext = ".png"
|
||||
filename = f"{msg.thread_ts or 'msg'}_{idx}{ext}"
|
||||
|
||||
try:
|
||||
safe_name = claim_unique_filename(normalize_filename(filename), seen_names)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[Manager] skipping inbound file with unsafe filename: channel=%s, file=%r",
|
||||
msg.channel_name,
|
||||
filename,
|
||||
)
|
||||
continue
|
||||
|
||||
dest = uploads_dir / safe_name
|
||||
try:
|
||||
dest.write_bytes(data)
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||
continue
|
||||
|
||||
created.append(
|
||||
{
|
||||
"filename": safe_name,
|
||||
"size": len(data),
|
||||
"path": f"/mnt/user-data/uploads/{safe_name}",
|
||||
"is_image": ftype == "image",
|
||||
}
|
||||
)
|
||||
|
||||
return created
|
||||
|
||||
|
||||
def _format_uploaded_files_block(files: list[dict[str, Any]]) -> str:
|
||||
lines = [
|
||||
"<uploaded_files>",
|
||||
"The following files were uploaded in this message:",
|
||||
"",
|
||||
]
|
||||
if not files:
|
||||
lines.append("(empty)")
|
||||
else:
|
||||
for f in files:
|
||||
filename = f.get("filename", "")
|
||||
size = int(f.get("size") or 0)
|
||||
size_kb = size / 1024 if size else 0
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
path = f.get("path", "")
|
||||
is_image = bool(f.get("is_image"))
|
||||
file_kind = "image" if is_image else "file"
|
||||
lines.append(f"- {filename} ({size_str})")
|
||||
lines.append(f" Type: {file_kind}")
|
||||
lines.append(f" Path: {path}")
|
||||
lines.append("")
|
||||
lines.append("Use `read_file` for text-based files and documents.")
|
||||
lines.append("Use `view_image` for image files (jpg, jpeg, png, webp) so the model can inspect the image content.")
|
||||
lines.append("</uploaded_files>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
||||
|
||||
@@ -535,6 +677,11 @@ class ChannelManager:
|
||||
assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id)
|
||||
if extra_context:
|
||||
run_context.update(extra_context)
|
||||
|
||||
uploaded = await _ingest_inbound_files(thread_id, msg)
|
||||
if uploaded:
|
||||
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
|
||||
|
||||
if self._channel_supports_streaming(msg.channel_name):
|
||||
await self._handle_streaming_chat(
|
||||
client,
|
||||
@@ -735,7 +882,8 @@ class ChannelManager:
|
||||
"/help — Show this help"
|
||||
)
|
||||
else:
|
||||
reply = f"Unknown command: /{command}. Type /help for available commands."
|
||||
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
|
||||
reply = f"Unknown command: /{command}. Available commands: {available}"
|
||||
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
|
||||
@@ -17,6 +17,7 @@ _CHANNEL_REGISTRY: dict[str, str] = {
|
||||
"feishu": "app.channels.feishu:FeishuChannel",
|
||||
"slack": "app.channels.slack:SlackChannel",
|
||||
"telegram": "app.channels.telegram:TelegramChannel",
|
||||
"wecom": "app.channels.wecom:WeComChannel",
|
||||
}
|
||||
|
||||
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||
|
||||
@@ -30,7 +30,7 @@ class SlackChannel(Channel):
|
||||
self._socket_client = None
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users: set[str] = set(config.get("allowed_users", []))
|
||||
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -126,7 +126,9 @@ class SlackChannel(Channel):
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise last_exc # type: ignore[misc]
|
||||
if last_exc is None:
|
||||
raise RuntimeError("Slack send failed without an exception from any attempt")
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._web_client:
|
||||
|
||||
@@ -125,7 +125,9 @@ class TelegramChannel(Channel):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc)
|
||||
raise last_exc # type: ignore[misc]
|
||||
if last_exc is None:
|
||||
raise RuntimeError("Telegram send failed without an exception from any attempt")
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._application:
|
||||
|
||||
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import (
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
ResolvedAttachment,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeComChannel(Channel):
|
||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||
super().__init__(name="wecom", bus=bus, config=config)
|
||||
self._bot_id: str | None = None
|
||||
self._bot_secret: str | None = None
|
||||
self._ws_client = None
|
||||
self._ws_task: asyncio.Task | None = None
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
|
||||
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
||||
if not thread_ts:
|
||||
return
|
||||
self._ws_frames.pop(thread_ts, None)
|
||||
self._ws_stream_ids.pop(thread_ts, None)
|
||||
|
||||
async def _send_ws_upload_command(self, req_id: str, body: dict[str, Any], cmd: str) -> dict[str, Any]:
|
||||
if not self._ws_client:
|
||||
raise RuntimeError("WeCom WebSocket client is not available")
|
||||
|
||||
ws_manager = getattr(self._ws_client, "_ws_manager", None)
|
||||
send_reply = getattr(ws_manager, "send_reply", None)
|
||||
if not callable(send_reply):
|
||||
raise RuntimeError("Installed wecom-aibot-python-sdk does not expose the WebSocket media upload API expected by DeerFlow. Use wecom-aibot-python-sdk==0.1.6 or update the adapter.")
|
||||
|
||||
send_reply_async = cast(Callable[[str, dict[str, Any], str], Awaitable[dict[str, Any]]], send_reply)
|
||||
return await send_reply_async(req_id, body, cmd)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
bot_id = self.config.get("bot_id")
|
||||
bot_secret = self.config.get("bot_secret")
|
||||
working_message = self.config.get("working_message")
|
||||
|
||||
self._bot_id = bot_id if isinstance(bot_id, str) and bot_id else None
|
||||
self._bot_secret = bot_secret if isinstance(bot_secret, str) and bot_secret else None
|
||||
self._working_message = working_message if isinstance(working_message, str) and working_message else "Working on it..."
|
||||
|
||||
if not self._bot_id or not self._bot_secret:
|
||||
logger.error("WeCom channel requires bot_id and bot_secret")
|
||||
return
|
||||
|
||||
try:
|
||||
from aibot import WSClient, WSClientOptions
|
||||
except ImportError:
|
||||
logger.error("wecom-aibot-python-sdk is not installed. Install it with: uv add wecom-aibot-python-sdk")
|
||||
return
|
||||
else:
|
||||
self._ws_client = WSClient(WSClientOptions(bot_id=self._bot_id, secret=self._bot_secret, logger=logger))
|
||||
self._ws_client.on("message.text", self._on_ws_text)
|
||||
self._ws_client.on("message.mixed", self._on_ws_mixed)
|
||||
self._ws_client.on("message.image", self._on_ws_image)
|
||||
self._ws_client.on("message.file", self._on_ws_file)
|
||||
self._ws_task = asyncio.create_task(self._ws_client.connect())
|
||||
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("WeCom channel started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||
if self._ws_task:
|
||||
try:
|
||||
self._ws_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws_task = None
|
||||
if self._ws_client:
|
||||
try:
|
||||
self._ws_client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws_client = None
|
||||
self._ws_frames.clear()
|
||||
self._ws_stream_ids.clear()
|
||||
logger.info("WeCom channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if self._ws_client:
|
||||
await self._send_ws(msg, _max_retries=_max_retries)
|
||||
return
|
||||
logger.warning("[WeCom] send called but WebSocket client is not available")
|
||||
|
||||
async def _on_outbound(self, msg: OutboundMessage) -> None:
|
||||
if msg.channel_name != self.name:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.send(msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to send outbound message on channel %s", self.name)
|
||||
if msg.is_final:
|
||||
self._clear_ws_context(msg.thread_ts)
|
||||
return
|
||||
|
||||
for attachment in msg.attachments:
|
||||
try:
|
||||
success = await self.send_file(msg, attachment)
|
||||
if not success:
|
||||
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
|
||||
except Exception:
|
||||
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
|
||||
|
||||
if msg.is_final:
|
||||
self._clear_ws_context(msg.thread_ts)
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not msg.is_final:
|
||||
return True
|
||||
if not self._ws_client:
|
||||
return False
|
||||
if not msg.thread_ts:
|
||||
return False
|
||||
frame = self._ws_frames.get(msg.thread_ts)
|
||||
if not frame:
|
||||
return False
|
||||
|
||||
media_type = "image" if attachment.is_image else "file"
|
||||
size_limit = 2 * 1024 * 1024 if attachment.is_image else 20 * 1024 * 1024
|
||||
if attachment.size > size_limit:
|
||||
logger.warning(
|
||||
"[WeCom] %s too large (%d bytes), skipping: %s",
|
||||
media_type,
|
||||
attachment.size,
|
||||
attachment.filename,
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
media_id = await self._upload_media_ws(
|
||||
media_type=media_type,
|
||||
filename=attachment.filename,
|
||||
path=str(attachment.actual_path),
|
||||
size=attachment.size,
|
||||
)
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
body = {media_type: {"media_id": media_id}, "msgtype": media_type}
|
||||
await self._ws_client.reply(frame, body)
|
||||
logger.debug("[WeCom] %s sent via ws: %s", media_type, attachment.filename)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("[WeCom] failed to upload/send file via ws: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def _on_ws_text(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
text = ((body.get("text") or {}).get("content") or "").strip()
|
||||
quote = body.get("quote", {}).get("text", {}).get("content", "").strip()
|
||||
if not text and not quote:
|
||||
return
|
||||
await self._publish_ws_inbound(frame, text + (f"\nQuote message: {quote}" if quote else ""))
|
||||
|
||||
async def _on_ws_mixed(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
mixed = body.get("mixed") or {}
|
||||
items = mixed.get("msg_item") or []
|
||||
parts: list[str] = []
|
||||
files: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
item_type = (item or {}).get("msgtype")
|
||||
if item_type == "text":
|
||||
content = (((item or {}).get("text") or {}).get("content") or "").strip()
|
||||
if content:
|
||||
parts.append(content)
|
||||
elif item_type in ("image", "file"):
|
||||
payload = (item or {}).get(item_type) or {}
|
||||
url = payload.get("url")
|
||||
aeskey = payload.get("aeskey")
|
||||
if isinstance(url, str) and url:
|
||||
files.append(
|
||||
{
|
||||
"type": item_type,
|
||||
"url": url,
|
||||
"aeskey": (aeskey if isinstance(aeskey, str) and aeskey else None),
|
||||
}
|
||||
)
|
||||
text = "\n\n".join(parts).strip()
|
||||
if not text and not files:
|
||||
return
|
||||
if not text:
|
||||
text = "(receive image/file)"
|
||||
await self._publish_ws_inbound(frame, text, files=files)
|
||||
|
||||
async def _on_ws_image(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
image = body.get("image") or {}
|
||||
url = image.get("url")
|
||||
aeskey = image.get("aeskey")
|
||||
if not isinstance(url, str) or not url:
|
||||
return
|
||||
await self._publish_ws_inbound(
|
||||
frame,
|
||||
"(receive image )",
|
||||
files=[
|
||||
{
|
||||
"type": "image",
|
||||
"url": url,
|
||||
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
async def _on_ws_file(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
file_obj = body.get("file") or {}
|
||||
url = file_obj.get("url")
|
||||
aeskey = file_obj.get("aeskey")
|
||||
if not isinstance(url, str) or not url:
|
||||
return
|
||||
await self._publish_ws_inbound(
|
||||
frame,
|
||||
"(receive file)",
|
||||
files=[
|
||||
{
|
||||
"type": "file",
|
||||
"url": url,
|
||||
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
async def _publish_ws_inbound(
|
||||
self,
|
||||
frame: dict[str, Any],
|
||||
text: str,
|
||||
*,
|
||||
files: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
try:
|
||||
from aibot import generate_req_id
|
||||
except Exception:
|
||||
return
|
||||
|
||||
body = frame.get("body", {}) or {}
|
||||
msg_id = body.get("msgid")
|
||||
if not msg_id:
|
||||
return
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=user_id, # keep user's conversation in memory
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
msg_type=inbound_type,
|
||||
thread_ts=msg_id,
|
||||
files=files or [],
|
||||
metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")},
|
||||
)
|
||||
inbound.topic_id = user_id # keep the same thread
|
||||
|
||||
stream_id = generate_req_id("stream")
|
||||
self._ws_frames[msg_id] = frame
|
||||
self._ws_stream_ids[msg_id] = stream_id
|
||||
|
||||
try:
|
||||
await self._ws_client.reply_stream(frame, stream_id, self._working_message, False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
try:
|
||||
from aibot import generate_req_id
|
||||
except Exception:
|
||||
generate_req_id = None
|
||||
|
||||
if msg.thread_ts and msg.thread_ts in self._ws_frames:
|
||||
frame = self._ws_frames[msg.thread_ts]
|
||||
stream_id = self._ws_stream_ids.get(msg.thread_ts)
|
||||
if not stream_id and generate_req_id:
|
||||
stream_id = generate_req_id("stream")
|
||||
self._ws_stream_ids[msg.thread_ts] = stream_id
|
||||
if not stream_id:
|
||||
return
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final))
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _max_retries - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
|
||||
last_exc = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await self._ws_client.send_message(msg.chat_id, body)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _max_retries - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
async def _upload_media_ws(
|
||||
self,
|
||||
*,
|
||||
media_type: str,
|
||||
filename: str,
|
||||
path: str,
|
||||
size: int,
|
||||
) -> str | None:
|
||||
if not self._ws_client:
|
||||
return None
|
||||
try:
|
||||
from aibot import generate_req_id
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
chunk_size = 512 * 1024
|
||||
total_chunks = (size + chunk_size - 1) // chunk_size
|
||||
if total_chunks < 1 or total_chunks > 100:
|
||||
logger.warning("[WeCom] invalid total_chunks=%d for %s", total_chunks, filename)
|
||||
return None
|
||||
|
||||
md5_hasher = hashlib.md5()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
||||
md5_hasher.update(chunk)
|
||||
md5 = md5_hasher.hexdigest()
|
||||
|
||||
init_req_id = generate_req_id("aibot_upload_media_init")
|
||||
init_body = {
|
||||
"type": media_type,
|
||||
"filename": filename,
|
||||
"total_size": int(size),
|
||||
"total_chunks": int(total_chunks),
|
||||
"md5": md5,
|
||||
}
|
||||
init_ack = await self._send_ws_upload_command(init_req_id, init_body, "aibot_upload_media_init")
|
||||
upload_id = (init_ack.get("body") or {}).get("upload_id")
|
||||
if not upload_id:
|
||||
logger.warning("[WeCom] upload init returned no upload_id: %s", init_ack)
|
||||
return None
|
||||
|
||||
with open(path, "rb") as f:
|
||||
for idx in range(total_chunks):
|
||||
data = f.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
chunk_req_id = generate_req_id("aibot_upload_media_chunk")
|
||||
chunk_body = {
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": int(idx),
|
||||
"base64_data": base64.b64encode(data).decode("utf-8"),
|
||||
}
|
||||
await self._send_ws_upload_command(chunk_req_id, chunk_body, "aibot_upload_media_chunk")
|
||||
|
||||
finish_req_id = generate_req_id("aibot_upload_media_finish")
|
||||
finish_ack = await self._send_ws_upload_command(finish_req_id, {"upload_id": upload_id}, "aibot_upload_media_finish")
|
||||
media_id = (finish_ack.get("body") or {}).get("media_id")
|
||||
if not media_id:
|
||||
logger.warning("[WeCom] upload finish returned no media_id: %s", finish_ack)
|
||||
return None
|
||||
return media_id
|
||||
+119
-1
@@ -1,15 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channels,
|
||||
mcp,
|
||||
memory,
|
||||
@@ -33,6 +39,88 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Auto-create the admin user on first boot if no users exist.
|
||||
|
||||
Prints the generated password to stdout so the operator can log in.
|
||||
On subsequent boots, warns if any user still needs setup.
|
||||
|
||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races.
|
||||
Only the worker that successfully creates/updates the admin prints the
|
||||
password; losers silently skip.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
provider = get_local_provider()
|
||||
user_count = await provider.count_users()
|
||||
|
||||
if user_count == 0:
|
||||
password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||
except ValueError:
|
||||
return # Another worker already created the admin.
|
||||
|
||||
# Migrate orphaned threads (no user_id) to this admin
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
await _migrate_orphaned_threads(store, str(admin.id))
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account created on first boot")
|
||||
logger.info(" Email: %s", admin.email)
|
||||
logger.info(" Password: %s", password)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
return
|
||||
|
||||
# Admin exists but setup never completed — reset password so operator
|
||||
# can always find it in the console without needing the CLI.
|
||||
# Multi-worker guard: if admin was created less than 5s ago, another
|
||||
# worker just created it and will print the password — skip reset.
|
||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||
if admin and admin.needs_setup:
|
||||
import time
|
||||
|
||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||
if age < 30:
|
||||
return # Just created by another worker in this startup; its password is still valid.
|
||||
|
||||
from app.gateway.auth.password import hash_password_async
|
||||
|
||||
password = secrets.token_urlsafe(16)
|
||||
admin.password_hash = await hash_password_async(password)
|
||||
admin.token_version += 1
|
||||
await provider.update_user(admin)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account setup incomplete — password reset")
|
||||
logger.info(" Email: %s", admin.email)
|
||||
logger.info(" Password: %s", password)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> None:
|
||||
"""Migrate threads with no user_id to the given admin."""
|
||||
try:
|
||||
migrated = 0
|
||||
results = await store.asearch(("threads",), limit=1000)
|
||||
for item in results:
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("user_id"):
|
||||
metadata["user_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
if migrated:
|
||||
logger.info("Migrated %d orphaned thread(s) to admin", migrated)
|
||||
except Exception:
|
||||
logger.exception("Thread migration failed (non-fatal)")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
@@ -52,6 +140,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Ensure admin user exists (auto-create on first boot)
|
||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||
await _ensure_admin_user(app)
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
@@ -163,7 +255,30 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
],
|
||||
)
|
||||
|
||||
# CORS is handled by nginx - no need for FastAPI middleware
|
||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware
|
||||
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||
if cors_origins_env:
|
||||
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||
for origin in cors_origins:
|
||||
if origin == "*":
|
||||
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
|
||||
cors_origins = [o for o in cors_origins if o != "*"]
|
||||
break
|
||||
if cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
@@ -199,6 +314,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Auth API is mounted at /api/v1/auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Authentication module for DeerFlow.
|
||||
|
||||
This module provides:
|
||||
- JWT-based authentication
|
||||
- Provider Factory pattern for extensible auth methods
|
||||
- UserRepository interface for storage backends (SQLite)
|
||||
"""
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.models import User, UserResponse
|
||||
from app.gateway.auth.password import hash_password, verify_password
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"AuthConfig",
|
||||
"get_auth_config",
|
||||
"set_auth_config",
|
||||
# Errors
|
||||
"AuthErrorCode",
|
||||
"AuthErrorResponse",
|
||||
"TokenError",
|
||||
# JWT
|
||||
"TokenPayload",
|
||||
"create_access_token",
|
||||
"decode_token",
|
||||
# Password
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
# Models
|
||||
"User",
|
||||
"UserResponse",
|
||||
# Providers
|
||||
"AuthProvider",
|
||||
"LocalAuthProvider",
|
||||
# Repository
|
||||
"UserRepository",
|
||||
]
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Authentication configuration for DeerFlow."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup."""
|
||||
|
||||
jwt_secret: str = Field(
|
||||
...,
|
||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
||||
)
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
users_db_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to users SQLite DB. Defaults to .deer-flow/users.db",
|
||||
)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
if _auth_config is None:
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
||||
return _auth_config
|
||||
|
||||
|
||||
def set_auth_config(config: AuthConfig) -> None:
|
||||
"""Set the global AuthConfig instance (for testing)."""
|
||||
global _auth_config
|
||||
_auth_config = config
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Typed error definitions for auth module.
|
||||
|
||||
AuthErrorCode: exhaustive enum of all auth failure conditions.
|
||||
TokenError: exhaustive enum of JWT decode failures.
|
||||
AuthErrorResponse: structured error payload for HTTP responses.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthErrorCode(StrEnum):
|
||||
"""Exhaustive list of auth error conditions."""
|
||||
|
||||
INVALID_CREDENTIALS = "invalid_credentials"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
TOKEN_INVALID = "token_invalid"
|
||||
USER_NOT_FOUND = "user_not_found"
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
"""Exhaustive list of JWT decode failure reasons."""
|
||||
|
||||
EXPIRED = "expired"
|
||||
INVALID_SIGNATURE = "invalid_signature"
|
||||
MALFORMED = "malformed"
|
||||
|
||||
|
||||
class AuthErrorResponse(BaseModel):
|
||||
"""Structured error response — replaces bare `detail` strings."""
|
||||
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||
"""Map TokenError to AuthErrorCode — single source of truth."""
|
||||
if err == TokenError.EXPIRED:
|
||||
return AuthErrorCode.TOKEN_EXPIRED
|
||||
return AuthErrorCode.TOKEN_INVALID
|
||||
@@ -0,0 +1,55 @@
|
||||
"""JWT token creation and verification."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT token payload."""
|
||||
|
||||
sub: str # user_id
|
||||
exp: datetime
|
||||
iat: datetime | None = None
|
||||
ver: int = 0 # token_version — must match User.token_version
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID as string
|
||||
expires_delta: Optional custom expiry, defaults to 7 days
|
||||
token_version: User's current token_version for invalidation
|
||||
|
||||
Returns:
|
||||
Encoded JWT string
|
||||
"""
|
||||
config = get_auth_config()
|
||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||
"""Decode and validate a JWT token.
|
||||
|
||||
Returns:
|
||||
TokenPayload if valid, or a specific TokenError variant.
|
||||
"""
|
||||
config = get_auth_config()
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
return TokenError.EXPIRED
|
||||
except jwt.InvalidSignatureError:
|
||||
return TokenError.INVALID_SIGNATURE
|
||||
except jwt.PyJWTError:
|
||||
return TokenError.MALFORMED
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Local email/password authentication provider."""
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
|
||||
class LocalAuthProvider(AuthProvider):
|
||||
"""Email/password authentication provider using local database."""
|
||||
|
||||
def __init__(self, repository: UserRepository):
|
||||
"""Initialize with a UserRepository.
|
||||
|
||||
Args:
|
||||
repository: UserRepository implementation (SQLite)
|
||||
"""
|
||||
self._repo = repository
|
||||
|
||||
async def authenticate(self, credentials: dict) -> User | None:
|
||||
"""Authenticate with email and password.
|
||||
|
||||
Args:
|
||||
credentials: dict with 'email' and 'password' keys
|
||||
|
||||
Returns:
|
||||
User if authentication succeeds, None otherwise
|
||||
"""
|
||||
email = credentials.get("email")
|
||||
password = credentials.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return None
|
||||
|
||||
user = await self._repo.get_user_by_email(email)
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
if user.password_hash is None:
|
||||
# OAuth user without local password
|
||||
return None
|
||||
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_user(self, user_id: str) -> User | None:
|
||||
"""Get user by ID."""
|
||||
return await self._repo.get_user_by_id(user_id)
|
||||
|
||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
||||
"""Create a new local user.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Plain text password (will be hashed)
|
||||
system_role: Role to assign ("admin" or "user")
|
||||
needs_setup: If True, user must complete setup on first login
|
||||
|
||||
Returns:
|
||||
Created User instance
|
||||
"""
|
||||
password_hash = await hash_password_async(password) if password else None
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
system_role=system_role,
|
||||
needs_setup=needs_setup,
|
||||
)
|
||||
return await self._repo.create_user(user)
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID."""
|
||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await self._repo.count_users()
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update_user(user)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email."""
|
||||
return await self._repo.get_user_by_email(email)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""User Pydantic models for authentication."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Return current UTC time (timezone-aware)."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Internal user representation."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||
email: EmailStr = Field(..., description="Unique email address")
|
||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||
system_role: Literal["admin", "user"] = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
# OAuth linkage (optional)
|
||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||
|
||||
# Auth lifecycle
|
||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info endpoint."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
system_role: Literal["admin", "user"]
|
||||
needs_setup: bool = False
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Password hashing utilities using bcrypt directly."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
async def hash_password_async(password: str) -> str:
|
||||
"""Hash a password using bcrypt (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password hashing.
|
||||
"""
|
||||
return await asyncio.to_thread(hash_password, password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password verification.
|
||||
"""
|
||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Auth provider abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, credentials: dict) -> "User | None":
|
||||
"""Authenticate user with given credentials.
|
||||
|
||||
Returns User if authentication succeeds, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> "User | None":
|
||||
"""Retrieve user by ID."""
|
||||
...
|
||||
|
||||
|
||||
# Import User at runtime to avoid circular imports
|
||||
from app.gateway.auth.models import User # noqa: E402
|
||||
@@ -0,0 +1,82 @@
|
||||
"""User repository interface for abstracting database operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
|
||||
class UserRepository(ABC):
|
||||
"""Abstract interface for user data storage.
|
||||
|
||||
Implement this interface to support different storage backends
|
||||
(SQLite)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
user: User object to create
|
||||
|
||||
Returns:
|
||||
Created User with ID assigned
|
||||
|
||||
Raises:
|
||||
ValueError: If email already exists
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User UUID as string
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user.
|
||||
|
||||
Args:
|
||||
user: User object with updated fields
|
||||
|
||||
Returns:
|
||||
Updated User
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g. 'github', 'google')
|
||||
oauth_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,196 @@
|
||||
"""SQLite implementation of UserRepository."""
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
_resolved_db_path: Path | None = None
|
||||
_table_initialized: bool = False
|
||||
|
||||
|
||||
def _get_users_db_path() -> Path:
|
||||
"""Get the users database path (resolved and cached once)."""
|
||||
global _resolved_db_path
|
||||
if _resolved_db_path is not None:
|
||||
return _resolved_db_path
|
||||
config = get_auth_config()
|
||||
if config.users_db_path:
|
||||
_resolved_db_path = Path(config.users_db_path)
|
||||
else:
|
||||
_resolved_db_path = Path(".deer-flow/users.db")
|
||||
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return _resolved_db_path
|
||||
|
||||
|
||||
def _get_connection() -> sqlite3.Connection:
|
||||
"""Get a SQLite connection for the users database."""
|
||||
db_path = _get_users_db_path()
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def _init_users_table(conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the users table if it doesn't exist."""
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT,
|
||||
system_role TEXT NOT NULL DEFAULT 'user',
|
||||
created_at REAL NOT NULL,
|
||||
oauth_provider TEXT,
|
||||
oauth_id TEXT,
|
||||
needs_setup INTEGER NOT NULL DEFAULT 0,
|
||||
token_version INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Add unique constraint for OAuth identity to prevent duplicate social logins
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
|
||||
ON users(oauth_provider, oauth_id)
|
||||
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_users_conn():
|
||||
"""Context manager for users database connection."""
|
||||
global _table_initialized
|
||||
conn = _get_connection()
|
||||
try:
|
||||
if not _table_initialized:
|
||||
_init_users_table(conn)
|
||||
_table_initialized = True
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""SQLite implementation of UserRepository."""
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user in SQLite."""
|
||||
return await asyncio.to_thread(self._create_user_sync, user)
|
||||
|
||||
def _create_user_sync(self, user: User) -> User:
|
||||
"""Synchronous user creation (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(user.id),
|
||||
user.email,
|
||||
user.password_hash,
|
||||
user.system_role,
|
||||
datetime.now(UTC).timestamp(),
|
||||
user.oauth_provider,
|
||||
user.oauth_id,
|
||||
int(user.needs_setup),
|
||||
user.token_version,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed: users.email" in str(e):
|
||||
raise ValueError(f"Email already registered: {user.email}") from e
|
||||
raise
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
|
||||
|
||||
def _get_user_by_id_sync(self, user_id: str) -> User | None:
|
||||
"""Synchronous get by ID (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_email_sync, email)
|
||||
|
||||
def _get_user_by_email_sync(self, email: str) -> User | None:
|
||||
"""Synchronous get by email (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user in SQLite."""
|
||||
return await asyncio.to_thread(self._update_user_sync, user)
|
||||
|
||||
def _update_user_sync(self, user: User) -> User:
|
||||
with _get_users_conn() as conn:
|
||||
conn.execute(
|
||||
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
|
||||
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
|
||||
)
|
||||
conn.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await asyncio.to_thread(self._count_users_sync)
|
||||
|
||||
def _count_users_sync(self) -> int:
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
|
||||
|
||||
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Synchronous get by OAuth (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
(provider, oauth_id),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: dict[str, Any]) -> User:
|
||||
"""Convert a database row to a User model."""
|
||||
return User(
|
||||
id=UUID(row["id"]),
|
||||
email=row["email"],
|
||||
password_hash=row["password_hash"],
|
||||
system_role=row["system_role"],
|
||||
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
|
||||
oauth_provider=row.get("oauth_provider"),
|
||||
oauth_id=row.get("oauth_id"),
|
||||
needs_setup=bool(row["needs_setup"]),
|
||||
token_version=int(row["token_version"]),
|
||||
)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""CLI tool to reset admin password.
|
||||
|
||||
Usage:
|
||||
python -m app.gateway.auth.reset_admin
|
||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from app.gateway.auth.password import hash_password
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo = SQLiteUserRepository()
|
||||
|
||||
# Find admin user synchronously (CLI context, no event loop)
|
||||
import asyncio
|
||||
|
||||
user = asyncio.run(_find_admin(repo, args.email))
|
||||
if user is None:
|
||||
if args.email:
|
||||
print(f"Error: user '{args.email}' not found.", file=sys.stderr)
|
||||
else:
|
||||
print("Error: no admin user found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
user.password_hash = hash_password(new_password)
|
||||
user.token_version += 1
|
||||
user.needs_setup = True
|
||||
asyncio.run(repo.update_user(user))
|
||||
|
||||
print(f"Password reset for: {user.email}")
|
||||
print(f"New password: {new_password}")
|
||||
print("Next login will require setup (new email + password).")
|
||||
|
||||
|
||||
async def _find_admin(repo: SQLiteUserRepository, email: str | None):
|
||||
if email:
|
||||
return await repo.get_user_by_email(email)
|
||||
# Find first admin
|
||||
import asyncio
|
||||
|
||||
from app.gateway.auth.repositories.sqlite import _get_users_conn
|
||||
|
||||
def _find_sync():
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT id FROM users WHERE system_role = 'admin' LIMIT 1")
|
||||
row = cursor.fetchone()
|
||||
return dict(row)["id"] if row else None
|
||||
|
||||
admin_id = await asyncio.to_thread(_find_sync)
|
||||
if admin_id:
|
||||
return await repo.get_user_by_id(admin_id)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Global authentication middleware — fail-closed safety net.
|
||||
|
||||
Rejects unauthenticated requests to non-public paths with 401.
|
||||
Fine-grained permission checks remain in authz.py decorators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode
|
||||
|
||||
# Paths that never require authentication.
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
)
|
||||
|
||||
# Exact auth paths that are public (login/register/status check).
|
||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_public(path: str) -> bool:
|
||||
stripped = path.rstrip("/")
|
||||
if stripped in _PUBLIC_EXACT_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Coarse-grained auth gate: reject requests without a valid session cookie.
|
||||
|
||||
This does NOT verify JWT signature or user existence — that is the job of
|
||||
``get_current_user_from_request`` in deps.py (called by ``@require_auth``).
|
||||
The middleware only checks *presence* of the cookie so that new endpoints
|
||||
that forget ``@require_auth`` are not completely exposed.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": {
|
||||
"code": AuthErrorCode.NOT_AUTHENTICATED,
|
||||
"message": "Authentication required",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Authorization decorators and context for DeerFlow.
|
||||
|
||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
||||
|
||||
**Usage:**
|
||||
|
||||
1. Use ``@require_auth`` on routes that need authentication
|
||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
||||
3. The decorator chain processes from bottom to top
|
||||
|
||||
**Example:**
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
# User is authenticated and has threads:read permission
|
||||
...
|
||||
|
||||
**Permission Model:**
|
||||
|
||||
- threads:read - View thread
|
||||
- threads:write - Create/update thread
|
||||
- threads:delete - Delete thread
|
||||
- runs:create - Run agent
|
||||
- runs:read - View run
|
||||
- runs:cancel - Cancel run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Permission constants
|
||||
class Permissions:
|
||||
"""Permission constants for resource:action format."""
|
||||
|
||||
# Threads
|
||||
THREADS_READ = "threads:read"
|
||||
THREADS_WRITE = "threads:write"
|
||||
THREADS_DELETE = "threads:delete"
|
||||
|
||||
# Runs
|
||||
RUNS_CREATE = "runs:create"
|
||||
RUNS_READ = "runs:read"
|
||||
RUNS_CANCEL = "runs:cancel"
|
||||
|
||||
|
||||
class AuthContext:
|
||||
"""Authentication context for the current request.
|
||||
|
||||
Stored in request.state.auth after require_auth decoration.
|
||||
|
||||
Attributes:
|
||||
user: The authenticated user, or None if anonymous
|
||||
permissions: List of permission strings (e.g., "threads:read")
|
||||
"""
|
||||
|
||||
__slots__ = ("user", "permissions")
|
||||
|
||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||
self.user = user
|
||||
self.permissions = permissions or []
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if user is authenticated."""
|
||||
return self.user is not None
|
||||
|
||||
def has_permission(self, resource: str, action: str) -> bool:
|
||||
"""Check if context has permission for resource:action.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads")
|
||||
action: Action name (e.g., "read")
|
||||
|
||||
Returns:
|
||||
True if user has permission
|
||||
"""
|
||||
permission = f"{resource}:{action}"
|
||||
return permission in self.permissions
|
||||
|
||||
def require_user(self) -> User:
|
||||
"""Get user or raise 401.
|
||||
|
||||
Raises:
|
||||
HTTPException 401 if not authenticated
|
||||
"""
|
||||
if not self.user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return self.user
|
||||
|
||||
|
||||
def get_auth_context(request: Request) -> AuthContext | None:
|
||||
"""Get AuthContext from request state."""
|
||||
return getattr(request.state, "auth", None)
|
||||
|
||||
|
||||
_ALL_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
async def _authenticate(request: Request) -> AuthContext:
|
||||
"""Authenticate request and return AuthContext.
|
||||
|
||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
||||
Returns AuthContext with user=None for anonymous requests.
|
||||
"""
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return AuthContext(user=None, permissions=[])
|
||||
|
||||
# In future, permissions could be stored in user record
|
||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
|
||||
|
||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""Decorator that authenticates the request and sets AuthContext.
|
||||
|
||||
Must be placed ABOVE other decorators (executes after them).
|
||||
|
||||
Usage:
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth # Bottom decorator (executes first after permission check)
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
auth: AuthContext = request.state.auth
|
||||
...
|
||||
|
||||
Raises:
|
||||
ValueError: If 'request' parameter is missing
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||
|
||||
# Authenticate and set context
|
||||
auth_context = await _authenticate(request)
|
||||
request.state.auth = auth_context
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_permission(
|
||||
resource: str,
|
||||
action: str,
|
||||
owner_check: bool = False,
|
||||
owner_filter_key: str = "user_id",
|
||||
inject_record: bool = False,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Decorator that checks permission for resource:action.
|
||||
|
||||
Must be used AFTER @require_auth.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads", "runs")
|
||||
action: Action name (e.g., "read", "write", "delete")
|
||||
owner_check: If True, validates that the current user owns the resource.
|
||||
Requires 'thread_id' path parameter and performs ownership check.
|
||||
owner_filter_key: Field name for ownership filter (default: "user_id")
|
||||
inject_record: If True and owner_check is True, injects the thread record
|
||||
into kwargs['thread_record'] for use in the handler.
|
||||
|
||||
Usage:
|
||||
# Simple permission check
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check (for /threads/{thread_id} endpoints)
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check and record injection
|
||||
@require_permission("threads", "delete", owner_check=True, inject_record=True)
|
||||
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
|
||||
# thread_record is injected if found
|
||||
...
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If authentication required but user is anonymous
|
||||
HTTPException 403: If user lacks permission
|
||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||
|
||||
auth: AuthContext = getattr(request.state, "auth", None)
|
||||
if auth is None:
|
||||
auth = await _authenticate(request)
|
||||
request.state.auth = auth
|
||||
|
||||
if not auth.is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Check permission
|
||||
if not auth.has_permission(resource, action):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission denied: {resource}:{action}",
|
||||
)
|
||||
|
||||
# Owner check for thread-specific resources
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
# Get thread and verify ownership
|
||||
from app.gateway.routers.threads import _store_get, get_store
|
||||
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
record = await _store_get(store, thread_id)
|
||||
if record:
|
||||
owner_id = record.get("metadata", {}).get(owner_filter_key)
|
||||
if owner_id and owner_id != str(auth.user.id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Thread {thread_id} not found",
|
||||
)
|
||||
# Inject record if requested
|
||||
if inject_record:
|
||||
kwargs["thread_record"] = record
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,112 @@
|
||||
"""CSRF protection middleware for FastAPI.
|
||||
|
||||
Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a secure random CSRF token."""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
def should_check_csrf(request: Request) -> bool:
|
||||
"""Determine if a request needs CSRF validation.
|
||||
|
||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
||||
"""
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_auth_endpoint(request: Request) -> bool:
|
||||
"""Check if the request is to an auth endpoint.
|
||||
|
||||
Auth endpoints don't need CSRF validation on first call (no token).
|
||||
"""
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token mismatch."},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# For auth endpoints that set up session, also set CSRF cookie
|
||||
if _is_auth and request.method == "POST":
|
||||
# Generate a new CSRF token for the session
|
||||
csrf_token = generate_csrf_token()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
||||
secure=is_https,
|
||||
samesite="strict",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> str | None:
|
||||
"""Get the CSRF token from the current request's cookies.
|
||||
|
||||
This is useful for server-side rendering where you need to embed
|
||||
token in forms or headers.
|
||||
"""
|
||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||
+104
-21
@@ -3,38 +3,22 @@
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunManager, StreamBridge
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
yield
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters – called by routers per-request
|
||||
@@ -68,3 +52,102 @@ def get_checkpointer(request: Request):
|
||||
def get_store(request: Request):
|
||||
"""Return the global store (may be ``None`` if not configured)."""
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Cached singletons to avoid repeated instantiation per request
|
||||
_cached_local_provider: LocalAuthProvider | None = None
|
||||
_cached_repo: SQLiteUserRepository | None = None
|
||||
|
||||
|
||||
def get_local_provider() -> LocalAuthProvider:
|
||||
"""Get or create the cached LocalAuthProvider singleton."""
|
||||
global _cached_local_provider, _cached_repo
|
||||
if _cached_repo is None:
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
_cached_repo = SQLiteUserRepository()
|
||||
if _cached_local_provider is None:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
|
||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
||||
return _cached_local_provider
|
||||
|
||||
|
||||
async def get_current_user_from_request(request: Request):
|
||||
"""Get the current authenticated user from the request cookie.
|
||||
|
||||
Raises HTTPException 401 if not authenticated.
|
||||
"""
|
||||
from app.gateway.auth import decode_token
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||
)
|
||||
|
||||
payload = decode_token(access_token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
||||
)
|
||||
|
||||
provider = get_local_provider()
|
||||
user = await provider.get_user(payload.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||
)
|
||||
|
||||
# Token version mismatch → password was changed, token is stale
|
||||
if user.token_version != payload.ver:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user_from_request(request: Request):
|
||||
"""Get optional authenticated user from request.
|
||||
|
||||
Returns None if not authenticated.
|
||||
"""
|
||||
try:
|
||||
return await get_current_user_from_request(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runtime bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
yield
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||
|
||||
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||
so both modes validate tokens with the same secret and rules.
|
||||
|
||||
Two layers:
|
||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
auth = Auth()
|
||||
|
||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
|
||||
def _check_csrf(request) -> None:
|
||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
||||
|
||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
||||
proxied directly by nginx have the same CSRF protection.
|
||||
"""
|
||||
method = getattr(request, "method", "") or ""
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token mismatch.",
|
||||
)
|
||||
|
||||
|
||||
@auth.authenticate
|
||||
async def authenticate(request):
|
||||
"""Validate the session cookie, decode JWT, and check token_version.
|
||||
|
||||
Same validation chain as Gateway's get_current_user_from_request:
|
||||
cookie → decode JWT → DB lookup → token_version match
|
||||
Also enforces CSRF on state-changing methods.
|
||||
"""
|
||||
# CSRF check before authentication so forged cross-site requests
|
||||
# are rejected early, even if the cookie carries a valid JWT.
|
||||
_check_csrf(request)
|
||||
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
|
||||
payload = decode_token(token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Token error: {payload.value}",
|
||||
)
|
||||
|
||||
user = await get_local_provider().get_user(payload.sub)
|
||||
if user is None:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.token_version != payload.ver:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Token revoked (password changed)",
|
||||
)
|
||||
|
||||
return payload.sub
|
||||
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.user_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp user_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["user_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"user_id": ctx.user.identity}
|
||||
@@ -1,3 +1,3 @@
|
||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
from . import artifacts, assistants_compat, auth, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
|
||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
__all__ = ["artifacts", "assistants_compat", "auth", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
|
||||
@@ -24,7 +24,7 @@ class AgentResponse(BaseModel):
|
||||
description: str = Field(default="", description="Agent description")
|
||||
model: str | None = Field(default=None, description="Optional model override")
|
||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||
soul: str | None = Field(default=None, description="SOUL.md content (included on GET /{name})")
|
||||
soul: str | None = Field(default=None, description="SOUL.md content")
|
||||
|
||||
|
||||
class AgentsListResponse(BaseModel):
|
||||
@@ -92,17 +92,17 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
|
||||
"/agents",
|
||||
response_model=AgentsListResponse,
|
||||
summary="List Custom Agents",
|
||||
description="List all custom agents available in the agents directory.",
|
||||
description="List all custom agents available in the agents directory, including their soul content.",
|
||||
)
|
||||
async def list_agents() -> AgentsListResponse:
|
||||
"""List all custom agents.
|
||||
|
||||
Returns:
|
||||
List of all custom agents with their metadata (without soul content).
|
||||
List of all custom agents with their metadata and soul content.
|
||||
"""
|
||||
try:
|
||||
agents = list_custom_agents()
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a) for a in agents])
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from app.gateway.auth import (
|
||||
UserResponse,
|
||||
create_access_token,
|
||||
)
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.csrf_middleware import is_secure_request
|
||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request/Response Models ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
||||
|
||||
expires_in: int # seconds
|
||||
needs_setup: bool = False
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Request model for password change (also handles setup flow)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
new_email: EmailStr | None = None
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||
"""Set the access_token HttpOnly cookie on the response."""
|
||||
config = get_auth_config()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_https,
|
||||
samesite="lax",
|
||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||
)
|
||||
|
||||
|
||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||
|
||||
_MAX_LOGIN_ATTEMPTS = 5
|
||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
# ip → (fail_count, lock_until_timestamp)
|
||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract the real client IP for rate limiting.
|
||||
|
||||
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
|
||||
$remote_addr``). Nginx unconditionally overwrites any client-supplied
|
||||
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
|
||||
that nginx observed — it cannot be spoofed by the client.
|
||||
|
||||
``request.client.host`` is NOT reliable because uvicorn's default
|
||||
``proxy_headers=True`` replaces it with the *first* entry from
|
||||
``X-Forwarded-For``, which IS client-spoofable.
|
||||
|
||||
``X-Forwarded-For`` is intentionally NOT used for the same reason.
|
||||
"""
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback: direct connection without nginx (e.g. unit tests, dev).
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _check_rate_limit(ip: str) -> None:
|
||||
"""Raise 429 if the IP is currently locked out."""
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
return
|
||||
fail_count, lock_until = record
|
||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||
if time.time() < lock_until:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many login attempts. Try again later.",
|
||||
)
|
||||
del _login_attempts[ip]
|
||||
|
||||
|
||||
_MAX_TRACKED_IPS = 10000
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> None:
|
||||
"""Record a failed login attempt for the given IP."""
|
||||
# Evict expired lockouts when dict grows too large
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
now = time.time()
|
||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||
for k in expired:
|
||||
del _login_attempts[k]
|
||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||
for k, _ in by_time[: len(by_time) // 2]:
|
||||
del _login_attempts[k]
|
||||
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
_login_attempts[ip] = (1, 0.0)
|
||||
else:
|
||||
new_count = record[0] + 1
|
||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||
_login_attempts[ip] = (new_count, lock_until)
|
||||
|
||||
|
||||
def _record_login_success(ip: str) -> None:
|
||||
"""Clear failure counter for the given IP on successful login."""
|
||||
_login_attempts.pop(ip, None)
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/login/local", response_model=LoginResponse)
|
||||
async def login_local(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
"""Local email/password login."""
|
||||
client_ip = _get_client_ip(request)
|
||||
_check_rate_limit(client_ip)
|
||||
|
||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
||||
|
||||
if user is None:
|
||||
_record_login_failure(client_ip)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
||||
)
|
||||
|
||||
_record_login_success(client_ip)
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return LoginResponse(
|
||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||
needs_setup=user.needs_setup,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||
"""Register a new user account (always 'user' role).
|
||||
|
||||
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||
Auto-login by setting the session cookie.
|
||||
"""
|
||||
try:
|
||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
||||
)
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response):
|
||||
"""Logout current user by clearing the cookie."""
|
||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/change-password", response_model=MessageResponse)
|
||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
||||
"""Change password for the currently authenticated user.
|
||||
|
||||
Also handles the first-boot setup flow:
|
||||
- If new_email is provided, updates email (checks uniqueness)
|
||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
||||
- Always increments token_version to invalidate old sessions
|
||||
- Re-issues session cookie with new token_version
|
||||
"""
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||
|
||||
if not await verify_password_async(body.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
||||
|
||||
provider = get_local_provider()
|
||||
|
||||
# Update email if provided
|
||||
if body.new_email is not None:
|
||||
existing = await provider.get_user_by_email(body.new_email)
|
||||
if existing and str(existing.id) != str(user.id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
||||
user.email = body.new_email
|
||||
|
||||
# Update password + bump version
|
||||
user.password_hash = await hash_password_async(body.new_password)
|
||||
user.token_version += 1
|
||||
|
||||
# Clear setup flag if this is the setup flow
|
||||
if user.needs_setup and body.new_email is not None:
|
||||
user.needs_setup = False
|
||||
|
||||
await provider.update_user(user)
|
||||
|
||||
# Re-issue cookie with new token_version
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return MessageResponse(message="Password changed successfully")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(request: Request):
|
||||
"""Get current authenticated user info."""
|
||||
user = await get_current_user_from_request(request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status():
|
||||
"""Check if admin account exists. Always False after first boot."""
|
||||
user_count = await get_local_provider().count_users()
|
||||
return {"needs_setup": user_count == 0}
|
||||
|
||||
|
||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
async def oauth_login(provider: str):
|
||||
"""Initiate OAuth login flow.
|
||||
|
||||
Redirects to the OAuth provider's authorization URL.
|
||||
Currently a placeholder - requires OAuth provider implementation.
|
||||
"""
|
||||
if provider not in ["github", "google"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider}",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth login not yet implemented",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback/{provider}")
|
||||
async def oauth_callback(provider: str, code: str, state: str):
|
||||
"""OAuth callback endpoint.
|
||||
|
||||
Handles the OAuth provider's callback after user authorization.
|
||||
Currently a placeholder.
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth callback not yet implemented",
|
||||
)
|
||||
@@ -49,6 +49,7 @@ class Fact(BaseModel):
|
||||
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
|
||||
createdAt: str = Field(default="", description="Creation timestamp")
|
||||
source: str = Field(default="unknown", description="Source thread ID")
|
||||
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
@@ -108,6 +109,7 @@ class MemoryStatusResponse(BaseModel):
|
||||
@router.get(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Data",
|
||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||
)
|
||||
@@ -152,6 +154,7 @@ async def get_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/reload",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Reload Memory Data",
|
||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||
)
|
||||
@@ -171,6 +174,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
@router.delete(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Clear All Memory Data",
|
||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||
)
|
||||
@@ -187,6 +191,7 @@ async def clear_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/facts",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Create Memory Fact",
|
||||
description="Create a single saved memory fact manually.",
|
||||
)
|
||||
@@ -209,6 +214,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
@router.delete(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Delete Memory Fact",
|
||||
description="Delete a single saved memory fact by its fact id.",
|
||||
)
|
||||
@@ -227,6 +233,7 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
@router.patch(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Patch Memory Fact",
|
||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||
)
|
||||
@@ -252,6 +259,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
@router.get(
|
||||
"/memory/export",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
@@ -264,6 +272,7 @@ async def export_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/import",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
@@ -317,6 +326,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
@router.get(
|
||||
"/memory/status",
|
||||
response_model=MemoryStatusResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Status",
|
||||
description="Retrieve both memory configuration and current data in a single request.",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.models import create_chat_model
|
||||
@@ -106,22 +107,21 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
prompt = (
|
||||
system_instruction = (
|
||||
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||
"Requirements:\n"
|
||||
"- Questions must be relevant to the conversation.\n"
|
||||
"- Questions must be relevant to the preceding conversation.\n"
|
||||
"- Questions must be written in the same language as the user.\n"
|
||||
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||
"- Output MUST be a JSON array of strings only.\n\n"
|
||||
"Conversation:\n"
|
||||
f"{conversation}\n"
|
||||
"- Output MUST be a JSON array of strings only.\n"
|
||||
)
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
response = model.invoke(prompt)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||
|
||||
@@ -19,6 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_auth, require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
@@ -38,6 +39,7 @@ class RunCreateRequest(BaseModel):
|
||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
@@ -91,19 +93,28 @@ def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
@require_auth
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
||||
"""Create a background run (returns immediately)."""
|
||||
"""Create a background run (returns immediately).
|
||||
|
||||
Multi-tenant isolation: only the thread owner can create runs.
|
||||
"""
|
||||
record = await start_run(body, thread_id, request)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
@require_auth
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
The response includes a ``Content-Location`` header with the run's
|
||||
resource URL, matching the LangGraph Platform protocol. The
|
||||
``useStream`` React hook uses this to extract run metadata.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can stream runs.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
@@ -124,8 +135,13 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
||||
@require_auth
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
"""Create a run and block until it completes, returning the final state.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can wait for runs.
|
||||
"""
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
if record.task is not None:
|
||||
@@ -149,16 +165,26 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
@require_auth
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||
"""List all runs for a thread."""
|
||||
"""List all runs for a thread.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can list runs.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
records = await run_mgr.list_by_thread(thread_id)
|
||||
return [_record_to_response(r) for r in records]
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
@require_auth
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
"""Get details of a specific run."""
|
||||
"""Get details of a specific run.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can get runs.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
@@ -167,6 +193,8 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
@require_auth
|
||||
@require_permission("runs", "cancel", owner_check=True)
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
@@ -180,6 +208,8 @@ async def cancel_run(
|
||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
||||
- wait=true: Block until the run fully stops, return 204
|
||||
- wait=false: Return immediately with 202
|
||||
|
||||
Multi-tenant isolation: only the thread owner can cancel runs.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
@@ -204,8 +234,13 @@ async def cancel_run(
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
@require_auth
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||
"""Join an existing run's SSE stream."""
|
||||
"""Join an existing run's SSE stream.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can join runs.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
|
||||
@@ -13,17 +13,26 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, HTTPException, Path, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_auth, require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_store
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread ID validation (prevents log-injection via control characters)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
ThreadId = Annotated[str, Path(description="Thread UUID", pattern=_UUID_RE.pattern)]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store namespace
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -65,6 +74,13 @@ class ThreadCreateRequest(BaseModel):
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
@field_validator("thread_id")
|
||||
@classmethod
|
||||
def _validate_uuid(cls, v: str | None) -> str | None:
|
||||
if v is not None and not _UUID_RE.match(v):
|
||||
raise ValueError("thread_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
"""Request body for searching threads."""
|
||||
@@ -215,17 +231,23 @@ def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread_data(thread_id: ThreadId, request: Request) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread.
|
||||
|
||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
||||
and removes the thread record from the Store.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can delete their thread.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove from Store (best-effort)
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
try:
|
||||
await store.adelete(THREADS_NS, thread_id)
|
||||
@@ -233,7 +255,6 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
@@ -251,12 +272,23 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
The thread record is written to the Store (for fast listing) and an
|
||||
empty checkpoint is written to the checkpointer (for state reads).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
|
||||
If authenticated, the user's ID is injected into the thread metadata
|
||||
for multi-tenant isolation.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
|
||||
thread_metadata = dict(body.metadata)
|
||||
if user:
|
||||
thread_metadata["user_id"] = str(user.id)
|
||||
|
||||
# Idempotency: return existing record from Store when already present
|
||||
if store is not None:
|
||||
existing_record = await _store_get(store, thread_id)
|
||||
@@ -279,7 +311,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": body.metadata,
|
||||
"metadata": thread_metadata,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -296,7 +328,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
"source": "input",
|
||||
"writes": None,
|
||||
"parents": {},
|
||||
**body.metadata,
|
||||
**thread_metadata,
|
||||
"created_at": now,
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
@@ -304,13 +336,13 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", thread_id)
|
||||
logger.info("Thread created: %s (user_id=%s)", thread_id, thread_metadata.get("user_id"))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
metadata=body.metadata,
|
||||
metadata=thread_metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -330,10 +362,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
newly found thread is immediately written to the Store so that the next
|
||||
search skips Phase 2 for that thread — the Store converges to a full
|
||||
index over time without a one-shot migration job.
|
||||
|
||||
If authenticated, only threads belonging to the current user are returned
|
||||
(enforced by user_id metadata filter for multi-tenant isolation).
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
user_id = str(user.id) if user else None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1: Store
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -409,6 +449,10 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
# -----------------------------------------------------------------------
|
||||
results = list(merged.values())
|
||||
|
||||
# Multi-tenant isolation: filter by user_id if authenticated
|
||||
if user_id:
|
||||
results = [r for r in results if r.metadata.get("user_id") == user_id]
|
||||
|
||||
if body.metadata:
|
||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
||||
|
||||
@@ -420,13 +464,20 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
@require_auth
|
||||
@require_permission("threads", "write", owner_check=True, inject_record=True)
|
||||
async def patch_thread(thread_id: ThreadId, request: Request, body: ThreadPatchRequest, thread_record: dict = None) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can patch their thread.
|
||||
"""
|
||||
store = get_store(request)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Store not available")
|
||||
|
||||
record = await _store_get(store, thread_id)
|
||||
record = thread_record
|
||||
if record is None:
|
||||
record = await _store_get(store, thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
@@ -451,12 +502,17 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: ThreadId, request: Request) -> ThreadResponse:
|
||||
"""Get thread info.
|
||||
|
||||
Reads metadata from the Store and derives the accurate execution
|
||||
status from the checkpointer. Falls back to the checkpointer alone
|
||||
for threads that pre-date Store adoption (backward compat).
|
||||
|
||||
Multi-tenant isolation: returns 404 if the thread does not belong to
|
||||
the authenticated user.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
@@ -488,26 +544,33 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") # type: ignore[union-attr]
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")), # type: ignore[union-attr]
|
||||
updated_at=str(record.get("updated_at", "")), # type: ignore[union-attr]
|
||||
metadata=record.get("metadata", {}), # type: ignore[union-attr]
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: ThreadId, request: Request) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread.
|
||||
|
||||
Channel values are serialized to ensure LangChain message objects
|
||||
are converted to JSON-safe dicts.
|
||||
|
||||
Multi-tenant isolation: returns 404 if thread does not belong to user.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
@@ -552,12 +615,16 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def update_thread_state(thread_id: ThreadId, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
||||
|
||||
Writes a new checkpoint that merges *body.values* into the latest
|
||||
channel values, then syncs any updated ``title`` field back to the Store
|
||||
so that ``/threads/search`` reflects the change immediately.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can update their thread.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
@@ -635,8 +702,13 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_history(thread_id: ThreadId, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread.
|
||||
|
||||
Multi-tenant isolation: returns 404 if thread does not belong to user.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
@@ -116,6 +116,7 @@ def build_run_config(
|
||||
metadata: dict[str, Any] | None,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a RunnableConfig dict for the agent.
|
||||
|
||||
@@ -128,27 +129,47 @@ def build_run_config(
|
||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||
identically.
|
||||
|
||||
If *user_id* is provided, it is injected into the config metadata for
|
||||
multi-tenant isolation.
|
||||
"""
|
||||
configurable: dict[str, Any] = {"thread_id": thread_id}
|
||||
config: dict[str, Any] = {"recursion_limit": 100}
|
||||
if request_config:
|
||||
configurable.update(request_config.get("configurable", {}))
|
||||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
||||
# pass thread-level data and rejects requests that include both
|
||||
# ``configurable`` and ``context``. If the caller already sends
|
||||
# ``context``, honour it and skip our own ``configurable`` dict.
|
||||
if "context" in request_config:
|
||||
if "configurable" in request_config:
|
||||
logger.warning(
|
||||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
||||
thread_id,
|
||||
list(request_config.get("configurable", {}).keys()),
|
||||
)
|
||||
config["context"] = request_config["context"]
|
||||
else:
|
||||
configurable = {"thread_id": thread_id}
|
||||
configurable.update(request_config.get("configurable", {}))
|
||||
config["configurable"] = configurable
|
||||
for k, v in request_config.items():
|
||||
if k not in ("configurable", "context"):
|
||||
config[k] = v
|
||||
else:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
# Inject custom agent name when the caller specified a non-default assistant.
|
||||
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "agent_name" not in configurable:
|
||||
# Normalize the same way ChannelManager does: strip, lowercase,
|
||||
# replace underscores with hyphens, then validate to prevent path
|
||||
# traversal and invalid agent directory lookups.
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
configurable["agent_name"] = normalized
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
||||
if "agent_name" not in config["configurable"]:
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
config["configurable"]["agent_name"] = normalized
|
||||
|
||||
# Multi-tenant isolation: inject user_id into metadata
|
||||
if user_id:
|
||||
config.setdefault("metadata", {})["user_id"] = user_id
|
||||
|
||||
config: dict[str, Any] = {"configurable": configurable, "recursion_limit": 100}
|
||||
if request_config:
|
||||
for k, v in request_config.items():
|
||||
if k != "configurable":
|
||||
config[k] = v
|
||||
if metadata:
|
||||
config.setdefault("metadata", {}).update(metadata)
|
||||
return config
|
||||
@@ -248,6 +269,10 @@ async def start_run(
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
# Reuse auth context set by @require_auth decorator to avoid redundant DB lookup
|
||||
auth = getattr(request.state, "auth", None)
|
||||
user_id = str(auth.user.id) if auth and auth.user else None
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
@@ -270,7 +295,34 @@ async def start_run(
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
config = build_run_config(
|
||||
thread_id,
|
||||
body.config,
|
||||
body.metadata,
|
||||
assistant_id=body.assistant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Merge DeerFlow-specific context overrides into configurable.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
context = getattr(body, "context", None)
|
||||
if context:
|
||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
}
|
||||
configurable = config.setdefault("configurable", {})
|
||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,129 @@
|
||||
# Authentication Upgrade Guide
|
||||
|
||||
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
|
||||
|
||||
## 核心概念
|
||||
|
||||
认证模块采用**始终强制**策略:
|
||||
|
||||
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
|
||||
- 认证从一开始就是强制的,无竞争窗口
|
||||
- 历史对话(升级前创建的 thread)自动迁移到 admin 名下
|
||||
|
||||
## 升级步骤
|
||||
|
||||
### 1. 更新代码
|
||||
|
||||
```bash
|
||||
git pull origin main
|
||||
cd backend && make install
|
||||
```
|
||||
|
||||
### 2. 首次启动
|
||||
|
||||
```bash
|
||||
make dev
|
||||
```
|
||||
|
||||
控制台会输出:
|
||||
|
||||
```
|
||||
============================================================
|
||||
Admin account created on first boot
|
||||
Email: admin@deerflow.dev
|
||||
Password: aB3xK9mN_pQ7rT2w
|
||||
Change it after login: Settings → Account
|
||||
============================================================
|
||||
```
|
||||
|
||||
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
|
||||
|
||||
### 3. 登录
|
||||
|
||||
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
|
||||
|
||||
### 4. 修改密码
|
||||
|
||||
登录后进入 Settings → Account → Change Password。
|
||||
|
||||
### 5. 添加用户(可选)
|
||||
|
||||
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
|
||||
|
||||
## 安全机制
|
||||
|
||||
| 机制 | 说明 |
|
||||
|------|------|
|
||||
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
|
||||
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
|
||||
| bcrypt 密码哈希 | 密码不以明文存储 |
|
||||
| 多租户隔离 | 用户只能访问自己的 thread |
|
||||
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
|
||||
|
||||
## 常见操作
|
||||
|
||||
### 忘记密码
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# 重置 admin 密码
|
||||
python -m app.gateway.auth.reset_admin
|
||||
|
||||
# 重置指定用户密码
|
||||
python -m app.gateway.auth.reset_admin --email user@example.com
|
||||
```
|
||||
|
||||
会输出新的随机密码。
|
||||
|
||||
### 完全重置
|
||||
|
||||
删除用户数据库,重启后自动创建新 admin:
|
||||
|
||||
```bash
|
||||
rm -f backend/.deer-flow/users.db
|
||||
# 重启服务,控制台输出新密码
|
||||
```
|
||||
|
||||
## 数据存储
|
||||
|
||||
| 文件 | 内容 |
|
||||
|------|------|
|
||||
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||
|
||||
### 生产环境建议
|
||||
|
||||
```bash
|
||||
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
# 将输出添加到 .env:
|
||||
# AUTH_JWT_SECRET=<生成的密钥>
|
||||
```
|
||||
|
||||
## API 端点
|
||||
|
||||
| 端点 | 方法 | 说明 |
|
||||
|------|------|------|
|
||||
| `/api/v1/auth/login/local` | POST | 邮箱密码登录(OAuth2 form) |
|
||||
| `/api/v1/auth/register` | POST | 注册新用户(user 角色) |
|
||||
| `/api/v1/auth/logout` | POST | 登出(清除 cookie) |
|
||||
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
|
||||
| `/api/v1/auth/change-password` | POST | 修改密码 |
|
||||
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
|
||||
|
||||
## 兼容性
|
||||
|
||||
- **标准模式**(`make dev`):完全兼容,admin 自动创建
|
||||
- **Gateway 模式**(`make dev-pro`):完全兼容
|
||||
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
|
||||
- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
|
||||
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
||||
|
||||
## 故障排查
|
||||
|
||||
| 症状 | 原因 | 解决 |
|
||||
|------|------|------|
|
||||
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
|
||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||
@@ -248,7 +248,7 @@ def after_agent(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | N
|
||||
- [`packages/harness/deerflow/agents/thread_state.py`](../packages/harness/deerflow/agents/thread_state.py) - ThreadState 定义
|
||||
- [`packages/harness/deerflow/agents/middlewares/title_middleware.py`](../packages/harness/deerflow/agents/middlewares/title_middleware.py) - TitleMiddleware 实现
|
||||
- [`packages/harness/deerflow/config/title_config.py`](../packages/harness/deerflow/config/title_config.py) - 配置管理
|
||||
- [`config.yaml`](../config.yaml) - 配置文件
|
||||
- [`config.yaml`](../../config.example.yaml) - 配置文件
|
||||
- [`packages/harness/deerflow/agents/lead_agent/agent.py`](../packages/harness/deerflow/agents/lead_agent/agent.py) - Middleware 注册
|
||||
|
||||
## 参考资料
|
||||
|
||||
@@ -278,6 +278,12 @@ skills:
|
||||
- Skills are automatically discovered and loaded
|
||||
- Available in both local and Docker sandbox via path mapping
|
||||
|
||||
**Per-Agent Skill Filtering**:
|
||||
Custom agents can restrict which skills they load by defining a `skills` field in their `config.yaml` (located at `workspace/agents/<agent_name>/config.yaml`):
|
||||
- **Omitted or `null`**: Loads all globally enabled skills (default fallback).
|
||||
- **`[]` (empty list)**: Disables all skills for this specific agent.
|
||||
- **`["skill-name"]`**: Loads only the explicitly specified skills.
|
||||
|
||||
### Title Generation
|
||||
|
||||
Automatic conversation title generation:
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
|
||||
### 2. 配置文件
|
||||
|
||||
#### [`config.yaml`](../config.yaml)
|
||||
#### [`config.yaml`](../../config.example.yaml)
|
||||
- ✅ 添加 title 配置段:
|
||||
```yaml
|
||||
title:
|
||||
@@ -51,7 +51,7 @@ title:
|
||||
- ✅ 故障排查指南
|
||||
- ✅ State vs Metadata 对比
|
||||
|
||||
#### [`BACKEND_TODO.md`](../BACKEND_TODO.md)
|
||||
#### [`TODO.md`](TODO.md)
|
||||
- ✅ 添加功能完成记录
|
||||
|
||||
### 4. 测试
|
||||
|
||||
@@ -0,0 +1,446 @@
|
||||
# [RFC] 在 DeerFlow 中增加 `grep` 与 `glob` 文件搜索工具
|
||||
|
||||
## Summary
|
||||
|
||||
我认为这个方向是对的,而且值得做。
|
||||
|
||||
如果 DeerFlow 想更接近 Claude Code 这类 coding agent 的实际工作流,仅有 `ls` / `read_file` / `write_file` / `str_replace` 还不够。模型在进入修改前,通常还需要两类能力:
|
||||
|
||||
- `glob`: 快速按路径模式找文件
|
||||
- `grep`: 快速按内容模式找候选位置
|
||||
|
||||
这两类工具的价值,不是“功能上 bash 也能做”,而是它们能以更低 token 成本、更强约束、更稳定的输出格式,替代模型频繁走 `bash find` / `bash grep` / `rg` 的习惯。
|
||||
|
||||
但前提是实现方式要对:**它们应该是只读、结构化、受限、可审计的原生工具,而不是对 shell 命令的简单包装。**
|
||||
|
||||
## Problem
|
||||
|
||||
当前 DeerFlow 的文件工具层主要覆盖:
|
||||
|
||||
- `ls`: 浏览目录结构
|
||||
- `read_file`: 读取文件内容
|
||||
- `write_file`: 写文件
|
||||
- `str_replace`: 做局部字符串替换
|
||||
- `bash`: 兜底执行命令
|
||||
|
||||
这套能力能完成任务,但在代码库探索阶段效率不高。
|
||||
|
||||
典型问题:
|
||||
|
||||
1. 模型想找 “所有 `*.tsx` 的 page 文件” 时,只能反复 `ls` 多层目录,或者退回 `bash find`
|
||||
2. 模型想找 “某个 symbol / 文案 / 配置键在哪里出现” 时,只能逐文件 `read_file`,或者退回 `bash grep` / `rg`
|
||||
3. 一旦退回 `bash`,工具调用就失去结构化输出,结果也更难做裁剪、分页、审计和跨 sandbox 一致化
|
||||
4. 对没有开启 host bash 的本地模式,`bash` 甚至可能不可用,此时缺少足够强的只读检索能力
|
||||
|
||||
结论:DeerFlow 现在缺的不是“再多一个 shell 命令”,而是**文件系统检索层**。
|
||||
|
||||
## Goals
|
||||
|
||||
- 为 agent 提供稳定的路径搜索和内容搜索能力
|
||||
- 减少对 `bash` 的依赖,特别是在仓库探索阶段
|
||||
- 保持与现有 sandbox 安全模型一致
|
||||
- 输出格式结构化,便于模型后续串联 `read_file` / `str_replace`
|
||||
- 让本地 sandbox、容器 sandbox、未来 MCP 文件系统工具都能遵守同一语义
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- 不做通用 shell 兼容层
|
||||
- 不暴露完整 grep/find/rg CLI 语法
|
||||
- 不在第一版支持二进制检索、复杂 PCRE 特性、上下文窗口高亮渲染等重功能
|
||||
- 不把它做成“任意磁盘搜索”,仍然只允许在 DeerFlow 已授权的路径内执行
|
||||
|
||||
## Why This Is Worth Doing
|
||||
|
||||
参考 Claude Code 这一类 agent 的设计思路,`glob` 和 `grep` 的核心价值不是新能力本身,而是把“探索代码库”的常见动作从开放式 shell 降到受控工具层。
|
||||
|
||||
这样有几个直接收益:
|
||||
|
||||
1. **更低的模型负担**
|
||||
模型不需要自己拼 `find`, `grep`, `rg`, `xargs`, quoting 等命令细节。
|
||||
|
||||
2. **更稳定的跨环境行为**
|
||||
本地、Docker、AIO sandbox 不必依赖容器里是否装了 `rg`,也不会因为 shell 差异导致行为漂移。
|
||||
|
||||
3. **更强的安全与审计**
|
||||
调用参数就是“搜索什么、在哪搜、最多返回多少”,天然比任意命令更容易审计和限流。
|
||||
|
||||
4. **更好的 token 效率**
|
||||
`grep` 返回的是命中摘要而不是整段文件,模型只对少数候选路径再调用 `read_file`。
|
||||
|
||||
5. **对 `tool_search` 友好**
|
||||
当 DeerFlow 持续扩展工具集时,`grep` / `glob` 会成为非常高频的基础工具,值得保留为 built-in,而不是让模型总是退回通用 bash。
|
||||
|
||||
## Proposal
|
||||
|
||||
增加两个 built-in sandbox tools:
|
||||
|
||||
- `glob`
|
||||
- `grep`
|
||||
|
||||
推荐继续放在:
|
||||
|
||||
- `backend/packages/harness/deerflow/sandbox/tools.py`
|
||||
|
||||
并在 `config.example.yaml` 中默认加入 `file:read` 组。
|
||||
|
||||
### 1. `glob` 工具
|
||||
|
||||
用途:按路径模式查找文件或目录。
|
||||
|
||||
建议 schema:
|
||||
|
||||
```python
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include_dirs: bool = False,
|
||||
max_results: int = 200,
|
||||
) -> str:
|
||||
...
|
||||
```
|
||||
|
||||
参数语义:
|
||||
|
||||
- `description`: 与现有工具保持一致
|
||||
- `pattern`: glob 模式,例如 `**/*.py`、`src/**/test_*.ts`
|
||||
- `path`: 搜索根目录,必须是绝对路径
|
||||
- `include_dirs`: 是否返回目录
|
||||
- `max_results`: 最大返回条数,防止一次性打爆上下文
|
||||
|
||||
建议返回格式:
|
||||
|
||||
```text
|
||||
Found 3 paths under /mnt/user-data/workspace
|
||||
1. /mnt/user-data/workspace/backend/app.py
|
||||
2. /mnt/user-data/workspace/backend/tests/test_app.py
|
||||
3. /mnt/user-data/workspace/scripts/build.py
|
||||
```
|
||||
|
||||
如果后续想更适合前端消费,也可以改成 JSON 字符串;但第一版为了兼容现有工具风格,返回可读文本即可。
|
||||
|
||||
### 2. `grep` 工具
|
||||
|
||||
用途:按内容模式搜索文件,返回命中位置摘要。
|
||||
|
||||
建议 schema:
|
||||
|
||||
```python
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> str:
|
||||
...
|
||||
```
|
||||
|
||||
参数语义:
|
||||
|
||||
- `pattern`: 搜索词或正则
|
||||
- `path`: 搜索根目录,必须是绝对路径
|
||||
- `glob`: 可选路径过滤,例如 `**/*.py`
|
||||
- `literal`: 为 `True` 时按普通字符串匹配,不解释为正则
|
||||
- `case_sensitive`: 是否大小写敏感
|
||||
- `max_results`: 最大返回命中数,不是文件数
|
||||
|
||||
建议返回格式:
|
||||
|
||||
```text
|
||||
Found 4 matches under /mnt/user-data/workspace
|
||||
/mnt/user-data/workspace/backend/config.py:12: TOOL_GROUPS = [...]
|
||||
/mnt/user-data/workspace/backend/config.py:48: def load_tool_config(...):
|
||||
/mnt/user-data/workspace/backend/tools.py:91: "tool_groups"
|
||||
/mnt/user-data/workspace/backend/tests/test_config.py:22: assert "tool_groups" in data
|
||||
```
|
||||
|
||||
第一版建议只返回:
|
||||
|
||||
- 文件路径
|
||||
- 行号
|
||||
- 命中行摘要
|
||||
|
||||
不返回上下文块,避免结果过大。模型如果需要上下文,再调用 `read_file(path, start_line, end_line)`。
|
||||
|
||||
## Design Principles
|
||||
|
||||
### A. 不做 shell wrapper
|
||||
|
||||
不建议把 `grep` 实现为:
|
||||
|
||||
```python
|
||||
subprocess.run("grep ...")
|
||||
```
|
||||
|
||||
也不建议在容器里直接拼 `find` / `rg` 命令。
|
||||
|
||||
原因:
|
||||
|
||||
- 会引入 shell quoting 和注入面
|
||||
- 会依赖不同 sandbox 内镜像是否安装同一套命令
|
||||
- Windows / macOS / Linux 行为不一致
|
||||
- 很难稳定控制输出条数与格式
|
||||
|
||||
正确方向是:
|
||||
|
||||
- `glob` 使用 Python 标准库路径遍历
|
||||
- `grep` 使用 Python 逐文件扫描
|
||||
- 输出由 DeerFlow 自己格式化
|
||||
|
||||
如果未来为了性能考虑要优先调用 `rg`,也应该封装在 provider 内部,并保证外部语义不变,而不是把 CLI 暴露给模型。
|
||||
|
||||
### B. 继续沿用 DeerFlow 的路径权限模型
|
||||
|
||||
这两个工具必须复用当前 `ls` / `read_file` 的路径校验逻辑:
|
||||
|
||||
- 本地模式走 `validate_local_tool_path(..., read_only=True)`
|
||||
- 支持 `/mnt/skills/...`
|
||||
- 支持 `/mnt/acp-workspace/...`
|
||||
- 支持 thread workspace / uploads / outputs 的虚拟路径解析
|
||||
- 明确拒绝越权路径与 path traversal
|
||||
|
||||
也就是说,它们属于 **file:read**,不是 `bash` 的替代越权入口。
|
||||
|
||||
### C. 结果必须硬限制
|
||||
|
||||
没有硬限制的 `glob` / `grep` 很容易炸上下文。
|
||||
|
||||
建议第一版至少限制:
|
||||
|
||||
- `glob.max_results` 默认 200,最大 1000
|
||||
- `grep.max_results` 默认 100,最大 500
|
||||
- 单行摘要最大长度,例如 200 字符
|
||||
- 二进制文件跳过
|
||||
- 超大文件跳过,例如单文件大于 1 MB 或按配置控制
|
||||
|
||||
此外,命中数超过阈值时应返回:
|
||||
|
||||
- 已展示的条数
|
||||
- 被截断的事实
|
||||
- 建议用户缩小搜索范围
|
||||
|
||||
例如:
|
||||
|
||||
```text
|
||||
Found more than 100 matches, showing first 100. Narrow the path or add a glob filter.
|
||||
```
|
||||
|
||||
### D. 工具语义要彼此互补
|
||||
|
||||
推荐模型工作流应该是:
|
||||
|
||||
1. `glob` 找候选文件
|
||||
2. `grep` 找候选位置
|
||||
3. `read_file` 读局部上下文
|
||||
4. `str_replace` / `write_file` 执行修改
|
||||
|
||||
这样工具边界清晰,也更利于 prompt 中教模型形成稳定习惯。
|
||||
|
||||
## Implementation Approach
|
||||
|
||||
## Option A: 直接在 `sandbox/tools.py` 实现第一版
|
||||
|
||||
这是我推荐的起步方案。
|
||||
|
||||
做法:
|
||||
|
||||
- 在 `sandbox/tools.py` 新增 `glob_tool` 与 `grep_tool`
|
||||
- 在 local sandbox 场景直接使用 Python 文件系统 API
|
||||
- 在非 local sandbox 场景,优先也通过 DeerFlow 自己控制的路径访问层实现
|
||||
|
||||
优点:
|
||||
|
||||
- 改动小
|
||||
- 能尽快验证 agent 效果
|
||||
- 不需要先改 `Sandbox` 抽象
|
||||
|
||||
缺点:
|
||||
|
||||
- `tools.py` 会继续变胖
|
||||
- 如果未来想在 provider 侧做性能优化,需要再抽象一次
|
||||
|
||||
## Option B: 先扩展 `Sandbox` 抽象
|
||||
|
||||
例如新增:
|
||||
|
||||
```python
|
||||
class Sandbox(ABC):
|
||||
def glob(self, path: str, pattern: str, include_dirs: bool = False, max_results: int = 200) -> list[str]:
|
||||
...
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> list[GrepMatch]:
|
||||
...
|
||||
```
|
||||
|
||||
优点:
|
||||
|
||||
- 抽象更干净
|
||||
- 容器 / 远程 sandbox 可以各自优化
|
||||
|
||||
缺点:
|
||||
|
||||
- 首次引入成本更高
|
||||
- 需要同步改所有 sandbox provider
|
||||
|
||||
结论:
|
||||
|
||||
**第一版建议走 Option A,等工具价值验证后再下沉到 `Sandbox` 抽象层。**
|
||||
|
||||
## Detailed Behavior
|
||||
|
||||
### `glob` 行为
|
||||
|
||||
- 输入根目录不存在:返回清晰错误
|
||||
- 根路径不是目录:返回清晰错误
|
||||
- 模式非法:返回清晰错误
|
||||
- 结果为空:返回 `No files matched`
|
||||
- 默认忽略项应尽量与当前 `list_dir` 对齐,例如:
|
||||
- `.git`
|
||||
- `node_modules`
|
||||
- `__pycache__`
|
||||
- `.venv`
|
||||
- 构建产物目录
|
||||
|
||||
这里建议抽一个共享 ignore 集,避免 `ls` 与 `glob` 结果风格不一致。
|
||||
|
||||
### `grep` 行为
|
||||
|
||||
- 默认只扫描文本文件
|
||||
- 检测到二进制文件直接跳过
|
||||
- 对超大文件直接跳过或只扫前 N KB
|
||||
- regex 编译失败时返回参数错误
|
||||
- 输出中的路径继续使用虚拟路径,而不是暴露宿主真实路径
|
||||
- 建议默认按文件路径、行号排序,保持稳定输出
|
||||
|
||||
## Prompting Guidance
|
||||
|
||||
如果引入这两个工具,建议同步更新系统提示中的文件操作建议:
|
||||
|
||||
- 查找文件名模式时优先用 `glob`
|
||||
- 查找代码符号、配置项、文案时优先用 `grep`
|
||||
- 只有在工具不足以完成目标时才退回 `bash`
|
||||
|
||||
否则模型仍会习惯性先调用 `bash`。
|
||||
|
||||
## Risks
|
||||
|
||||
### 1. 与 `bash` 能力重叠
|
||||
|
||||
这是事实,但不是问题。
|
||||
|
||||
`ls` 和 `read_file` 也都能被 `bash` 替代,但我们仍然保留它们,因为结构化工具更适合 agent。
|
||||
|
||||
### 2. 性能问题
|
||||
|
||||
在大仓库上,纯 Python `grep` 可能比 `rg` 慢。
|
||||
|
||||
缓解方式:
|
||||
|
||||
- 第一版先加结果上限和文件大小上限
|
||||
- 路径上强制要求 root path
|
||||
- 提供 `glob` 过滤缩小扫描范围
|
||||
- 后续如有必要,在 provider 内部做 `rg` 优化,但保持同一 schema
|
||||
|
||||
### 3. 忽略规则不一致
|
||||
|
||||
如果 `ls` 能看到的路径,`glob` 却看不到,模型会困惑。
|
||||
|
||||
缓解方式:
|
||||
|
||||
- 统一 ignore 规则
|
||||
- 在文档里明确“默认跳过常见依赖和构建目录”
|
||||
|
||||
### 4. 正则搜索过于复杂
|
||||
|
||||
如果第一版就支持大量 grep 方言,边界会很乱。
|
||||
|
||||
缓解方式:
|
||||
|
||||
- 第一版只支持 Python `re`
|
||||
- 并提供 `literal=True` 的简单模式
|
||||
|
||||
## Alternatives Considered
|
||||
|
||||
### A. 不增加工具,完全依赖 `bash`
|
||||
|
||||
不推荐。
|
||||
|
||||
这会让 DeerFlow 在代码探索体验上持续落后,也削弱无 bash 或受限 bash 场景下的能力。
|
||||
|
||||
### B. 只加 `glob`,不加 `grep`
|
||||
|
||||
不推荐。
|
||||
|
||||
只解决“找文件”,没有解决“找位置”。模型最终还是会退回 `bash grep`。
|
||||
|
||||
### C. 只加 `grep`,不加 `glob`
|
||||
|
||||
也不推荐。
|
||||
|
||||
`grep` 缺少路径模式过滤时,扫描范围经常太大;`glob` 是它的天然前置工具。
|
||||
|
||||
### D. 直接接入 MCP filesystem server 的搜索能力
|
||||
|
||||
短期不推荐作为主路径。
|
||||
|
||||
MCP 可以是补充,但 `glob` / `grep` 作为 DeerFlow 的基础 coding tool,最好仍然是 built-in,这样才能在默认安装中稳定可用。
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- `config.example.yaml` 中可默认启用 `glob` 与 `grep`
|
||||
- 两个工具归属 `file:read` 组
|
||||
- 本地 sandbox 下严格遵守现有路径权限
|
||||
- 输出不泄露宿主机真实路径
|
||||
- 大结果集会被截断并明确提示
|
||||
- 模型可以通过 `glob -> grep -> read_file -> str_replace` 完成典型改码流
|
||||
- 在禁用 host bash 的本地模式下,仓库探索能力明显提升
|
||||
|
||||
## Rollout Plan
|
||||
|
||||
1. 在 `sandbox/tools.py` 中实现 `glob_tool` 与 `grep_tool`
|
||||
2. 抽取与 `list_dir` 一致的 ignore 规则,避免行为漂移
|
||||
3. 在 `config.example.yaml` 默认加入工具配置
|
||||
4. 为本地路径校验、虚拟路径映射、结果截断、二进制跳过补测试
|
||||
5. 更新 README / backend docs / prompt guidance
|
||||
6. 收集实际 agent 调用数据,再决定是否下沉到 `Sandbox` 抽象
|
||||
|
||||
## Suggested Config
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
- name: glob
|
||||
group: file:read
|
||||
use: deerflow.sandbox.tools:glob_tool
|
||||
|
||||
- name: grep
|
||||
group: file:read
|
||||
use: deerflow.sandbox.tools:grep_tool
|
||||
```
|
||||
|
||||
## Final Recommendation
|
||||
|
||||
结论是:**可以加,而且应该加。**
|
||||
|
||||
但我会明确卡三个边界:
|
||||
|
||||
1. `grep` / `glob` 必须是 built-in 的只读结构化工具
|
||||
2. 第一版不要做 shell wrapper,不要把 CLI 方言直接暴露给模型
|
||||
3. 先在 `sandbox/tools.py` 验证价值,再考虑是否下沉到 `Sandbox` provider 抽象
|
||||
|
||||
如果按这个方向做,它会明显提升 DeerFlow 在 coding / repo exploration 场景下的可用性,而且风险可控。
|
||||
@@ -8,6 +8,9 @@
|
||||
"graphs": {
|
||||
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||
},
|
||||
"auth": {
|
||||
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||
},
|
||||
"checkpointer": {
|
||||
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||
}
|
||||
|
||||
@@ -343,6 +343,8 @@ def make_lead_agent(config: RunnableConfig):
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
try:
|
||||
return list(load_skills(enabled_only=True))
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
return []
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
@@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
suitable for injection into any agent's system prompt.
|
||||
"""
|
||||
skills = load_skills(enabled_only=True)
|
||||
skills = _get_enabled_skills()
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
@@ -402,6 +410,10 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
if available_skills is not None:
|
||||
skills = [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
# Check again after filtering
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
skill_items = "\n".join(
|
||||
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
|
||||
)
|
||||
@@ -446,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
return ""
|
||||
except FileNotFoundError:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
|
||||
@@ -29,6 +29,17 @@ Instructions:
|
||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||
3. Update the memory sections as needed following the detailed length guidelines below
|
||||
|
||||
Before extracting facts, perform a structured reflection on the conversation:
|
||||
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
|
||||
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
|
||||
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
|
||||
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
|
||||
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
|
||||
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
|
||||
If yes, record them as facts with the most appropriate category and confidence.
|
||||
|
||||
{correction_hint}
|
||||
|
||||
Memory Section Guidelines:
|
||||
|
||||
**User Context** (Current state - concise summaries):
|
||||
@@ -62,6 +73,7 @@ Memory Section Guidelines:
|
||||
* context: Background facts (job title, projects, locations, languages)
|
||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||
* goal: Stated objectives, learning targets, project ambitions
|
||||
* correction: Explicit agent mistakes or user corrections, including the correct approach
|
||||
- Confidence levels:
|
||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||
* 0.7-0.8: Strongly implied from actions/discussions
|
||||
@@ -94,7 +106,7 @@ Output Format (JSON):
|
||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"newFacts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
],
|
||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||
}}
|
||||
@@ -104,6 +116,8 @@ Important Rules:
|
||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||
- Include specific metrics, version numbers, and proper nouns in facts
|
||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
|
||||
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
|
||||
- Remove facts that are contradicted by new information
|
||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||
@@ -126,7 +140,7 @@ Message:
|
||||
Extract facts in this JSON format:
|
||||
{{
|
||||
"facts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
]
|
||||
}}
|
||||
|
||||
@@ -136,6 +150,7 @@ Categories:
|
||||
- context: Background context (location, job, projects)
|
||||
- behavior: Behavioral patterns
|
||||
- goal: User's goals or objectives
|
||||
- correction: Explicit corrections or mistakes to avoid repeating
|
||||
|
||||
Rules:
|
||||
- Only extract clear, specific facts
|
||||
@@ -231,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
if earlier.get("summary"):
|
||||
history_sections.append(f"Earlier: {earlier['summary']}")
|
||||
|
||||
background = history_data.get("longTermBackground", {})
|
||||
if background.get("summary"):
|
||||
history_sections.append(f"Background: {background['summary']}")
|
||||
|
||||
if history_sections:
|
||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||
|
||||
@@ -262,7 +281,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
continue
|
||||
category = str(fact.get("category", "context")).strip() or "context"
|
||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
source_error = fact.get("sourceError")
|
||||
if category == "correction" and isinstance(source_error, str) and source_error.strip():
|
||||
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
|
||||
else:
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
|
||||
@@ -20,6 +20,8 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@@ -37,25 +39,42 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
@@ -115,6 +134,8 @@ class MemoryUpdateQueue:
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
return stripped
|
||||
return stripped.casefold()
|
||||
|
||||
|
||||
class MemoryUpdater:
|
||||
@@ -266,13 +266,22 @@ class MemoryUpdater:
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -295,9 +304,27 @@ class MemoryUpdater:
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
@@ -383,6 +410,8 @@ class MemoryUpdater:
|
||||
confidence = fact.get("confidence", 0.5)
|
||||
if confidence >= config.fact_confidence_threshold:
|
||||
raw_content = fact.get("content", "")
|
||||
if not isinstance(raw_content, str):
|
||||
continue
|
||||
normalized_content = raw_content.strip()
|
||||
fact_key = _fact_content_key(normalized_content)
|
||||
if fact_key is not None and fact_key in existing_fact_keys:
|
||||
@@ -396,6 +425,11 @@ class MemoryUpdater:
|
||||
"createdAt": now,
|
||||
"source": thread_id or "unknown",
|
||||
}
|
||||
source_error = fact.get("sourceError")
|
||||
if isinstance(source_error, str):
|
||||
normalized_source_error = source_error.strip()
|
||||
if normalized_source_error:
|
||||
fact_entry["sourceError"] = normalized_source_error
|
||||
current_memory["facts"].append(fact_entry)
|
||||
if fact_key is not None:
|
||||
existing_fact_keys.add(fact_key)
|
||||
@@ -412,16 +446,24 @@ class MemoryUpdater:
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory_from_conversation(
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
|
||||
+275
@@ -0,0 +1,275 @@
|
||||
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
_BUSY_PATTERNS = (
|
||||
"server busy",
|
||||
"temporarily unavailable",
|
||||
"try again later",
|
||||
"please retry",
|
||||
"please try again",
|
||||
"overloaded",
|
||||
"high demand",
|
||||
"rate limit",
|
||||
"负载较高",
|
||||
"服务繁忙",
|
||||
"稍后重试",
|
||||
"请稍后重试",
|
||||
)
|
||||
_QUOTA_PATTERNS = (
|
||||
"insufficient_quota",
|
||||
"quota",
|
||||
"billing",
|
||||
"credit",
|
||||
"payment",
|
||||
"余额不足",
|
||||
"超出限额",
|
||||
"额度不足",
|
||||
"欠费",
|
||||
)
|
||||
_AUTH_PATTERNS = (
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"无权",
|
||||
"未授权",
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||
|
||||
retry_max_attempts: int = 3
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
error_code = _extract_error_code(exc)
|
||||
status_code = _extract_status_code(exc)
|
||||
|
||||
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
|
||||
return False, "quota"
|
||||
if _matches_any(lowered, _AUTH_PATTERNS):
|
||||
return False, "auth"
|
||||
|
||||
exc_name = exc.__class__.__name__
|
||||
if exc_name in {
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"InternalServerError",
|
||||
}:
|
||||
return True, "transient"
|
||||
if status_code in _RETRIABLE_STATUS_CODES:
|
||||
return True, "transient"
|
||||
if _matches_any(lowered, _BUSY_PATTERNS):
|
||||
return True, "busy"
|
||||
|
||||
return False, "generic"
|
||||
|
||||
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
|
||||
retry_after = _extract_retry_after_ms(exc)
|
||||
if retry_after is not None:
|
||||
return retry_after
|
||||
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
|
||||
return min(backoff, self.retry_cap_delay_ms)
|
||||
|
||||
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
|
||||
seconds = max(1, round(wait_ms / 1000))
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
|
||||
if reason == "auth":
|
||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||
if reason in {"busy", "transient"}:
|
||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||
return f"LLM request failed: {detail}"
|
||||
|
||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
writer(
|
||||
{
|
||||
"type": "llm_retry",
|
||||
"attempt": attempt,
|
||||
"max_attempts": self.retry_max_attempts,
|
||||
"wait_ms": wait_ms,
|
||||
"reason": reason,
|
||||
"message": self._build_retry_message(attempt, wait_ms, reason),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to emit llm_retry event", exc_info=True)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
time.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||
return any(pattern in detail for pattern in patterns)
|
||||
|
||||
|
||||
def _extract_error_code(exc: BaseException) -> Any:
|
||||
for attr in ("code", "error_code"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if isinstance(body, dict):
|
||||
error = body.get("error")
|
||||
if isinstance(error, dict):
|
||||
for key in ("code", "type"):
|
||||
value = error.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_status_code(exc: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status"):
|
||||
value = getattr(exc, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
response = getattr(exc, "response", None)
|
||||
status = getattr(response, "status_code", None)
|
||||
return status if isinstance(status, int) else None
|
||||
|
||||
|
||||
def _extract_retry_after_ms(exc: BaseException) -> int | None:
|
||||
response = getattr(exc, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
raw = None
|
||||
header_name = ""
|
||||
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
|
||||
header_name = key
|
||||
if hasattr(headers, "get"):
|
||||
raw = headers.get(key)
|
||||
if raw:
|
||||
break
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
multiplier = 1 if "ms" in header_name.lower() else 1000
|
||||
return max(0, int(float(raw) * multiplier))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
target = parsedate_to_datetime(str(raw))
|
||||
delta = target.timestamp() - time.time()
|
||||
return max(0, int(delta * 1000))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_detail(exc: BaseException) -> str:
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
return detail
|
||||
message = getattr(exc, "message", None)
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
return exc.__class__.__name__
|
||||
@@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _append_text(content: str | list | None, text: str) -> str | list:
|
||||
"""Append *text* to AIMessage content, handling str, list, and None.
|
||||
|
||||
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
||||
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
||||
a string to a list, which would raise ``TypeError``.
|
||||
"""
|
||||
if content is None:
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
@@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
|
||||
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
@@ -14,6 +14,37 @@ from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@@ -21,6 +52,22 @@ class MemoryMiddlewareState(AgentState):
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
@@ -44,18 +91,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
||||
content_str = str(content)
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
@@ -87,6 +129,48 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@@ -150,7 +234,15 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -101,44 +101,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Synchronously generate a title. Returns state update or None."""
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (sync)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Asynchronously generate a title. Returns state update or None."""
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (async)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
|
||||
+4
-1
@@ -72,6 +72,7 @@ def _build_runtime_middlewares(
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
@@ -90,6 +91,8 @@ def _build_runtime_middlewares(
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
@@ -135,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
@@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_OUTLINE_PREVIEW_LINES = 5
|
||||
|
||||
|
||||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||||
"""Return the document outline and fallback preview for *file_path*.
|
||||
|
||||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
(outline, preview) where:
|
||||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||||
Empty when no headings are found or no .md exists.
|
||||
- preview: first few non-empty lines of the .md, used as a content
|
||||
anchor when outline is empty so the agent has some context.
|
||||
Empty when outline is non-empty (no fallback needed).
|
||||
"""
|
||||
md_path = file_path.with_suffix(".md")
|
||||
if not md_path.is_file():
|
||||
return [], []
|
||||
|
||||
outline = extract_outline(md_path)
|
||||
if outline:
|
||||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||||
return outline, []
|
||||
|
||||
# outline is empty — read the first few non-empty lines as a content preview
|
||||
preview: list[str] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
preview.append(stripped)
|
||||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||||
return [], preview
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
@@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
outline = file.get("outline") or []
|
||||
if outline:
|
||||
truncated = outline[-1].get("truncated", False)
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||||
for entry in visible:
|
||||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||||
if truncated:
|
||||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||||
else:
|
||||
preview = file.get("outline_preview") or []
|
||||
if preview:
|
||||
lines.append(" No structural headings detected. Document begins with:")
|
||||
for text in preview:
|
||||
lines.append(f" > {text}")
|
||||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("")
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
Each file dict may contain an optional ``outline`` key — a list of
|
||||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
@@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
lines.append("")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
|
||||
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
|
||||
lines.append("To work with these files:")
|
||||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Use `glob` to find files by name pattern")
|
||||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
@@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
outline, preview = _extract_outline_for_file(file_path)
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
"outline": outline,
|
||||
"outline_preview": preview,
|
||||
}
|
||||
)
|
||||
|
||||
# Attach outlines to new files as well
|
||||
if uploads_dir:
|
||||
for file in new_files:
|
||||
phys_path = uploads_dir / file["filename"]
|
||||
outline, preview = _extract_outline_for_file(phys_path)
|
||||
file["outline"] = outline
|
||||
file["outline_preview"] = preview
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: bool = False,
|
||||
plan_mode: bool = False,
|
||||
agent_name: str | None = None,
|
||||
available_skills: set[str] | None = None,
|
||||
middlewares: Sequence[AgentMiddleware] | None = None,
|
||||
):
|
||||
"""Initialize the client.
|
||||
@@ -133,6 +134,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: Enable subagent delegation.
|
||||
plan_mode: Enable TodoList middleware for plan mode.
|
||||
agent_name: Name of the agent to use.
|
||||
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
|
||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||
"""
|
||||
if config_path is not None:
|
||||
@@ -148,6 +150,7 @@ class DeerFlowClient:
|
||||
self._subagent_enabled = subagent_enabled
|
||||
self._plan_mode = plan_mode
|
||||
self._agent_name = agent_name
|
||||
self._available_skills = set(available_skills) if available_skills is not None else None
|
||||
self._middlewares = list(middlewares) if middlewares else []
|
||||
|
||||
# Lazy agent — created on first call, recreated when config changes.
|
||||
@@ -208,6 +211,8 @@ class DeerFlowClient:
|
||||
cfg.get("thinking_enabled"),
|
||||
cfg.get("is_plan_mode"),
|
||||
cfg.get("subagent_enabled"),
|
||||
self._agent_name,
|
||||
frozenset(self._available_skills) if self._available_skills is not None else None,
|
||||
)
|
||||
|
||||
if self._agent is not None and self._agent_config_key == key:
|
||||
@@ -226,6 +231,7 @@ class DeerFlowClient:
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
agent_name=self._agent_name,
|
||||
available_skills=self._available_skills,
|
||||
),
|
||||
"state_schema": ThreadState,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import uuid
|
||||
from agent_sandbox import Sandbox as AioSandboxClient
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -124,16 +125,96 @@ class AioSandbox(Sandbox):
|
||||
content: The text content to write to the file.
|
||||
append: Whether to append the content to the file.
|
||||
"""
|
||||
try:
|
||||
if append:
|
||||
# Read existing content first and append
|
||||
existing = self.read_file(path)
|
||||
if not existing.startswith("Error:"):
|
||||
content = existing + content
|
||||
self._client.file.write_file(file=path, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
with self._lock:
|
||||
try:
|
||||
if append:
|
||||
existing = self.read_file(path)
|
||||
if not existing.startswith("Error:"):
|
||||
content = existing + content
|
||||
self._client.file.write_file(file=path, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
if not include_dirs:
|
||||
result = self._client.file.find_files(path=path, glob=pattern)
|
||||
files = result.data.files if result.data and result.data.files else []
|
||||
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
|
||||
truncated = len(filtered) > max_results
|
||||
return filtered[:max_results], truncated
|
||||
|
||||
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = result.data.files if result.data and result.data.files else []
|
||||
matches: list[str] = []
|
||||
root_path = path.rstrip("/") or "/"
|
||||
root_prefix = root_path if root_path == "/" else f"{root_path}/"
|
||||
for entry in entries:
|
||||
if entry.path != root_path and not entry.path.startswith(root_prefix):
|
||||
continue
|
||||
if should_ignore_path(entry.path):
|
||||
continue
|
||||
rel_path = entry.path[len(root_path) :].lstrip("/")
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(entry.path)
|
||||
if len(matches) >= max_results:
|
||||
return matches, True
|
||||
return matches, False
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
import re as _re
|
||||
|
||||
regex_source = _re.escape(pattern) if literal else pattern
|
||||
# Validate the pattern locally so an invalid regex raises re.error
|
||||
# (caught by grep_tool's except re.error handler) rather than a
|
||||
# generic remote API error.
|
||||
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
|
||||
regex = regex_source if case_sensitive else f"(?i){regex_source}"
|
||||
|
||||
if glob is not None:
|
||||
find_result = self._client.file.find_files(path=path, glob=glob)
|
||||
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
|
||||
else:
|
||||
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = list_result.data.files if list_result.data and list_result.data.files else []
|
||||
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
|
||||
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
|
||||
for file_path in candidate_paths:
|
||||
if should_ignore_path(file_path):
|
||||
continue
|
||||
|
||||
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
|
||||
data = search_result.data
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
line_numbers = data.line_numbers or []
|
||||
matched_lines = data.matches or []
|
||||
for line_number, line in zip(line_numbers, matched_lines):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=file_path,
|
||||
line_number=line_number if isinstance(line_number, int) else 0,
|
||||
line=truncate_line(line),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content in the sandbox.
|
||||
@@ -142,9 +223,10 @@ class AioSandbox(Sandbox):
|
||||
path: The absolute path of the file to update.
|
||||
content: The binary content to write to the file.
|
||||
"""
|
||||
try:
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file in sandbox: {e}")
|
||||
raise
|
||||
with self._lock:
|
||||
try:
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
class JinaClient:
|
||||
def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
global _api_key_warned
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Return-Format": return_format,
|
||||
@@ -15,11 +18,13 @@ class JinaClient:
|
||||
}
|
||||
if os.getenv("JINA_API_KEY"):
|
||||
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
|
||||
else:
|
||||
elif not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||
data = {"url": url}
|
||||
try:
|
||||
response = requests.post("https://r.jina.ai/", headers=headers, json=data)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"Jina API returned status {response.status_code}: {response.text}"
|
||||
@@ -34,5 +39,5 @@ class JinaClient:
|
||||
return response.text
|
||||
except Exception as e:
|
||||
error_message = f"Request to Jina API failed: {str(e)}"
|
||||
logger.error(error_message)
|
||||
logger.exception(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
@@ -8,7 +8,7 @@ readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
async def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@@ -23,6 +23,8 @@ def web_fetch_tool(url: str) -> str:
|
||||
config = get_app_config().get_tool_config("web_fetch")
|
||||
if config is not None and "timeout" in config.model_extra:
|
||||
timeout = config.model_extra.get("timeout")
|
||||
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||
return html_content
|
||||
article = readability_extractor.extract_article(html_content)
|
||||
return article.to_markdown()[:4096]
|
||||
|
||||
@@ -3,7 +3,13 @@ from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skills_config import SkillsConfig
|
||||
from .tracing_config import get_tracing_config, is_tracing_enabled
|
||||
from .tracing_config import (
|
||||
get_enabled_tracing_providers,
|
||||
get_explicitly_enabled_tracing_providers,
|
||||
get_tracing_config,
|
||||
is_tracing_enabled,
|
||||
validate_enabled_tracing_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_app_config",
|
||||
@@ -15,5 +21,8 @@ __all__ = [
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
"get_explicitly_enabled_tracing_providers",
|
||||
"get_enabled_tracing_providers",
|
||||
"is_tracing_enabled",
|
||||
"validate_enabled_tracing_providers",
|
||||
]
|
||||
|
||||
@@ -22,6 +22,11 @@ class AgentConfig(BaseModel):
|
||||
description: str = ""
|
||||
model: str | None = None
|
||||
tool_groups: list[str] | None = None
|
||||
# skills controls which skills are loaded into the agent's prompt:
|
||||
# - None (or omitted): load all enabled skills (default fallback behavior)
|
||||
# - [] (explicit empty list): disable all skills
|
||||
# - ["skill1", "skill2"]: load only the specified skills
|
||||
skills: list[str] | None = None
|
||||
|
||||
|
||||
def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
@@ -10,15 +11,15 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import load_memory_config_from_dict
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import load_title_config_from_dict
|
||||
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
@@ -28,6 +29,13 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (backend_dir / "config.yaml", repo_root / "config.yaml")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
@@ -40,6 +48,11 @@ class AppConfig(BaseModel):
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
|
||||
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
|
||||
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
|
||||
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
@@ -51,7 +64,7 @@ class AppConfig(BaseModel):
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
@@ -64,14 +77,10 @@ class AppConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the config.yaml is in the current directory
|
||||
path = Path(os.getcwd()) / "config.yaml"
|
||||
if not path.exists():
|
||||
# Check if the config.yaml is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "config.yaml"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
|
||||
return path
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
@@ -244,6 +253,8 @@ _app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
@@ -276,6 +287,10 @@ def get_app_config() -> AppConfig:
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
@@ -337,3 +352,26 @@ def set_app_config(config: AppConfig) -> None:
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
|
||||
@@ -80,6 +80,12 @@ class ExtensionsConfig(BaseModel):
|
||||
Args:
|
||||
config_path: Optional path to extensions config file.
|
||||
|
||||
Resolution order:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search backend/repository-root defaults for
|
||||
`extensions_config.json`, then legacy `mcp_config.json`.
|
||||
|
||||
Returns:
|
||||
Path to the extensions config file if found, otherwise None.
|
||||
"""
|
||||
@@ -94,24 +100,16 @@ class ExtensionsConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the extensions_config.json is in the current directory
|
||||
path = Path(os.getcwd()) / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Check if the extensions_config.json is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Backward compatibility: check for mcp_config.json
|
||||
path = Path(os.getcwd()) / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
path = Path(os.getcwd()).parent / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
for path in (
|
||||
backend_dir / "extensions_config.json",
|
||||
repo_root / "extensions_config.json",
|
||||
backend_dir / "mcp_config.json",
|
||||
repo_root / "mcp_config.json",
|
||||
):
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Extensions are optional, so return None if not found
|
||||
return None
|
||||
|
||||
@@ -9,6 +9,12 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
"""Return the repo-local DeerFlow state directory without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
return backend_dir / ".deer-flow"
|
||||
|
||||
|
||||
def _validate_thread_id(thread_id: str) -> str:
|
||||
"""Validate a thread ID before using it in filesystem paths."""
|
||||
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
||||
@@ -67,8 +73,7 @@ class Paths:
|
||||
BaseDir resolution (in priority order):
|
||||
1. Constructor argument `base_dir`
|
||||
2. DEER_FLOW_HOME environment variable
|
||||
3. Local dev fallback: cwd/.deer-flow (when cwd is the backend/ dir)
|
||||
4. Default: $HOME/.deer-flow
|
||||
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None) -> None:
|
||||
@@ -104,11 +109,7 @@ class Paths:
|
||||
if env_home := os.getenv("DEER_FLOW_HOME"):
|
||||
return Path(env_home).resolve()
|
||||
|
||||
cwd = Path.cwd()
|
||||
if cwd.name == "backend" or (cwd / "pyproject.toml").exists():
|
||||
return cwd / ".deer-flow"
|
||||
|
||||
return Path.home() / ".deer-flow"
|
||||
return _default_local_base_dir()
|
||||
|
||||
@property
|
||||
def memory_file(self) -> Path:
|
||||
|
||||
@@ -64,4 +64,15 @@ class SandboxConfig(BaseModel):
|
||||
description="Environment variables to inject into the sandbox container. Values starting with $ will be resolved from host environment variables.",
|
||||
)
|
||||
|
||||
bash_output_max_chars: int = Field(
|
||||
default=20000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from bash tool output. Output exceeding this limit is middle-truncated (head + tail), preserving the first and last half. Set to 0 to disable truncation.",
|
||||
)
|
||||
read_file_output_max_chars: int = Field(
|
||||
default=50000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -3,6 +3,11 @@ from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _default_repo_root() -> Path:
|
||||
"""Resolve the repo root without relying on the current working directory."""
|
||||
return Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
@@ -26,8 +31,8 @@ class SkillsConfig(BaseModel):
|
||||
# Use configured path (can be absolute or relative)
|
||||
path = Path(self.path)
|
||||
if not path.is_absolute():
|
||||
# If relative, resolve from current working directory
|
||||
path = Path.cwd() / path
|
||||
# If relative, resolve from the repo root for deterministic behavior.
|
||||
path = _default_repo_root() / path
|
||||
return path.resolve()
|
||||
else:
|
||||
# Default: ../skills relative to backend directory
|
||||
|
||||
@@ -15,6 +15,11 @@ class SubagentOverrideConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Timeout in seconds for this subagent (None = use global default)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Maximum turns for this subagent (None = use global or builtin default)",
|
||||
)
|
||||
|
||||
|
||||
class SubagentsAppConfig(BaseModel):
|
||||
@@ -25,6 +30,11 @@ class SubagentsAppConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
|
||||
)
|
||||
agents: dict[str, SubagentOverrideConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Per-agent configuration overrides keyed by agent name",
|
||||
@@ -44,6 +54,15 @@ class SubagentsAppConfig(BaseModel):
|
||||
return override.timeout_seconds
|
||||
return self.timeout_seconds
|
||||
|
||||
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
|
||||
"""Get the effective max_turns for a specific agent."""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.max_turns is not None:
|
||||
return override.max_turns
|
||||
if self.max_turns is not None:
|
||||
return self.max_turns
|
||||
return builtin_default
|
||||
|
||||
|
||||
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
|
||||
|
||||
@@ -58,8 +77,26 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
global _subagents_config
|
||||
_subagents_config = SubagentsAppConfig(**config_dict)
|
||||
|
||||
overrides_summary = {name: f"{override.timeout_seconds}s" for name, override in _subagents_config.agents.items() if override.timeout_seconds is not None}
|
||||
overrides_summary = {}
|
||||
for name, override in _subagents_config.agents.items():
|
||||
parts = []
|
||||
if override.timeout_seconds is not None:
|
||||
parts.append(f"timeout={override.timeout_seconds}s")
|
||||
if override.max_turns is not None:
|
||||
parts.append(f"max_turns={override.max_turns}")
|
||||
if parts:
|
||||
overrides_summary[name] = ", ".join(parts)
|
||||
|
||||
if overrides_summary:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, per-agent overrides={overrides_summary}")
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
overrides_summary,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, no per-agent overrides")
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
class TracingConfig(BaseModel):
|
||||
class LangSmithTracingConfig(BaseModel):
|
||||
"""Configuration for LangSmith tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
@@ -18,9 +16,69 @@ class TracingConfig(BaseModel):
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if tracing is fully configured (enabled and has API key)."""
|
||||
return self.enabled and bool(self.api_key)
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.enabled and not self.api_key:
|
||||
raise ValueError("LangSmith tracing is enabled but LANGSMITH_API_KEY (or LANGCHAIN_API_KEY) is not set.")
|
||||
|
||||
|
||||
class LangfuseTracingConfig(BaseModel):
|
||||
"""Configuration for Langfuse tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
public_key: str | None = Field(...)
|
||||
secret_key: str | None = Field(...)
|
||||
host: str = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return self.enabled and bool(self.public_key) and bool(self.secret_key)
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
missing: list[str] = []
|
||||
if not self.public_key:
|
||||
missing.append("LANGFUSE_PUBLIC_KEY")
|
||||
if not self.secret_key:
|
||||
missing.append("LANGFUSE_SECRET_KEY")
|
||||
if missing:
|
||||
raise ValueError(f"Langfuse tracing is enabled but required settings are missing: {', '.join(missing)}")
|
||||
|
||||
|
||||
class TracingConfig(BaseModel):
|
||||
"""Tracing configuration for supported providers."""
|
||||
|
||||
langsmith: LangSmithTracingConfig = Field(...)
|
||||
langfuse: LangfuseTracingConfig = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.enabled_providers)
|
||||
|
||||
@property
|
||||
def explicitly_enabled_providers(self) -> list[str]:
|
||||
enabled: list[str] = []
|
||||
if self.langsmith.enabled:
|
||||
enabled.append("langsmith")
|
||||
if self.langfuse.enabled:
|
||||
enabled.append("langfuse")
|
||||
return enabled
|
||||
|
||||
@property
|
||||
def enabled_providers(self) -> list[str]:
|
||||
enabled: list[str] = []
|
||||
if self.langsmith.is_configured:
|
||||
enabled.append("langsmith")
|
||||
if self.langfuse.is_configured:
|
||||
enabled.append("langfuse")
|
||||
return enabled
|
||||
|
||||
def validate_enabled(self) -> None:
|
||||
self.langsmith.validate()
|
||||
self.langfuse.validate()
|
||||
|
||||
|
||||
_tracing_config: TracingConfig | None = None
|
||||
|
||||
@@ -29,12 +87,7 @@ _TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _env_flag_preferred(*names: str) -> bool:
|
||||
"""Return the boolean value of the first env var that is present and non-empty.
|
||||
|
||||
Accepted truthy values (case-insensitive): ``1``, ``true``, ``yes``, ``on``.
|
||||
Any other non-empty value is treated as falsy. If none of the named
|
||||
variables is set, returns ``False``.
|
||||
"""
|
||||
"""Return the boolean value of the first env var that is present and non-empty."""
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value is not None and value.strip():
|
||||
@@ -52,43 +105,45 @@ def _first_env_value(*names: str) -> str | None:
|
||||
|
||||
|
||||
def get_tracing_config() -> TracingConfig:
|
||||
"""Get the current tracing configuration from environment variables.
|
||||
|
||||
``LANGSMITH_*`` variables take precedence over their legacy ``LANGCHAIN_*``
|
||||
counterparts. For boolean flags (``enabled``), the *first* variable that is
|
||||
present and non-empty in the priority list is the sole authority – its value
|
||||
is parsed and returned without consulting the remaining candidates. Accepted
|
||||
truthy values are ``1``, ``true``, ``yes``, and ``on`` (case-insensitive);
|
||||
any other non-empty value is treated as falsy.
|
||||
|
||||
Priority order:
|
||||
enabled : LANGSMITH_TRACING > LANGCHAIN_TRACING_V2 > LANGCHAIN_TRACING
|
||||
api_key : LANGSMITH_API_KEY > LANGCHAIN_API_KEY
|
||||
project : LANGSMITH_PROJECT > LANGCHAIN_PROJECT (default: "deer-flow")
|
||||
endpoint : LANGSMITH_ENDPOINT > LANGCHAIN_ENDPOINT (default: https://api.smith.langchain.com)
|
||||
|
||||
Returns:
|
||||
TracingConfig with current settings.
|
||||
"""
|
||||
"""Get the current tracing configuration from environment variables."""
|
||||
global _tracing_config
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
with _config_lock:
|
||||
if _tracing_config is not None: # Double-check after acquiring lock
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
_tracing_config = TracingConfig(
|
||||
# Keep compatibility with both legacy LANGCHAIN_* and newer LANGSMITH_* variables.
|
||||
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
|
||||
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
|
||||
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
|
||||
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
|
||||
langsmith=LangSmithTracingConfig(
|
||||
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
|
||||
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
|
||||
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
|
||||
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
|
||||
),
|
||||
langfuse=LangfuseTracingConfig(
|
||||
enabled=_env_flag_preferred("LANGFUSE_TRACING"),
|
||||
public_key=_first_env_value("LANGFUSE_PUBLIC_KEY"),
|
||||
secret_key=_first_env_value("LANGFUSE_SECRET_KEY"),
|
||||
host=_first_env_value("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com",
|
||||
),
|
||||
)
|
||||
return _tracing_config
|
||||
|
||||
|
||||
def get_enabled_tracing_providers() -> list[str]:
|
||||
"""Return the configured tracing providers that are enabled and complete."""
|
||||
return get_tracing_config().enabled_providers
|
||||
|
||||
|
||||
def get_explicitly_enabled_tracing_providers() -> list[str]:
|
||||
"""Return tracing providers explicitly enabled by config, even if incomplete."""
|
||||
return get_tracing_config().explicitly_enabled_providers
|
||||
|
||||
|
||||
def validate_enabled_tracing_providers() -> None:
|
||||
"""Validate that any explicitly enabled providers are fully configured."""
|
||||
get_tracing_config().validate_enabled()
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if LangSmith tracing is enabled and configured.
|
||||
Returns:
|
||||
True if tracing is enabled and has an API key.
|
||||
"""
|
||||
"""Check if any tracing provider is enabled and fully configured."""
|
||||
return get_tracing_config().is_configured
|
||||
|
||||
@@ -2,8 +2,9 @@ import logging
|
||||
|
||||
from langchain.chat_models import BaseChatModel
|
||||
|
||||
from deerflow.config import get_app_config, get_tracing_config, is_tracing_enabled
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.reflection import resolve_class
|
||||
from deerflow.tracing import build_tracing_callbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,17 +80,9 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
if is_tracing_enabled():
|
||||
try:
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
tracing_config = get_tracing_config()
|
||||
tracer = LangChainTracer(
|
||||
project_name=tracing_config.project,
|
||||
)
|
||||
existing_callbacks = model_instance.callbacks or []
|
||||
model_instance.callbacks = [*existing_callbacks, tracer]
|
||||
logger.debug(f"LangSmith tracing attached to model '{name}' (project='{tracing_config.project}')")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to attach LangSmith tracing to model '{name}': {e}")
|
||||
callbacks = build_tracing_callbacks()
|
||||
if callbacks:
|
||||
existing_callbacks = model_instance.callbacks or []
|
||||
model_instance.callbacks = [*existing_callbacks, *callbacks]
|
||||
logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}")
|
||||
return model_instance
|
||||
|
||||
@@ -90,6 +90,11 @@ async def run_agent(
|
||||
# Inject runtime context so middlewares can access thread_id
|
||||
# (langgraph-cli does this automatically; we must do it manually)
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||
# prefers it over ``configurable`` for thread-level data), make
|
||||
# sure ``thread_id`` is available there too.
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
|
||||
@@ -25,6 +25,7 @@ class MemoryStreamBridge(StreamBridge):
|
||||
self._maxsize = queue_maxsize
|
||||
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
|
||||
self._counters: dict[str, int] = {}
|
||||
self._dropped_counts: dict[str, int] = {}
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
@@ -32,6 +33,7 @@ class MemoryStreamBridge(StreamBridge):
|
||||
if run_id not in self._queues:
|
||||
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
|
||||
self._counters[run_id] = 0
|
||||
self._dropped_counts[run_id] = 0
|
||||
return self._queues[run_id]
|
||||
|
||||
def _next_id(self, run_id: str) -> str:
|
||||
@@ -48,14 +50,41 @@ class MemoryStreamBridge(StreamBridge):
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning("Stream bridge queue full for run %s — dropping event %s", run_id, event)
|
||||
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
|
||||
run_id,
|
||||
event,
|
||||
self._dropped_counts[run_id],
|
||||
)
|
||||
|
||||
async def publish_end(self, run_id: str) -> None:
|
||||
queue = self._get_or_create_queue(run_id)
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(END_SENTINEL), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning("Stream bridge queue full for run %s — dropping END sentinel", run_id)
|
||||
|
||||
# END sentinel is critical — it is the only signal that allows
|
||||
# subscribers to terminate. If the queue is full we evict the
|
||||
# oldest *regular* events to make room rather than dropping END,
|
||||
# which would cause the SSE connection to hang forever and leak
|
||||
# the queue/counter resources for this run_id.
|
||||
if queue.full():
|
||||
evicted = 0
|
||||
while queue.full():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
evicted += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break # pragma: no cover – defensive
|
||||
if evicted:
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
|
||||
run_id,
|
||||
evicted,
|
||||
)
|
||||
|
||||
# After eviction the queue is guaranteed to have space, so a
|
||||
# simple non-blocking put is safe. We still use put() (which
|
||||
# blocks until space is available) as a defensive measure.
|
||||
await queue.put(END_SENTINEL)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
@@ -84,7 +113,18 @@ class MemoryStreamBridge(StreamBridge):
|
||||
await asyncio.sleep(delay)
|
||||
self._queues.pop(run_id, None)
|
||||
self._counters.pop(run_id, None)
|
||||
self._dropped_counts.pop(run_id, None)
|
||||
|
||||
async def close(self) -> None:
|
||||
self._queues.clear()
|
||||
self._counters.clear()
|
||||
self._dropped_counts.clear()
|
||||
|
||||
def dropped_count(self, run_id: str) -> int:
|
||||
"""Return the number of events dropped for *run_id*."""
|
||||
return self._dropped_counts.get(run_id, 0)
|
||||
|
||||
@property
|
||||
def dropped_total(self) -> int:
|
||||
"""Return the total number of events dropped across all runs."""
|
||||
return sum(self._dropped_counts.values())
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import threading
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
|
||||
_FILE_OPERATION_LOCKS: dict[tuple[str, str], threading.Lock] = {}
|
||||
_FILE_OPERATION_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def get_file_operation_lock_key(sandbox: Sandbox, path: str) -> tuple[str, str]:
|
||||
sandbox_id = getattr(sandbox, "id", None)
|
||||
if not sandbox_id:
|
||||
sandbox_id = f"instance:{id(sandbox)}"
|
||||
return sandbox_id, path
|
||||
|
||||
|
||||
def get_file_operation_lock(sandbox: Sandbox, path: str) -> threading.Lock:
|
||||
lock_key = get_file_operation_lock_key(sandbox, path)
|
||||
with _FILE_OPERATION_LOCKS_GUARD:
|
||||
lock = _FILE_OPERATION_LOCKS.get(lock_key)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_FILE_OPERATION_LOCKS[lock_key] = lock
|
||||
return lock
|
||||
@@ -1,72 +1,6 @@
|
||||
import fnmatch
|
||||
from pathlib import Path
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
# Version Control
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
# Dependencies
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
# Build outputs
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
# IDE & Editor
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
# OS generated
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
# Logs & temp files
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
# Coverage & test artifacts
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
|
||||
def _should_ignore(name: str) -> bool:
|
||||
"""Check if a file/directory name matches any ignore pattern."""
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
from deerflow.sandbox.search import should_ignore_name
|
||||
|
||||
|
||||
def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
@@ -95,7 +29,7 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
|
||||
try:
|
||||
for item in current_path.iterdir():
|
||||
if _should_ignore(item.name):
|
||||
if should_ignore_name(item.name):
|
||||
continue
|
||||
|
||||
post_fix = "/" if item.is_dir() else ""
|
||||
|
||||
@@ -1,11 +1,23 @@
|
||||
import errno
|
||||
import ntpath
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.list_dir import list_dir
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PathMapping:
|
||||
"""A path mapping from a container path to a local path with optional read-only flag."""
|
||||
|
||||
container_path: str
|
||||
local_path: str
|
||||
read_only: bool = False
|
||||
|
||||
|
||||
class LocalSandbox(Sandbox):
|
||||
@@ -39,17 +51,42 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(self, id: str, path_mappings: dict[str, str] | None = None):
|
||||
def __init__(self, id: str, path_mappings: list[PathMapping] | None = None):
|
||||
"""
|
||||
Initialize local sandbox with optional path mappings.
|
||||
|
||||
Args:
|
||||
id: Sandbox identifier
|
||||
path_mappings: Dictionary mapping container paths to local paths
|
||||
Example: {"/mnt/skills": "/absolute/path/to/skills"}
|
||||
path_mappings: List of path mappings with optional read-only flag.
|
||||
Skills directory is read-only by default.
|
||||
"""
|
||||
super().__init__(id)
|
||||
self.path_mappings = path_mappings or {}
|
||||
self.path_mappings = path_mappings or []
|
||||
|
||||
def _is_read_only_path(self, resolved_path: str) -> bool:
|
||||
"""Check if a resolved path is under a read-only mount.
|
||||
|
||||
When multiple mappings match (nested mounts), prefer the most specific
|
||||
mapping (i.e. the one whose local_path is the longest prefix of the
|
||||
resolved path), similar to how ``_resolve_path`` handles container paths.
|
||||
"""
|
||||
resolved = str(Path(resolved_path).resolve())
|
||||
|
||||
best_mapping: PathMapping | None = None
|
||||
best_prefix_len = -1
|
||||
|
||||
for mapping in self.path_mappings:
|
||||
local_resolved = str(Path(mapping.local_path).resolve())
|
||||
if resolved == local_resolved or resolved.startswith(local_resolved + os.sep):
|
||||
prefix_len = len(local_resolved)
|
||||
if prefix_len > best_prefix_len:
|
||||
best_prefix_len = prefix_len
|
||||
best_mapping = mapping
|
||||
|
||||
if best_mapping is None:
|
||||
return False
|
||||
|
||||
return best_mapping.read_only
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
@@ -64,7 +101,9 @@ class LocalSandbox(Sandbox):
|
||||
path_str = str(path)
|
||||
|
||||
# Try each mapping (longest prefix first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True):
|
||||
container_path = mapping.container_path
|
||||
local_path = mapping.local_path
|
||||
if path_str == container_path or path_str.startswith(container_path + "/"):
|
||||
# Replace the container path prefix with local path
|
||||
relative = path_str[len(container_path) :].lstrip("/")
|
||||
@@ -84,15 +123,16 @@ class LocalSandbox(Sandbox):
|
||||
Returns:
|
||||
Container path if mapping exists, otherwise original path
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
normalized_path = path.replace("\\", "/")
|
||||
path_str = str(Path(normalized_path).resolve())
|
||||
|
||||
# Try each mapping (longest local path first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True):
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
if path_str.startswith(local_path_resolved):
|
||||
for mapping in sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True):
|
||||
local_path_resolved = str(Path(mapping.local_path).resolve())
|
||||
if path_str == local_path_resolved or path_str.startswith(local_path_resolved + "/"):
|
||||
# Replace the local path prefix with container path
|
||||
relative = path_str[len(local_path_resolved) :].lstrip("/")
|
||||
resolved = f"{container_path}/{relative}" if relative else container_path
|
||||
resolved = f"{mapping.container_path}/{relative}" if relative else mapping.container_path
|
||||
return resolved
|
||||
|
||||
# No mapping found, return original path
|
||||
@@ -111,7 +151,7 @@ class LocalSandbox(Sandbox):
|
||||
import re
|
||||
|
||||
# Sort mappings by local path length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True)
|
||||
|
||||
if not sorted_mappings:
|
||||
return output
|
||||
@@ -119,12 +159,11 @@ class LocalSandbox(Sandbox):
|
||||
# Create pattern that matches absolute paths
|
||||
# Match paths like /Users/... or other absolute paths
|
||||
result = output
|
||||
for container_path, local_path in sorted_mappings:
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
for mapping in sorted_mappings:
|
||||
# Escape the local path for use in regex
|
||||
escaped_local = re.escape(local_path_resolved)
|
||||
# Match the local path followed by optional path components
|
||||
pattern = re.compile(escaped_local + r"(?:/[^\s\"';&|<>()]*)?")
|
||||
escaped_local = re.escape(str(Path(mapping.local_path).resolve()))
|
||||
# Match the local path followed by optional path components with either separator
|
||||
pattern = re.compile(escaped_local + r"(?:[/\\][^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
@@ -147,7 +186,7 @@ class LocalSandbox(Sandbox):
|
||||
import re
|
||||
|
||||
# Sort mappings by length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True)
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True)
|
||||
|
||||
# Build regex pattern to match all container paths
|
||||
# Match container path followed by optional path components
|
||||
@@ -157,7 +196,7 @@ class LocalSandbox(Sandbox):
|
||||
# Create pattern that matches any of the container paths.
|
||||
# The lookahead (?=/|$|...) ensures we only match at a path-segment boundary,
|
||||
# preventing /mnt/skills from matching inside /mnt/skills-extra.
|
||||
patterns = [re.escape(container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for container_path, _ in sorted_mappings]
|
||||
patterns = [re.escape(m.container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for m in sorted_mappings]
|
||||
pattern = re.compile("|".join(f"({p})" for p in patterns))
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
@@ -248,6 +287,8 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
if self._is_read_only_path(resolved_path):
|
||||
raise OSError(errno.EROFS, "Read-only file system", path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
@@ -259,8 +300,43 @@ class LocalSandbox(Sandbox):
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
resolved_path = Path(self._resolve_path(path))
|
||||
matches, truncated = find_glob_matches(resolved_path, pattern, include_dirs=include_dirs, max_results=max_results)
|
||||
return [self._reverse_resolve_path(match) for match in matches], truncated
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
resolved_path = Path(self._resolve_path(path))
|
||||
matches, truncated = find_grep_matches(
|
||||
resolved_path,
|
||||
pattern,
|
||||
glob_pattern=glob,
|
||||
literal=literal,
|
||||
case_sensitive=case_sensitive,
|
||||
max_results=max_results,
|
||||
)
|
||||
return [
|
||||
GrepMatch(
|
||||
path=self._reverse_resolve_path(match.path),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
for match in matches
|
||||
], truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
if self._is_read_only_path(resolved_path):
|
||||
raise OSError(errno.EROFS, "Read-only file system", path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
@@ -14,16 +15,17 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
"""Initialize the local sandbox provider with path mappings."""
|
||||
self._path_mappings = self._setup_path_mappings()
|
||||
|
||||
def _setup_path_mappings(self) -> dict[str, str]:
|
||||
def _setup_path_mappings(self) -> list[PathMapping]:
|
||||
"""
|
||||
Setup path mappings for local sandbox.
|
||||
|
||||
Maps container paths to actual local paths, including skills directory.
|
||||
Maps container paths to actual local paths, including skills directory
|
||||
and any custom mounts configured in config.yaml.
|
||||
|
||||
Returns:
|
||||
Dictionary of path mappings
|
||||
List of path mappings
|
||||
"""
|
||||
mappings = {}
|
||||
mappings: list[PathMapping] = []
|
||||
|
||||
# Map skills container path to local skills directory
|
||||
try:
|
||||
@@ -35,10 +37,63 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
|
||||
# Only add mapping if skills directory exists
|
||||
if skills_path.exists():
|
||||
mappings[container_path] = str(skills_path)
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
container_path=container_path,
|
||||
local_path=str(skills_path),
|
||||
read_only=True, # Skills directory is always read-only
|
||||
)
|
||||
)
|
||||
|
||||
# Map custom mounts from sandbox config
|
||||
_RESERVED_CONTAINER_PREFIXES = [container_path, "/mnt/acp-workspace", "/mnt/user-data"]
|
||||
sandbox_config = config.sandbox
|
||||
if sandbox_config and sandbox_config.mounts:
|
||||
for mount in sandbox_config.mounts:
|
||||
host_path = Path(mount.host_path)
|
||||
container_path = mount.container_path.rstrip("/") or "/"
|
||||
|
||||
if not host_path.is_absolute():
|
||||
logger.warning(
|
||||
"Mount host_path must be absolute, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
|
||||
if not container_path.startswith("/"):
|
||||
logger.warning(
|
||||
"Mount container_path must be absolute, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject mounts that conflict with reserved container paths
|
||||
if any(container_path == p or container_path.startswith(p + "/") for p in _RESERVED_CONTAINER_PREFIXES):
|
||||
logger.warning(
|
||||
"Mount container_path conflicts with reserved prefix, skipping: %s",
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
# Ensure the host path exists before adding mapping
|
||||
if host_path.exists():
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
container_path=container_path,
|
||||
local_path=str(host_path.resolve()),
|
||||
read_only=mount.read_only,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Mount host_path does not exist, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail if config loading fails
|
||||
logger.warning("Could not setup skills path mapping: %s", e, exc_info=True)
|
||||
logger.warning("Could not setup path mappings: %s", e, exc_info=True)
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
|
||||
|
||||
class Sandbox(ABC):
|
||||
"""Abstract base class for sandbox environments"""
|
||||
@@ -61,6 +63,25 @@ class Sandbox(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
"""Find paths that match a glob pattern under a root directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
"""Search for matches inside text files under a directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content.
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
DEFAULT_MAX_FILE_SIZE_BYTES = 1_000_000
|
||||
DEFAULT_LINE_SUMMARY_LENGTH = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrepMatch:
|
||||
path: str
|
||||
line_number: int
|
||||
line: str
|
||||
|
||||
|
||||
def should_ignore_name(name: str) -> bool:
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def should_ignore_path(path: str) -> bool:
|
||||
return any(should_ignore_name(segment) for segment in path.replace("\\", "/").split("/") if segment)
|
||||
|
||||
|
||||
def path_matches(pattern: str, rel_path: str) -> bool:
|
||||
path = PurePosixPath(rel_path)
|
||||
if path.match(pattern):
|
||||
return True
|
||||
if pattern.startswith("**/"):
|
||||
return path.match(pattern[3:])
|
||||
return False
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = DEFAULT_LINE_SUMMARY_LENGTH) -> str:
|
||||
line = line.rstrip("\n\r")
|
||||
if len(line) <= max_chars:
|
||||
return line
|
||||
return line[: max_chars - 3] + "..."
|
||||
|
||||
|
||||
def is_binary_file(path: Path, sample_size: int = 8192) -> bool:
|
||||
try:
|
||||
with path.open("rb") as handle:
|
||||
return b"\0" in handle.read(sample_size)
|
||||
except OSError:
|
||||
return True
|
||||
|
||||
|
||||
def find_glob_matches(root: Path, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
matches: list[str] = []
|
||||
truncated = False
|
||||
root = root.resolve()
|
||||
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(root)
|
||||
if not root.is_dir():
|
||||
raise NotADirectoryError(root)
|
||||
|
||||
for current_root, dirs, files in os.walk(root):
|
||||
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
|
||||
# root is already resolved; os.walk builds current_root by joining under root,
|
||||
# so relative_to() works without an extra stat()/resolve() per directory.
|
||||
rel_dir = Path(current_root).relative_to(root)
|
||||
|
||||
if include_dirs:
|
||||
for name in dirs:
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(str(Path(current_root) / name))
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
for name in files:
|
||||
if should_ignore_name(name):
|
||||
continue
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(str(Path(current_root) / name))
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
|
||||
def find_grep_matches(
|
||||
root: Path,
|
||||
pattern: str,
|
||||
*,
|
||||
glob_pattern: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
max_file_size: int = DEFAULT_MAX_FILE_SIZE_BYTES,
|
||||
line_summary_length: int = DEFAULT_LINE_SUMMARY_LENGTH,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
root = root.resolve()
|
||||
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(root)
|
||||
if not root.is_dir():
|
||||
raise NotADirectoryError(root)
|
||||
|
||||
regex_source = re.escape(pattern) if literal else pattern
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(regex_source, flags)
|
||||
|
||||
# Skip lines longer than this to prevent ReDoS on minified / no-newline files.
|
||||
_max_line_chars = line_summary_length * 10
|
||||
|
||||
for current_root, dirs, files in os.walk(root):
|
||||
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
|
||||
rel_dir = Path(current_root).relative_to(root)
|
||||
|
||||
for name in files:
|
||||
if should_ignore_name(name):
|
||||
continue
|
||||
|
||||
candidate_path = Path(current_root) / name
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
|
||||
if glob_pattern is not None and not path_matches(glob_pattern, rel_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
if candidate_path.is_symlink():
|
||||
continue
|
||||
file_path = candidate_path.resolve()
|
||||
if not file_path.is_relative_to(root):
|
||||
continue
|
||||
if file_path.stat().st_size > max_file_size or is_binary_file(file_path):
|
||||
continue
|
||||
with file_path.open(encoding="utf-8", errors="replace") as handle:
|
||||
for line_number, line in enumerate(handle, start=1):
|
||||
if len(line) > _max_line_chars:
|
||||
continue
|
||||
if regex.search(line):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=str(file_path),
|
||||
line_number=line_number,
|
||||
line=truncate_line(line, line_summary_length),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
return matches, truncated
|
||||
@@ -7,17 +7,21 @@ from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
SandboxError,
|
||||
SandboxNotFoundError,
|
||||
SandboxRuntimeError,
|
||||
)
|
||||
from deerflow.sandbox.file_operation_lock import get_file_operation_lock
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])/(?:[^\s\"'`;&|<>()]+)")
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
||||
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
|
||||
_LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
|
||||
"/bin/",
|
||||
"/usr/bin/",
|
||||
@@ -29,6 +33,10 @@ _LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
|
||||
|
||||
_DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
|
||||
_ACP_WORKSPACE_VIRTUAL_PATH = "/mnt/acp-workspace"
|
||||
_DEFAULT_GLOB_MAX_RESULTS = 200
|
||||
_MAX_GLOB_MAX_RESULTS = 1000
|
||||
_DEFAULT_GREP_MAX_RESULTS = 100
|
||||
_MAX_GREP_MAX_RESULTS = 500
|
||||
|
||||
|
||||
def _get_skills_container_path() -> str:
|
||||
@@ -111,6 +119,54 @@ def _is_acp_workspace_path(path: str) -> bool:
|
||||
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
|
||||
|
||||
|
||||
def _get_custom_mounts():
|
||||
"""Get custom volume mounts from sandbox config.
|
||||
|
||||
Result is cached after the first successful config load. If config loading
|
||||
fails an empty list is returned *without* caching so that a later call can
|
||||
pick up the real value once the config is available.
|
||||
"""
|
||||
cached = getattr(_get_custom_mounts, "_cached", None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
mounts = []
|
||||
if config.sandbox and config.sandbox.mounts:
|
||||
# Only include mounts whose host_path exists, consistent with
|
||||
# LocalSandboxProvider._setup_path_mappings() which also filters
|
||||
# by host_path.exists().
|
||||
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
|
||||
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
|
||||
return mounts
|
||||
except Exception:
|
||||
# If config loading fails, return an empty list without caching so that
|
||||
# a later call can retry once the config is available.
|
||||
return []
|
||||
|
||||
|
||||
def _is_custom_mount_path(path: str) -> bool:
|
||||
"""Check if path is under a custom mount container_path."""
|
||||
for mount in _get_custom_mounts():
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_custom_mount_for_path(path: str):
|
||||
"""Get the mount config matching this path (longest prefix first)."""
|
||||
best = None
|
||||
for mount in _get_custom_mounts():
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
if best is None or len(mount.container_path) > len(best.container_path):
|
||||
best = mount
|
||||
return best
|
||||
|
||||
|
||||
def _extract_thread_id_from_thread_data(thread_data: "ThreadDataState | None") -> str | None:
|
||||
"""Extract thread_id from thread_data by inspecting workspace_path.
|
||||
|
||||
@@ -243,16 +299,84 @@ def _get_mcp_allowed_paths() -> list[str]:
|
||||
return allowed_paths
|
||||
|
||||
|
||||
def _get_tool_config_int(name: str, key: str, default: int) -> int:
|
||||
try:
|
||||
tool_config = get_app_config().get_tool_config(name)
|
||||
if tool_config is not None and key in tool_config.model_extra:
|
||||
value = tool_config.model_extra.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
except Exception:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
|
||||
if value <= 0:
|
||||
return default
|
||||
return min(value, upper_bound)
|
||||
|
||||
|
||||
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
|
||||
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
|
||||
configured_max_results = _clamp_max_results(
|
||||
_get_tool_config_int(name, "max_results", default),
|
||||
default=default,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
return min(requested_max_results, configured_max_results)
|
||||
|
||||
|
||||
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
|
||||
validate_local_tool_path(path, thread_data, read_only=True)
|
||||
if _is_skills_path(path):
|
||||
return _resolve_skills_path(path)
|
||||
if _is_acp_workspace_path(path):
|
||||
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
return _resolve_and_validate_user_data_path(path, thread_data)
|
||||
|
||||
|
||||
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
|
||||
if not matches:
|
||||
return f"No files matched under {root_path}"
|
||||
|
||||
lines = [f"Found {len(matches)} paths under {root_path}"]
|
||||
if truncated:
|
||||
lines[0] += f" (showing first {len(matches)})"
|
||||
lines.extend(f"{index}. {path}" for index, path in enumerate(matches, start=1))
|
||||
if truncated:
|
||||
lines.append("Results truncated. Narrow the path or pattern to see fewer matches.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_grep_results(root_path: str, matches: list[GrepMatch], truncated: bool) -> str:
|
||||
if not matches:
|
||||
return f"No matches found under {root_path}"
|
||||
|
||||
lines = [f"Found {len(matches)} matches under {root_path}"]
|
||||
if truncated:
|
||||
lines[0] += f" (showing first {len(matches)})"
|
||||
lines.extend(f"{match.path}:{match.line_number}: {match.line}" for match in matches)
|
||||
if truncated:
|
||||
lines.append("Results truncated. Narrow the path or add a glob filter.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _path_variants(path: str) -> set[str]:
|
||||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||
|
||||
|
||||
def _path_separator_for_style(path: str) -> str:
|
||||
return "\\" if "\\" in path and "/" not in path else "/"
|
||||
|
||||
|
||||
def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
if not relative:
|
||||
return base
|
||||
if "/" in base and "\\" not in base:
|
||||
return f"{base.rstrip('/')}/{relative}"
|
||||
return str(Path(base) / relative)
|
||||
separator = _path_separator_for_style(base)
|
||||
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
|
||||
stripped_base = base.rstrip("/\\")
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
@@ -297,7 +421,10 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
return actual_base
|
||||
if path.startswith(f"{virtual_base}/"):
|
||||
rest = path[len(virtual_base) :].lstrip("/")
|
||||
return _join_path_preserving_style(actual_base, rest)
|
||||
result = _join_path_preserving_style(actual_base, rest)
|
||||
if path.endswith("/") and not result.endswith(("/", "\\")):
|
||||
result += _path_separator_for_style(actual_base)
|
||||
return result
|
||||
|
||||
return path
|
||||
|
||||
@@ -377,6 +504,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
|
||||
|
||||
result = pattern.sub(replace_acp, result)
|
||||
|
||||
# Custom mount host paths are masked by LocalSandbox._reverse_resolve_paths_in_output()
|
||||
|
||||
# Mask user-data host paths
|
||||
if thread_data is None:
|
||||
return result
|
||||
@@ -425,6 +554,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
- ``/mnt/user-data/*`` — always allowed (read + write)
|
||||
- ``/mnt/skills/*`` — allowed only when *read_only* is True
|
||||
- ``/mnt/acp-workspace/*`` — allowed only when *read_only* is True
|
||||
- Custom mount paths (from config.yaml) — respects per-mount ``read_only`` flag
|
||||
|
||||
Args:
|
||||
path: The virtual path to validate.
|
||||
@@ -456,7 +586,14 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
if path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, or {_ACP_WORKSPACE_VIRTUAL_PATH}/ are allowed")
|
||||
# Custom mount paths — respect read_only config
|
||||
if _is_custom_mount_path(path):
|
||||
mount = _get_custom_mount_for_path(path)
|
||||
if mount and mount.read_only and not read_only:
|
||||
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
|
||||
|
||||
|
||||
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
|
||||
@@ -506,15 +643,21 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
boundary and must not be treated as isolation from the host filesystem.
|
||||
|
||||
In local mode, commands must use virtual paths under /mnt/user-data for
|
||||
user data access. Skills paths under /mnt/skills and ACP workspace paths
|
||||
under /mnt/acp-workspace are allowed (path-traversal checks only; write
|
||||
prevention for bash commands is not enforced here).
|
||||
user data access. Skills paths under /mnt/skills, ACP workspace paths
|
||||
under /mnt/acp-workspace, and custom mount container paths (configured in
|
||||
config.yaml) are allowed (path-traversal checks only; write prevention
|
||||
for bash commands is not enforced here).
|
||||
A small allowlist of common system path prefixes is kept for executable
|
||||
and device references (e.g. /bin/sh, /dev/null).
|
||||
"""
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
|
||||
# Block file:// URLs which bypass the absolute-path regex but allow local file exfiltration
|
||||
file_url_match = _FILE_URL_PATTERN.search(command)
|
||||
if file_url_match:
|
||||
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
|
||||
|
||||
unsafe_paths: list[str] = []
|
||||
allowed_paths = _get_mcp_allowed_paths()
|
||||
|
||||
@@ -538,6 +681,11 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
# Allow custom mount container paths
|
||||
if _is_custom_mount_path(absolute_path):
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
if any(absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
|
||||
continue
|
||||
|
||||
@@ -582,6 +730,8 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
|
||||
|
||||
result = acp_pattern.sub(replace_acp_match, result)
|
||||
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_paths_in_command()
|
||||
|
||||
# Replace user-data paths
|
||||
if VIRTUAL_PATH_PREFIX in result and thread_data is not None:
|
||||
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
|
||||
@@ -659,7 +809,8 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
return sandbox
|
||||
|
||||
|
||||
@@ -694,7 +845,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
# Sandbox was released, fall through to acquire new one
|
||||
|
||||
@@ -716,7 +868,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
|
||||
|
||||
@@ -757,6 +910,59 @@ def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState]
|
||||
runtime.state["thread_directories_created"] = True
|
||||
|
||||
|
||||
def _truncate_bash_output(output: str, max_chars: int) -> str:
|
||||
"""Middle-truncate bash output, preserving head and tail (50/50 split).
|
||||
|
||||
bash output may have errors at either end (stderr/stdout ordering is
|
||||
non-deterministic), so both ends are preserved equally.
|
||||
|
||||
The returned string (including the truncation marker) is guaranteed to be
|
||||
no longer than max_chars characters. Pass max_chars=0 to disable truncation
|
||||
and return the full output unchanged.
|
||||
"""
|
||||
if max_chars == 0:
|
||||
return output
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
total_len = len(output)
|
||||
# Compute the exact worst-case marker length: skipped chars is at most
|
||||
# total_len, so this is a tight upper bound.
|
||||
marker_max_len = len(f"\n... [middle truncated: {total_len} chars skipped] ...\n")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return output[:max_chars]
|
||||
head_len = kept // 2
|
||||
tail_len = kept - head_len
|
||||
skipped = total_len - kept
|
||||
marker = f"\n... [middle truncated: {skipped} chars skipped] ...\n"
|
||||
return f"{output[:head_len]}{marker}{output[-tail_len:] if tail_len > 0 else ''}"
|
||||
|
||||
|
||||
def _truncate_read_file_output(output: str, max_chars: int) -> str:
|
||||
"""Head-truncate read_file output, preserving the beginning of the file.
|
||||
|
||||
Source code and documents are read top-to-bottom; the head contains the
|
||||
most context (imports, class definitions, function signatures).
|
||||
|
||||
The returned string (including the truncation marker) is guaranteed to be
|
||||
no longer than max_chars characters. Pass max_chars=0 to disable truncation
|
||||
and return the full output unchanged.
|
||||
"""
|
||||
if max_chars == 0:
|
||||
return output
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
total = len(output)
|
||||
# Compute the exact worst-case marker length: both numeric fields are at
|
||||
# their maximum (total chars), so this is a tight upper bound.
|
||||
marker_max_len = len(f"\n... [truncated: showing first {total} of {total} chars. Use start_line/end_line to read a specific range] ...")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return output[:max_chars]
|
||||
marker = f"\n... [truncated: showing first {kept} of {total} chars. Use start_line/end_line to read a specific range] ..."
|
||||
return f"{output[:kept]}{marker}"
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
@@ -781,9 +987,23 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
command = replace_virtual_paths_in_command(command, thread_data)
|
||||
command = _apply_cwd_prefix(command, thread_data)
|
||||
output = sandbox.execute_command(command)
|
||||
return mask_local_paths_in_output(output, thread_data)
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
return sandbox.execute_command(command)
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except PermissionError as e:
|
||||
@@ -811,8 +1031,9 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
path = _resolve_skills_path(path)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
else:
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
children = sandbox.list_dir(path)
|
||||
if not children:
|
||||
return "(empty)"
|
||||
@@ -827,6 +1048,126 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include_dirs: bool = False,
|
||||
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
|
||||
) -> str:
|
||||
"""Find files or directories that match a glob pattern under a root directory.
|
||||
|
||||
Args:
|
||||
description: Explain why you are searching for these paths in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
pattern: The glob pattern to match relative to the root path, for example `**/*.py`.
|
||||
path: The **absolute** root directory to search under.
|
||||
include_dirs: Whether matching directories should also be returned. Default is False.
|
||||
max_results: Maximum number of paths to return. Default is 200.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
"glob",
|
||||
max_results,
|
||||
default=_DEFAULT_GLOB_MAX_RESULTS,
|
||||
upper_bound=_MAX_GLOB_MAX_RESULTS,
|
||||
)
|
||||
thread_data = None
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
|
||||
if thread_data is not None:
|
||||
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
|
||||
return _format_glob_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except NotADirectoryError:
|
||||
return f"Error: Path is not a directory: {requested_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
|
||||
) -> str:
|
||||
"""Search for matching lines inside text files under a root directory.
|
||||
|
||||
Args:
|
||||
description: Explain why you are searching file contents in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
pattern: The string or regex pattern to search for.
|
||||
path: The **absolute** root directory to search under.
|
||||
glob: Optional glob filter for candidate files, for example `**/*.py`.
|
||||
literal: Whether to treat `pattern` as a plain string. Default is False.
|
||||
case_sensitive: Whether matching is case-sensitive. Default is False.
|
||||
max_results: Maximum number of matching lines to return. Default is 100.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
"grep",
|
||||
max_results,
|
||||
default=_DEFAULT_GREP_MAX_RESULTS,
|
||||
upper_bound=_MAX_GREP_MAX_RESULTS,
|
||||
)
|
||||
thread_data = None
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
matches, truncated = sandbox.grep(
|
||||
path,
|
||||
pattern,
|
||||
glob=glob,
|
||||
literal=literal,
|
||||
case_sensitive=case_sensitive,
|
||||
max_results=effective_max_results,
|
||||
)
|
||||
if thread_data is not None:
|
||||
matches = [
|
||||
GrepMatch(
|
||||
path=mask_local_paths_in_output(match.path, thread_data),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
for match in matches
|
||||
]
|
||||
return _format_grep_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except NotADirectoryError:
|
||||
return f"Error: Path is not a directory: {requested_path}"
|
||||
except re.error as e:
|
||||
return f"Error: Invalid regex pattern: {e}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
@@ -854,14 +1195,22 @@ def read_file_tool(
|
||||
path = _resolve_skills_path(path)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
else:
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "(empty)"
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
return content
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||
except Exception:
|
||||
max_chars = 50000
|
||||
return _truncate_read_file_output(content, max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
@@ -896,8 +1245,11 @@ def write_file_tool(
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
sandbox.write_file(path, content, append)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
sandbox.write_file(path, content, append)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
@@ -937,17 +1289,20 @@ def str_replace_tool(
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "OK"
|
||||
if old_str not in content:
|
||||
return f"Error: String to replace not found in file: {requested_path}"
|
||||
if replace_all:
|
||||
content = content.replace(old_str, new_str)
|
||||
else:
|
||||
content = content.replace(old_str, new_str, 1)
|
||||
sandbox.write_file(path, content)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "OK"
|
||||
if old_str not in content:
|
||||
return f"Error: String to replace not found in file: {requested_path}"
|
||||
if replace_all:
|
||||
content = content.replace(old_str, new_str)
|
||||
else:
|
||||
content = content.replace(old_str, new_str, 1)
|
||||
sandbox.write_file(path, content)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
@@ -33,15 +33,72 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
|
||||
|
||||
front_matter = front_matter_match.group(1)
|
||||
|
||||
# Parse YAML front matter (simple key-value parsing)
|
||||
# Parse YAML front matter with basic multiline string support
|
||||
metadata = {}
|
||||
for line in front_matter.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
lines = front_matter.split("\n")
|
||||
current_key = None
|
||||
current_value = []
|
||||
is_multiline = False
|
||||
multiline_style = None
|
||||
indent_level = None
|
||||
|
||||
for line in lines:
|
||||
if is_multiline:
|
||||
if not line.strip():
|
||||
current_value.append("")
|
||||
continue
|
||||
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
|
||||
if indent_level is None:
|
||||
if current_indent > 0:
|
||||
indent_level = current_indent
|
||||
current_value.append(line[indent_level:])
|
||||
continue
|
||||
elif current_indent >= indent_level:
|
||||
current_value.append(line[indent_level:])
|
||||
continue
|
||||
|
||||
# If we reach here, it's either a new key or the end of multiline
|
||||
if current_key and is_multiline:
|
||||
if multiline_style == "|":
|
||||
metadata[current_key] = "\n".join(current_value).rstrip()
|
||||
else:
|
||||
text = "\n".join(current_value).rstrip()
|
||||
# Replace single newlines with spaces for folded blocks
|
||||
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
||||
|
||||
current_key = None
|
||||
current_value = []
|
||||
is_multiline = False
|
||||
multiline_style = None
|
||||
indent_level = None
|
||||
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
if ":" in line:
|
||||
# Handle nested dicts simply by ignoring indentation for now,
|
||||
# or just extracting top-level keys
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip()
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
if value in (">", "|"):
|
||||
current_key = key
|
||||
is_multiline = True
|
||||
multiline_style = value
|
||||
current_value = []
|
||||
indent_level = None
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
if current_key and is_multiline:
|
||||
if multiline_style == "|":
|
||||
metadata[current_key] = "\n".join(current_value).rstrip()
|
||||
else:
|
||||
text = "\n".join(current_value).rstrip()
|
||||
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
||||
|
||||
# Extract required fields
|
||||
name = metadata.get("name")
|
||||
|
||||
@@ -43,5 +43,5 @@ You have access to the sandbox environment:
|
||||
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"],
|
||||
model="inherit",
|
||||
max_turns=30,
|
||||
max_turns=60,
|
||||
)
|
||||
|
||||
@@ -44,5 +44,5 @@ You have access to the same sandbox environment as the parent agent:
|
||||
tools=None, # Inherit all tools from parent
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"], # Prevent nesting and clarification
|
||||
model="inherit",
|
||||
max_turns=50,
|
||||
max_turns=100,
|
||||
)
|
||||
|
||||
@@ -28,9 +28,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
effective_timeout = app_config.get_timeout_for(name)
|
||||
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
|
||||
|
||||
overrides = {}
|
||||
if effective_timeout != config.timeout_seconds:
|
||||
logger.debug(f"Subagent '{name}': timeout overridden by config.yaml ({config.timeout_seconds}s -> {effective_timeout}s)")
|
||||
config = replace(config, timeout_seconds=effective_timeout)
|
||||
logger.debug(
|
||||
"Subagent '%s': timeout overridden by config.yaml (%ss -> %ss)",
|
||||
name,
|
||||
config.timeout_seconds,
|
||||
effective_timeout,
|
||||
)
|
||||
overrides["timeout_seconds"] = effective_timeout
|
||||
if effective_max_turns != config.max_turns:
|
||||
logger.debug(
|
||||
"Subagent '%s': max_turns overridden by config.yaml (%s -> %s)",
|
||||
name,
|
||||
config.max_turns,
|
||||
effective_max_turns,
|
||||
)
|
||||
overrides["max_turns"] = effective_max_turns
|
||||
if overrides:
|
||||
config = replace(config, **overrides)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -57,6 +57,42 @@ def _build_mcp_servers() -> dict[str, dict[str, Any]]:
|
||||
return build_servers_config(ExtensionsConfig.from_file())
|
||||
|
||||
|
||||
def _build_acp_mcp_servers() -> list[dict[str, Any]]:
|
||||
"""Build ACP ``mcpServers`` payload for ``new_session``.
|
||||
|
||||
The ACP client expects a list of server objects, while DeerFlow's MCP helper
|
||||
returns a name -> config mapping for the LangChain MCP adapter. This helper
|
||||
converts the enabled servers into the ACP wire format.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
mcp_servers: list[dict[str, Any]] = []
|
||||
for name, server_config in enabled_servers.items():
|
||||
transport_type = server_config.type or "stdio"
|
||||
payload: dict[str, Any] = {"name": name, "type": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not server_config.command:
|
||||
raise ValueError(f"MCP server '{name}' with stdio transport requires 'command' field")
|
||||
payload["command"] = server_config.command
|
||||
payload["args"] = server_config.args
|
||||
payload["env"] = [{"name": key, "value": value} for key, value in server_config.env.items()]
|
||||
elif transport_type in ("http", "sse"):
|
||||
if not server_config.url:
|
||||
raise ValueError(f"MCP server '{name}' with {transport_type} transport requires 'url' field")
|
||||
payload["url"] = server_config.url
|
||||
payload["headers"] = [{"name": key, "value": value} for key, value in server_config.headers.items()]
|
||||
else:
|
||||
raise ValueError(f"MCP server '{name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
mcp_servers.append(payload)
|
||||
|
||||
return mcp_servers
|
||||
|
||||
|
||||
def _build_permission_response(options: list[Any], *, auto_approve: bool) -> Any:
|
||||
"""Build an ACP permission response.
|
||||
|
||||
@@ -173,7 +209,15 @@ def build_invoke_acp_agent_tool(agents: dict) -> BaseTool:
|
||||
cmd = agent_config.command
|
||||
args = agent_config.args or []
|
||||
physical_cwd = _get_work_dir(thread_id)
|
||||
mcp_servers = _build_mcp_servers()
|
||||
try:
|
||||
mcp_servers = _build_acp_mcp_servers()
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Invalid MCP server configuration for ACP agent '%s'; continuing without MCP servers: %s",
|
||||
agent,
|
||||
exc,
|
||||
)
|
||||
mcp_servers = []
|
||||
agent_env: dict[str, str] | None = None
|
||||
if agent_config.env:
|
||||
agent_env = {k: (os.environ.get(v[1:], "") if v.startswith("$") else v) for k, v in agent_config.env.items()}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .factory import build_tracing_callbacks
|
||||
|
||||
__all__ = ["build_tracing_callbacks"]
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config import (
|
||||
get_enabled_tracing_providers,
|
||||
get_tracing_config,
|
||||
validate_enabled_tracing_providers,
|
||||
)
|
||||
|
||||
|
||||
def _create_langsmith_tracer(config) -> Any:
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
return LangChainTracer(project_name=config.project)
|
||||
|
||||
|
||||
def _create_langfuse_handler(config) -> Any:
|
||||
from langfuse import Langfuse
|
||||
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
|
||||
|
||||
# langfuse>=4 initializes project-specific credentials through the client
|
||||
# singleton; the LangChain callback then attaches to that configured client.
|
||||
Langfuse(
|
||||
secret_key=config.secret_key,
|
||||
public_key=config.public_key,
|
||||
host=config.host,
|
||||
)
|
||||
return LangfuseCallbackHandler(public_key=config.public_key)
|
||||
|
||||
|
||||
def build_tracing_callbacks() -> list[Any]:
|
||||
"""Build callbacks for all explicitly enabled tracing providers."""
|
||||
validate_enabled_tracing_providers()
|
||||
enabled_providers = get_enabled_tracing_providers()
|
||||
if not enabled_providers:
|
||||
return []
|
||||
|
||||
tracing_config = get_tracing_config()
|
||||
callbacks: list[Any] = []
|
||||
|
||||
for provider in enabled_providers:
|
||||
if provider == "langsmith":
|
||||
try:
|
||||
callbacks.append(_create_langsmith_tracer(tracing_config.langsmith))
|
||||
except Exception as exc: # pragma: no cover - exercised via tests with monkeypatch
|
||||
raise RuntimeError(f"LangSmith tracing initialization failed: {exc}") from exc
|
||||
elif provider == "langfuse":
|
||||
try:
|
||||
callbacks.append(_create_langfuse_handler(tracing_config.langfuse))
|
||||
except Exception as exc: # pragma: no cover - exercised via tests with monkeypatch
|
||||
raise RuntimeError(f"Langfuse tracing initialization failed: {exc}") from exc
|
||||
|
||||
return callbacks
|
||||
@@ -1,10 +1,22 @@
|
||||
"""File conversion utilities.
|
||||
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown using markitdown.
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown.
|
||||
|
||||
PDF conversion strategy (auto mode):
|
||||
1. Try pymupdf4llm if installed — better heading detection, faster on most files.
|
||||
2. If output is suspiciously short (< _MIN_CHARS_PER_PAGE chars/page, or < 200 chars
|
||||
total when page count is unavailable), treat as image-based and fall back to MarkItDown.
|
||||
3. If pymupdf4llm is not installed, use MarkItDown directly (existing behaviour).
|
||||
|
||||
Large files (> ASYNC_THRESHOLD_BYTES) are converted in a thread pool via
|
||||
asyncio.to_thread() to avoid blocking the event loop (fixes #1569).
|
||||
|
||||
No FastAPI or HTTP dependencies — pure utility functions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,28 +32,278 @@ CONVERTIBLE_EXTENSIONS = {
|
||||
".docx",
|
||||
}
|
||||
|
||||
# Files larger than this threshold are converted in a background thread.
|
||||
# Small files complete in < 1s synchronously; spawning a thread adds unnecessary
|
||||
# scheduling overhead for them.
|
||||
_ASYNC_THRESHOLD_BYTES = 1 * 1024 * 1024 # 1 MB
|
||||
|
||||
# If pymupdf4llm produces fewer characters *per page* than this threshold,
|
||||
# the PDF is likely image-based or encrypted — fall back to MarkItDown.
|
||||
# Rationale: normal text PDFs yield 200-2000 chars/page; image-based PDFs
|
||||
# yield close to 0. 50 chars/page gives a wide safety margin.
|
||||
# Falls back to absolute 200-char check when page count is unavailable.
|
||||
_MIN_CHARS_PER_PAGE = 50
|
||||
|
||||
|
||||
def _pymupdf_output_too_sparse(text: str, file_path: Path) -> bool:
|
||||
"""Return True if pymupdf4llm output is suspiciously short (image-based PDF).
|
||||
|
||||
Uses chars-per-page rather than an absolute threshold so that both short
|
||||
documents (few pages, few chars) and long documents (many pages, many chars)
|
||||
are handled correctly.
|
||||
"""
|
||||
chars = len(text.strip())
|
||||
doc = None
|
||||
pages: int | None = None
|
||||
try:
|
||||
import pymupdf
|
||||
|
||||
doc = pymupdf.open(str(file_path))
|
||||
pages = len(doc)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if doc is not None:
|
||||
try:
|
||||
doc.close()
|
||||
except Exception:
|
||||
pass
|
||||
if pages is not None and pages > 0:
|
||||
return (chars / pages) < _MIN_CHARS_PER_PAGE
|
||||
# Fallback: absolute threshold when page count is unavailable
|
||||
return chars < 200
|
||||
|
||||
|
||||
def _convert_pdf_with_pymupdf4llm(file_path: Path) -> str | None:
|
||||
"""Attempt PDF conversion with pymupdf4llm.
|
||||
|
||||
Returns the markdown text, or None if pymupdf4llm is not installed or
|
||||
if conversion fails (e.g. encrypted/corrupt PDF).
|
||||
"""
|
||||
try:
|
||||
import pymupdf4llm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
return pymupdf4llm.to_markdown(str(file_path))
|
||||
except Exception:
|
||||
logger.exception("pymupdf4llm failed to convert %s; falling back to MarkItDown", file_path.name)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_with_markitdown(file_path: Path) -> str:
|
||||
"""Convert any supported file to markdown text using MarkItDown."""
|
||||
from markitdown import MarkItDown
|
||||
|
||||
md = MarkItDown()
|
||||
return md.convert(str(file_path)).text_content
|
||||
|
||||
|
||||
def _do_convert(file_path: Path, pdf_converter: str) -> str:
|
||||
"""Synchronous conversion — called directly or via asyncio.to_thread.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
pdf_converter: "auto" | "pymupdf4llm" | "markitdown"
|
||||
"""
|
||||
is_pdf = file_path.suffix.lower() == ".pdf"
|
||||
|
||||
if is_pdf and pdf_converter != "markitdown":
|
||||
# Try pymupdf4llm first (auto or explicit)
|
||||
pymupdf_text = _convert_pdf_with_pymupdf4llm(file_path)
|
||||
|
||||
if pymupdf_text is not None:
|
||||
# pymupdf4llm is installed
|
||||
if pdf_converter == "pymupdf4llm":
|
||||
# Explicit — use as-is regardless of output length
|
||||
return pymupdf_text
|
||||
# auto mode: fall back if output looks like a failed parse.
|
||||
# Use chars-per-page to distinguish image-based PDFs (near 0) from
|
||||
# legitimately short documents.
|
||||
if not _pymupdf_output_too_sparse(pymupdf_text, file_path):
|
||||
return pymupdf_text
|
||||
logger.warning(
|
||||
"pymupdf4llm produced only %d chars for %s (likely image-based PDF); falling back to MarkItDown",
|
||||
len(pymupdf_text.strip()),
|
||||
file_path.name,
|
||||
)
|
||||
# pymupdf4llm not installed or fallback triggered → use MarkItDown
|
||||
|
||||
return _convert_with_markitdown(file_path)
|
||||
|
||||
|
||||
async def convert_file_to_markdown(file_path: Path) -> Path | None:
|
||||
"""Convert a file to markdown using markitdown.
|
||||
"""Convert a supported document file to Markdown.
|
||||
|
||||
PDF files are handled with a two-converter strategy (see module docstring).
|
||||
Large files (> 1 MB) are offloaded to a thread pool to avoid blocking the
|
||||
event loop.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to convert.
|
||||
|
||||
Returns:
|
||||
Path to the markdown file if conversion was successful, None otherwise.
|
||||
Path to the generated .md file, or None if conversion failed.
|
||||
"""
|
||||
try:
|
||||
from markitdown import MarkItDown
|
||||
pdf_converter = _get_pdf_converter()
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
md = MarkItDown()
|
||||
result = md.convert(str(file_path))
|
||||
if file_size > _ASYNC_THRESHOLD_BYTES:
|
||||
text = await asyncio.to_thread(_do_convert, file_path, pdf_converter)
|
||||
else:
|
||||
text = _do_convert(file_path, pdf_converter)
|
||||
|
||||
# Save as .md file with same name
|
||||
md_path = file_path.with_suffix(".md")
|
||||
md_path.write_text(result.text_content, encoding="utf-8")
|
||||
md_path.write_text(text, encoding="utf-8")
|
||||
|
||||
logger.info(f"Converted {file_path.name} to markdown: {md_path.name}")
|
||||
logger.info("Converted %s to markdown: %s (%d chars)", file_path.name, md_path.name, len(text))
|
||||
return md_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert {file_path.name} to markdown: {e}")
|
||||
logger.error("Failed to convert %s to markdown: %s", file_path.name, e)
|
||||
return None
|
||||
|
||||
|
||||
# Regex for bold-only lines that look like section headings.
|
||||
# Targets SEC filing structural headings that pymupdf4llm renders as **bold**
|
||||
# rather than # Markdown headings (because they use same font size as body text,
|
||||
# distinguished only by bold+caps formatting).
|
||||
#
|
||||
# Pattern requires ALL of:
|
||||
# 1. Entire line is a single **...** block (no surrounding prose)
|
||||
# 2. Starts with a recognised structural keyword:
|
||||
# - ITEM / PART / SECTION (with optional number/letter after)
|
||||
# - SCHEDULE, EXHIBIT, APPENDIX, ANNEX, CHAPTER
|
||||
# All-caps addresses, boilerplate ("CURRENT REPORT", "SIGNATURES",
|
||||
# "WASHINGTON, DC 20549") do NOT start with these keywords and are excluded.
|
||||
#
|
||||
# Chinese headings (第三节...) are already captured as standard # headings
|
||||
# by pymupdf4llm, so they don't need this pattern.
|
||||
_BOLD_HEADING_RE = re.compile(r"^\*\*((ITEM|PART|SECTION|SCHEDULE|EXHIBIT|APPENDIX|ANNEX|CHAPTER)\b[A-Z0-9 .,\-]*)\*\*\s*$")
|
||||
|
||||
# Regex for split-bold headings produced by pymupdf4llm when a heading spans
|
||||
# multiple text spans in the PDF (e.g. section number and title are separate spans).
|
||||
# Matches lines like: **1** **Introduction** or **3.2** **Multi-Head Attention**
|
||||
# Requirements:
|
||||
# 1. Entire line consists only of **...** blocks separated by whitespace (no prose)
|
||||
# 2. First block is a section number (digits and dots, e.g. "1", "3.2", "A.1")
|
||||
# 3. Second block must not be purely numeric/punctuation — excludes financial table
|
||||
# headers like **2023** **2022** **2021** while allowing non-ASCII titles such as
|
||||
# **1** **概述** or accented words (negative lookahead instead of [A-Za-z])
|
||||
# 4. At most two additional blocks (four total) with [^*]+ (no * inside) to keep
|
||||
# the regex linear and avoid ReDoS on attacker-controlled content
|
||||
_SPLIT_BOLD_HEADING_RE = re.compile(r"^\*\*[\dA-Z][\d\.]*\*\*\s+\*\*(?!\d[\d\s.,\-–—/:()%]*\*\*)[^*]+\*\*(?:\s+\*\*[^*]+\*\*){0,2}\s*$")
|
||||
|
||||
# Maximum number of outline entries injected into the agent context.
|
||||
# Keeps prompt size bounded even for very long documents.
|
||||
MAX_OUTLINE_ENTRIES = 50
|
||||
|
||||
_ALLOWED_PDF_CONVERTERS = {"auto", "pymupdf4llm", "markitdown"}
|
||||
|
||||
|
||||
def _clean_bold_title(raw: str) -> str:
|
||||
"""Normalise a title string that may contain pymupdf4llm bold artefacts.
|
||||
|
||||
pymupdf4llm sometimes emits adjacent bold spans as ``**A** **B**`` instead
|
||||
of a single ``**A B**`` block. This helper merges those fragments and then
|
||||
strips the outermost ``**...**`` wrapper so the caller gets plain text.
|
||||
|
||||
Examples::
|
||||
|
||||
"**Overview**" → "Overview"
|
||||
"**UNITED STATES** **SECURITIES**" → "UNITED STATES SECURITIES"
|
||||
"plain text" → "plain text" (unchanged)
|
||||
"""
|
||||
# Merge adjacent bold spans: "** **" → " "
|
||||
merged = re.sub(r"\*\*\s*\*\*", " ", raw).strip()
|
||||
# Strip outermost **...** if the whole string is wrapped
|
||||
if m := re.fullmatch(r"\*\*(.+?)\*\*", merged, re.DOTALL):
|
||||
return m.group(1).strip()
|
||||
return merged
|
||||
|
||||
|
||||
def extract_outline(md_path: Path) -> list[dict]:
|
||||
"""Extract document outline (headings) from a Markdown file.
|
||||
|
||||
Recognises three heading styles produced by pymupdf4llm:
|
||||
|
||||
1. Standard Markdown headings: lines starting with one or more '#'.
|
||||
Inline ``**...**`` wrappers and adjacent bold spans (``** **``) are
|
||||
cleaned so the title is plain text.
|
||||
|
||||
2. Bold-only structural headings: ``**ITEM 1. BUSINESS**``, ``**PART II**``,
|
||||
etc. SEC filings use bold+caps for section headings with the same font
|
||||
size as body text, so pymupdf4llm cannot promote them to # headings.
|
||||
|
||||
3. Split-bold headings: ``**1** **Introduction**``, ``**3.2** **Attention**``.
|
||||
pymupdf4llm emits these when the section number and title text are
|
||||
separate spans in the underlying PDF (common in academic papers).
|
||||
|
||||
Args:
|
||||
md_path: Path to the .md file.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: title (str), line (int, 1-based).
|
||||
When the outline is truncated at MAX_OUTLINE_ENTRIES, a sentinel entry
|
||||
``{"truncated": True}`` is appended as the last element so callers can
|
||||
render a "showing first N headings" hint without re-scanning the file.
|
||||
Returns an empty list if the file cannot be read or has no headings.
|
||||
"""
|
||||
outline: list[dict] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for lineno, line in enumerate(f, 1):
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
|
||||
# Style 1: standard Markdown heading
|
||||
if stripped.startswith("#"):
|
||||
title = _clean_bold_title(stripped.lstrip("#").strip())
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
# Style 2: single bold block with SEC structural keyword
|
||||
elif m := _BOLD_HEADING_RE.match(stripped):
|
||||
title = m.group(1).strip()
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
# Style 3: split-bold heading — **<num>** **<title>**
|
||||
# Regex already enforces max 4 blocks and non-numeric second block.
|
||||
elif _SPLIT_BOLD_HEADING_RE.match(stripped):
|
||||
title = " ".join(re.findall(r"\*\*([^*]+)\*\*", stripped))
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
if len(outline) >= MAX_OUTLINE_ENTRIES:
|
||||
outline.append({"truncated": True})
|
||||
break
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
def _get_pdf_converter() -> str:
|
||||
"""Read pdf_converter setting from app config, defaulting to 'auto'.
|
||||
|
||||
Normalizes the value to lowercase and validates it against the allowed set
|
||||
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
|
||||
fall through to unexpected behaviour.
|
||||
"""
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
cfg = get_app_config()
|
||||
uploads_cfg = getattr(cfg, "uploads", None)
|
||||
if uploads_cfg is not None:
|
||||
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
|
||||
if raw not in _ALLOWED_PDF_CONVERTERS:
|
||||
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
|
||||
return "auto"
|
||||
return raw
|
||||
except Exception:
|
||||
pass
|
||||
return "auto"
|
||||
|
||||
@@ -9,15 +9,17 @@ dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"httpx>=0.28.0",
|
||||
"kubernetes>=30.0.0",
|
||||
"langchain>=1.2.3",
|
||||
"langchain>=1.2.3,<1.2.10",
|
||||
"langchain-anthropic>=1.3.4",
|
||||
"langchain-deepseek>=1.0.1",
|
||||
"langchain-mcp-adapters>=0.1.0",
|
||||
"langchain-openai>=1.1.7",
|
||||
"langfuse>=3.4.1",
|
||||
"langgraph>=1.0.6,<1.0.10",
|
||||
"langgraph-prebuilt>=1.0.6,<1.0.9",
|
||||
"langgraph-api>=0.7.0,<0.8.0",
|
||||
"langgraph-cli>=0.4.14",
|
||||
"langgraph-runtime-inmem>=0.22.1",
|
||||
"langgraph-runtime-inmem>=0.22.1,<0.27.0",
|
||||
"markdownify>=1.2.2",
|
||||
"markitdown[all,xlsx]>=0.0.1a2",
|
||||
"pydantic>=2.12.5",
|
||||
@@ -33,6 +35,9 @@ dependencies = [
|
||||
"langgraph-sdk>=0.1.51",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@@ -16,6 +16,10 @@ dependencies = [
|
||||
"python-telegram-bot>=21.0",
|
||||
"langgraph-sdk>=0.1.51",
|
||||
"markdown-to-mrkdwn>=0.3.1",
|
||||
"wecom-aibot-python-sdk>=0.1.6",
|
||||
"bcrypt>=4.0.0",
|
||||
"pyjwt>=2.9.0",
|
||||
"email-validator>=2.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -131,3 +131,53 @@ class TestListDirSerialization:
|
||||
result = sandbox.list_dir("/test")
|
||||
assert result == ["/a", "/b"]
|
||||
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||
|
||||
|
||||
class TestConcurrentFileWrites:
|
||||
"""Verify file write paths do not lose concurrent updates."""
|
||||
|
||||
def test_append_should_preserve_both_parallel_writes(self, sandbox):
|
||||
storage = {"content": "seed\n"}
|
||||
active_reads = 0
|
||||
state_lock = threading.Lock()
|
||||
overlap_detected = threading.Event()
|
||||
|
||||
def overlapping_read_file(path):
|
||||
nonlocal active_reads
|
||||
with state_lock:
|
||||
active_reads += 1
|
||||
snapshot = storage["content"]
|
||||
if active_reads == 2:
|
||||
overlap_detected.set()
|
||||
|
||||
overlap_detected.wait(0.05)
|
||||
|
||||
with state_lock:
|
||||
active_reads -= 1
|
||||
|
||||
return snapshot
|
||||
|
||||
def write_back(*, file, content, **kwargs):
|
||||
storage["content"] = content
|
||||
return SimpleNamespace(data=SimpleNamespace())
|
||||
|
||||
sandbox.read_file = overlapping_read_file
|
||||
sandbox._client.file.write_file = write_back
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def writer(payload: str):
|
||||
barrier.wait()
|
||||
sandbox.write_file("/tmp/shared.log", payload, append=True)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=writer, args=("A\n",)),
|
||||
threading.Thread(target=writer, args=("B\n",)),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
|
||||
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hash_password_and_verify():
|
||||
"""Hashing and verification round-trip."""
|
||||
password = "s3cr3tP@ssw0rd!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("wrongpassword", hashed) is False
|
||||
|
||||
|
||||
def test_hash_password_different_each_time():
|
||||
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||
password = "testpassword"
|
||||
h1 = hash_password(password)
|
||||
h2 = hash_password(password)
|
||||
assert h1 != h2 # Different salts
|
||||
# But both verify correctly
|
||||
assert verify_password(password, h1) is True
|
||||
assert verify_password(password, h2) is True
|
||||
|
||||
|
||||
def test_verify_password_rejects_empty():
|
||||
"""Empty password should not verify."""
|
||||
hashed = hash_password("nonempty")
|
||||
assert verify_password("", hashed) is False
|
||||
|
||||
|
||||
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
"""JWT creation and decoding round-trip."""
|
||||
user_id = str(uuid4())
|
||||
# Set a valid JWT secret for this test
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(user_id)
|
||||
assert isinstance(token, str)
|
||||
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_token(token)
|
||||
assert payload == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||
|
||||
|
||||
def test_create_token_custom_expiry():
|
||||
"""Custom expiry is respected."""
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_authenticated_no_perms():
|
||||
"""AuthContext with user but no permissions."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_has_permission():
|
||||
"""AuthContext permission checking."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
assert ctx.has_permission("runs", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_require_user_raises():
|
||||
"""require_user raises 401 when not authenticated."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
ctx.require_user()
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_context_require_user_returns_user():
|
||||
"""require_user returns user when authenticated."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
returned = ctx.require_user()
|
||||
assert returned == user
|
||||
|
||||
|
||||
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_auth_context_not_set():
|
||||
"""get_auth_context returns None when auth not set on request."""
|
||||
mock_request = MagicMock()
|
||||
# Make getattr return None (simulating attribute not set)
|
||||
mock_request.state = MagicMock()
|
||||
del mock_request.state.auth
|
||||
assert get_auth_context(mock_request) is None
|
||||
|
||||
|
||||
def test_get_auth_context_set():
|
||||
"""get_auth_context returns the AuthContext from request."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.auth = ctx
|
||||
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_model_has_needs_setup_default_false():
|
||||
"""New users default to needs_setup=False."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.needs_setup is False
|
||||
|
||||
|
||||
def test_user_model_has_token_version_default_zero():
|
||||
"""New users default to token_version=0."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.token_version == 0
|
||||
|
||||
|
||||
def test_user_model_needs_setup_true():
|
||||
"""Auto-created admin has needs_setup=True."""
|
||||
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||
assert user.needs_setup is True
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip."""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.gateway.auth.repositories import sqlite as sqlite_mod
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_users.db")
|
||||
old_path = sqlite_mod._resolved_db_path
|
||||
old_init = sqlite_mod._table_initialized
|
||||
sqlite_mod._resolved_db_path = Path(db_path)
|
||||
sqlite_mod._table_initialized = False
|
||||
try:
|
||||
repo = sqlite_mod.SQLiteUserRepository()
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = asyncio.run(repo.create_user(user))
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
asyncio.run(repo.update_user(fetched))
|
||||
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
sqlite_mod._resolved_db_path = old_path
|
||||
sqlite_mod._table_initialized = old_init
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 0
|
||||
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
new_password="newpassword",
|
||||
new_email="new@example.com",
|
||||
)
|
||||
assert req.new_email == "new@example.com"
|
||||
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
resp2 = LoginResponse(expires_in=3600)
|
||||
assert resp2.needs_setup is False
|
||||
|
||||
|
||||
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
for _ in range(5):
|
||||
_record_login_failure(ip)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_rate_limit(ip)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
for _ in range(4):
|
||||
_record_login_failure(ip)
|
||||
_record_login_success(ip)
|
||||
_check_rate_limit(ip) # Should not raise
|
||||
|
||||
|
||||
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_client_ip_direct_connection():
|
||||
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_uses_x_real_ip():
|
||||
"""X-Real-IP (set by nginx) is used when present."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_xff_ignored():
|
||||
"""X-Forwarded-For is never used; only X-Real-IP matters."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
|
||||
assert _get_client_ip(req) == "198.51.100.5"
|
||||
|
||||
|
||||
def test_get_client_ip_no_real_ip_fallback():
|
||||
"""No X-Real-IP → falls back to client.host (direct connection)."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_always_preferred():
|
||||
"""X-Real-IP is always preferred over client.host regardless of IP."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.99"
|
||||
req.headers = {"x-real-ip": "198.51.100.7"}
|
||||
assert _get_client_ip(req) == "198.51.100.7"
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for AuthConfig typed configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == "test-jwt-secret-from-env"
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Tests for auth error types and typed decode_token."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
|
||||
|
||||
def test_auth_error_code_values():
|
||||
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
|
||||
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
|
||||
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
|
||||
|
||||
|
||||
def test_token_error_values():
|
||||
assert TokenError.EXPIRED == "expired"
|
||||
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
|
||||
assert TokenError.MALFORMED == "malformed"
|
||||
|
||||
|
||||
def test_auth_error_response_serialization():
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.TOKEN_EXPIRED,
|
||||
message="Token has expired",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert d == {"code": "token_expired", "message": "Token has expired"}
|
||||
|
||||
|
||||
def test_auth_error_response_from_dict():
|
||||
d = {"code": "invalid_credentials", "message": "Wrong password"}
|
||||
err = AuthErrorResponse(**d)
|
||||
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
|
||||
|
||||
|
||||
# ── decode_token typed failure tests ──────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_expired():
|
||||
_setup_config()
|
||||
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_bad_signature():
|
||||
_setup_config()
|
||||
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("not-a-jwt")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_returns_payload_on_valid():
|
||||
_setup_config()
|
||||
token = create_access_token("user-123")
|
||||
result = decode_token(token)
|
||||
assert not isinstance(result, TokenError)
|
||||
assert result.sub == "user-123"
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/health",
|
||||
"/health/",
|
||||
"/docs",
|
||||
"/docs/",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
],
|
||||
)
|
||||
def test_public_paths(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models",
|
||||
"/api/mcp/config",
|
||||
"/api/memory",
|
||||
"/api/skills",
|
||||
"/api/threads/123",
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
"/api/v1/auth/change-password",
|
||||
],
|
||||
)
|
||||
def test_protected_paths(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/v1/auth/login/local/",
|
||||
"/api/v1/auth/register/",
|
||||
"/api/v1/auth/logout/",
|
||||
"/api/v1/auth/setup-status/",
|
||||
],
|
||||
)
|
||||
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models/",
|
||||
"/api/v1/auth/me/",
|
||||
"/api/v1/auth/change-password/",
|
||||
],
|
||||
)
|
||||
def test_protected_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
def test_unknown_api_path_is_protected():
|
||||
"""Fail-closed: any new /api/* path is protected by default."""
|
||||
assert _is_public("/api/new-feature") is False
|
||||
assert _is_public("/api/v2/something") is False
|
||||
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||
|
||||
|
||||
# ── Middleware integration tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
return {"needs_setup": False}
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
|
||||
@app.delete("/api/threads/abc")
|
||||
async def thread_delete():
|
||||
return {"ok": True}
|
||||
|
||||
@app.patch("/api/threads/abc")
|
||||
async def thread_patch():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def stream():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/future-endpoint")
|
||||
async def future():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_public_path_no_cookie(client):
|
||||
res = client.get("/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_public_auth_path_no_cookie(client):
|
||||
"""Public auth endpoints (login/register) pass without cookie."""
|
||||
res = client.get("/api/v1/auth/setup-status")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_auth_path_no_cookie(client):
|
||||
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||
res = client.get("/api/v1/auth/me")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_path_no_cookie_returns_401(client):
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
body = res.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_protected_path_with_cookie_passes(client):
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_post_no_cookie_returns_401(client):
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
def test_protected_put_no_cookie(client):
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_delete_no_cookie(client):
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_patch_no_cookie(client):
|
||||
res = client.patch("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_put_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_delete_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_unknown_endpoint_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 200
|
||||
@@ -0,0 +1,675 @@
|
||||
"""Tests for auth type system hardening.
|
||||
|
||||
Covers structured error responses, typed decode_token callers,
|
||||
CSRF middleware path matching, config-driven cookie security,
|
||||
and unhappy paths / edge cases for all auth boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
is_auth_endpoint,
|
||||
should_check_csrf,
|
||||
)
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal request mock for CSRF path matching tests."""
|
||||
|
||||
def __init__(self, path: str, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
class _URL:
|
||||
def __init__(self, p):
|
||||
self.path = p
|
||||
|
||||
self.url = _URL(path)
|
||||
self.cookies = {}
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local():
|
||||
"""login/local (actual route) should be exempt from CSRF."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local_trailing_slash():
|
||||
"""Trailing slash should also be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_logout():
|
||||
req = _FakeRequest("/api/v1/auth/logout")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_register():
|
||||
req = _FakeRequest("/api/v1/auth/register")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_old_login_path():
|
||||
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_me():
|
||||
req = _FakeRequest("/api/v1/auth/me")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_skips_get_requests():
|
||||
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||
assert should_check_csrf(req) is False
|
||||
|
||||
|
||||
def test_csrf_checks_post_to_protected():
|
||||
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||
assert should_check_csrf(req) is True
|
||||
|
||||
|
||||
# ── Structured Error Response Format ────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_error_response_has_code_and_message():
|
||||
"""All auth errors should have structured {code, message} format."""
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Wrong password",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert "code" in d
|
||||
assert "message" in d
|
||||
assert d["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_auth_error_response_all_codes_serializable():
|
||||
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||
for code in AuthErrorCode:
|
||||
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||
d = err.model_dump()
|
||||
assert d["code"] == code.value
|
||||
|
||||
|
||||
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_expired_maps_to_token_expired_code():
|
||||
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
# Verify the mapping pattern used in route handlers
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
result = decode_token("garbage")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
# ── Login Response Format ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
assert "access_token" not in d
|
||||
assert "expires_in" in d
|
||||
assert d["expires_in"] == 604800
|
||||
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
|
||||
|
||||
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
assert resp.expires_in == expected_seconds
|
||||
|
||||
|
||||
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
assert resp.system_role == "admin"
|
||||
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||
assert resp.system_role == "user"
|
||||
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# UNHAPPY PATHS / EDGE CASES
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
# ── get_current_user structured 401 responses ────────────────────────
|
||||
|
||||
|
||||
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
"""No cookie → 401 with code=not_authenticated."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_get_current_user_expired_token_returns_token_expired():
|
||||
"""Expired token → 401 with code=token_expired."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
"""Bad signature → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
"""Garbage token → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
# ── decode_token edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_empty_string_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_whitespace_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token(" ")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_missing_jwt_secret_raises():
|
||||
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig()
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_zero_raises():
|
||||
"""token_expiry_days must be >= 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_31_raises():
|
||||
"""token_expiry_days must be <= 30."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_1_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||
assert config.token_expiry_days == 1
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_30_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||
assert config.token_expiry_days == 30
|
||||
|
||||
|
||||
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
|
||||
|
||||
def _make_csrf_app():
|
||||
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from fastapi.responses import JSONResponse as _JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(_HTTPException)
|
||||
async def _http_exc_handler(request, exc):
|
||||
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/test/protected")
|
||||
async def protected():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/v1/test/read")
|
||||
async def read_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_without_token():
|
||||
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/test/protected")
|
||||
assert resp.status_code == 403
|
||||
assert "CSRF" in resp.json()["detail"]
|
||||
assert "missing" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: "token-b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_post_with_matching_token():
|
||||
"""POST with matching CSRF cookie/header → 200."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
token = secrets.token_urlsafe(64)
|
||||
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_get_without_token():
|
||||
"""GET requests bypass CSRF check."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.get("/api/v1/test/read")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_exempts_login_local():
|
||||
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert CSRF_COOKIE_NAME in resp.cookies
|
||||
|
||||
|
||||
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# HTTP-LEVEL API CONTRACT TESTS
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_auth_app():
|
||||
"""Create FastAPI app with auth routes for contract testing."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
|
||||
|
||||
def _unique_email(prefix: str) -> str:
|
||||
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||
|
||||
|
||||
def _get_set_cookie_headers(resp) -> list[str]:
|
||||
"""Extract all set-cookie header values from a TestClient response."""
|
||||
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||
|
||||
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user