Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2b33bfd78f | |||
| 745bf4324e | |||
| e7a881b577 | |||
| ea73db6fc1 | |||
| ceeccabc98 | |||
| f0b065bef6 | |||
| 3aa3e37532 | |||
| e5ad92474c | |||
| 4b139fb689 | |||
| 2531cce0d1 | |||
| f942e4e597 | |||
| 03c3b18565 | |||
| 00e0e9a49a |
@@ -1,181 +0,0 @@
|
||||
---
|
||||
name: smoke-test
|
||||
description: End-to-end smoke test skill for DeerFlow. Guides through: 1) Pulling latest code, 2) Docker OR Local installation and deployment (user preference, default to Local if Docker network issues), 3) Service availability verification, 4) Health check, 5) Final test report. Use when the user says "run smoke test", "smoke test deployment", "verify installation", "test service availability", "end-to-end test", or similar.
|
||||
---
|
||||
|
||||
# DeerFlow Smoke Test Skill
|
||||
|
||||
This skill guides the Agent through DeerFlow's full end-to-end smoke test workflow, including code updates, deployment (supporting both Docker and local installation modes), service availability verification, and health checks.
|
||||
|
||||
## Deployment Mode Selection
|
||||
|
||||
This skill supports two deployment modes:
|
||||
- **Local installation mode** (recommended, especially when network issues occur) - Run all services directly on the local machine
|
||||
- **Docker mode** - Run all services inside Docker containers
|
||||
|
||||
**Selection strategy**:
|
||||
- If the user explicitly asks for Docker mode, use Docker
|
||||
- If network issues occur (such as slow image pulls), automatically switch to local mode
|
||||
- Default to local mode whenever possible
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
smoke-test/
|
||||
├── SKILL.md ← You are here - core workflow and logic
|
||||
├── scripts/
|
||||
│ ├── check_docker.sh ← Check the Docker environment
|
||||
│ ├── check_local_env.sh ← Check local environment dependencies
|
||||
│ ├── frontend_check.sh ← Frontend page smoke check
|
||||
│ ├── pull_code.sh ← Pull the latest code
|
||||
│ ├── deploy_docker.sh ← Docker deployment
|
||||
│ ├── deploy_local.sh ← Local deployment
|
||||
│ └── health_check.sh ← Service health check
|
||||
├── references/
|
||||
│ ├── SOP.md ← Standard operating procedure
|
||||
│ └── troubleshooting.md ← Troubleshooting guide
|
||||
└── templates/
|
||||
├── report.local.template.md ← Local mode smoke test report template
|
||||
└── report.docker.template.md ← Docker mode smoke test report template
|
||||
```
|
||||
|
||||
## Standard Operating Procedure (SOP)
|
||||
|
||||
### Phase 1: Code Update Check
|
||||
|
||||
1. **Confirm current directory** - Verify that the current working directory is the DeerFlow project root
|
||||
2. **Check Git status** - See whether there are uncommitted changes
|
||||
3. **Pull the latest code** - Use `git pull origin main` to get the latest updates
|
||||
4. **Confirm code update** - Verify that the latest code was pulled successfully
|
||||
|
||||
### Phase 2: Deployment Mode Selection and Environment Check
|
||||
|
||||
**Choose deployment mode**:
|
||||
- Ask for user preference, or choose automatically based on network conditions
|
||||
- Default to local installation mode
|
||||
|
||||
**Local mode environment check**:
|
||||
1. **Check Node.js version** - Requires 22+
|
||||
2. **Check pnpm** - Package manager
|
||||
3. **Check uv** - Python package manager
|
||||
4. **Check nginx** - Reverse proxy
|
||||
5. **Check required ports** - Confirm that ports 2026, 3000, 8001, and 2024 are not occupied
|
||||
|
||||
**Docker mode environment check** (if Docker is selected):
|
||||
1. **Check whether Docker is installed** - Run `docker --version`
|
||||
2. **Check Docker daemon status** - Run `docker info`
|
||||
3. **Check Docker Compose availability** - Run `docker compose version`
|
||||
4. **Check required ports** - Confirm that port 2026 is not occupied
|
||||
|
||||
### Phase 3: Configuration Preparation
|
||||
|
||||
1. **Check whether config.yaml exists**
|
||||
- If it does not exist, run `make config` to generate it
|
||||
- If it already exists, check whether it needs an upgrade with `make config-upgrade`
|
||||
2. **Check the .env file**
|
||||
- Verify that required environment variables are configured
|
||||
- Especially model API keys such as `OPENAI_API_KEY`
|
||||
|
||||
### Phase 4: Deployment Execution
|
||||
|
||||
**Local mode deployment**:
|
||||
1. **Check dependencies** - Run `make check`
|
||||
2. **Install dependencies** - Run `make install`
|
||||
3. **(Optional) Pre-pull the sandbox image** - If needed, run `make setup-sandbox`
|
||||
4. **Start services** - Run `make dev-daemon` (background mode, recommended) or `make dev` (foreground mode)
|
||||
5. **Wait for startup** - Give all services enough time to start completely (90-120 seconds recommended)
|
||||
|
||||
**Docker mode deployment** (if Docker is selected):
|
||||
1. **Initialize Docker environment** - Run `make docker-init`
|
||||
2. **Start Docker services** - Run `make docker-start`
|
||||
3. **Wait for startup** - Give all containers enough time to start completely (60 seconds recommended)
|
||||
|
||||
### Phase 5: Service Health Check
|
||||
|
||||
**Local mode health check**:
|
||||
1. **Check process status** - Confirm that LangGraph, Gateway, Frontend, and Nginx processes are all running
|
||||
2. **Check frontend service** - Visit `http://localhost:2026` and verify that the page loads
|
||||
3. **Check API Gateway** - Verify the `http://localhost:2026/health` endpoint
|
||||
4. **Check LangGraph service** - Verify the availability of relevant endpoints
|
||||
5. **Frontend route smoke check** - Run `bash .agent/skills/smoke-test/scripts/frontend_check.sh` to verify key routes under `/workspace`
|
||||
|
||||
**Docker mode health check** (when using Docker):
|
||||
1. **Check container status** - Run `docker ps` and confirm that all containers are running
|
||||
2. **Check frontend service** - Visit `http://localhost:2026` and verify that the page loads
|
||||
3. **Check API Gateway** - Verify the `http://localhost:2026/health` endpoint
|
||||
4. **Check LangGraph service** - Verify the availability of relevant endpoints
|
||||
5. **Frontend route smoke check** - Run `bash .agent/skills/smoke-test/scripts/frontend_check.sh` to verify key routes under `/workspace`
|
||||
|
||||
### Optional Functional Verification
|
||||
|
||||
1. **List available models** - Verify that model configuration loads correctly
|
||||
2. **List available skills** - Verify that the skill directory is mounted correctly
|
||||
3. **Simple chat test** - Send a simple message to verify the end-to-end flow
|
||||
|
||||
### Phase 6: Generate Test Report
|
||||
|
||||
1. **Collect all test results** - Summarize execution status for each phase
|
||||
2. **Record encountered issues** - If anything fails, record the error details
|
||||
3. **Generate the final report** - Use the template that matches the selected deployment mode to create the complete test report, including overall conclusion, detailed key test cases, and explicit frontend page / route results
|
||||
4. **Provide follow-up recommendations** - Offer suggestions based on the test results
|
||||
|
||||
## Execution Rules
|
||||
|
||||
- **Follow the sequence** - Execute strictly in the order described above
|
||||
- **Idempotency** - Every step should be safe to repeat
|
||||
- **Error handling** - If a step fails, stop and report the issue, then provide troubleshooting suggestions
|
||||
- **Detailed logging** - Record the execution result and status of each step
|
||||
- **User confirmation** - Ask for confirmation before potentially risky operations such as overwriting config
|
||||
- **Mode preference** - Prefer local mode to avoid network-related issues
|
||||
- **Template requirement** - The final report must use the matching template under `templates/`; do not output a free-form summary instead of the template-based report
|
||||
- **Report clarity** - The execution summary must include the overall pass/fail conclusion plus per-case result explanations, and frontend smoke check results must be listed explicitly in the report
|
||||
- **Optional phase handling** - If functional verification is not executed, do not present it as a separate skipped phase in the final report
|
||||
|
||||
## Known Acceptable Warnings
|
||||
|
||||
The following warnings can appear during smoke testing and do not block a successful result:
|
||||
- Feishu/Lark SSL errors in Gateway logs (certificate verification failure) can be ignored if that channel is not enabled
|
||||
- Warnings in LangGraph logs about missing methods in the custom checkpointer, such as `adelete_for_runs` or `aprune`, do not affect the core functionality
|
||||
|
||||
## Key Tools
|
||||
|
||||
Use the following tools during execution:
|
||||
|
||||
1. **bash** - Run shell commands
|
||||
2. **present_file** - Show generated reports and important files
|
||||
3. **task_tool** - Organize complex steps with subtasks when needed
|
||||
|
||||
## Success Criteria
|
||||
|
||||
Smoke test pass criteria (local mode):
|
||||
- [x] Latest code is pulled successfully
|
||||
- [x] Local environment check passes (Node.js 22+, pnpm, uv, nginx)
|
||||
- [x] Configuration files are set up correctly
|
||||
- [x] `make check` passes
|
||||
- [x] `make install` completes successfully
|
||||
- [x] `make dev` starts successfully
|
||||
- [x] All service processes run normally
|
||||
- [x] Frontend page is accessible
|
||||
- [x] Frontend route smoke check passes (`/workspace` key routes)
|
||||
- [x] API Gateway health check passes
|
||||
- [x] Test report is generated completely
|
||||
|
||||
Smoke test pass criteria (Docker mode):
|
||||
- [x] Latest code is pulled successfully
|
||||
- [x] Docker environment check passes
|
||||
- [x] Configuration files are set up correctly
|
||||
- [x] `make docker-init` completes successfully
|
||||
- [x] `make docker-start` completes successfully
|
||||
- [x] All Docker containers run normally
|
||||
- [x] Frontend page is accessible
|
||||
- [x] Frontend route smoke check passes (`/workspace` key routes)
|
||||
- [x] API Gateway health check passes
|
||||
- [x] Test report is generated completely
|
||||
|
||||
## Read Reference Files
|
||||
|
||||
Before starting execution, read the following reference files:
|
||||
1. `references/SOP.md` - Detailed step-by-step operating instructions
|
||||
2. `references/troubleshooting.md` - Common issues and solutions
|
||||
3. `templates/report.local.template.md` - Local mode test report template
|
||||
4. `templates/report.docker.template.md` - Docker mode test report template
|
||||
@@ -1,452 +0,0 @@
|
||||
# DeerFlow Smoke Test Standard Operating Procedure (SOP)
|
||||
|
||||
This document describes the detailed operating steps for each phase of the DeerFlow smoke test.
|
||||
|
||||
## Phase 1: Code Update Check
|
||||
|
||||
### 1.1 Confirm Current Directory
|
||||
|
||||
**Objective**: Verify that the current working directory is the DeerFlow project root.
|
||||
|
||||
**Steps**:
|
||||
1. Run `pwd` to view the current working directory
|
||||
2. Check whether the directory contains the following files/directories:
|
||||
- `Makefile`
|
||||
- `backend/`
|
||||
- `frontend/`
|
||||
- `config.example.yaml`
|
||||
|
||||
**Success Criteria**: The current directory contains all of the files/directories listed above.
|
||||
|
||||
---
|
||||
|
||||
### 1.2 Check Git Status
|
||||
|
||||
**Objective**: Check whether there are uncommitted changes.
|
||||
|
||||
**Steps**:
|
||||
1. Run `git status`
|
||||
2. Check whether the output includes "Changes not staged for commit" or "Untracked files"
|
||||
|
||||
**Notes**:
|
||||
- If there are uncommitted changes, recommend that the user commit or stash them first to avoid conflicts while pulling
|
||||
- If the user confirms that they want to continue, this step can be skipped
|
||||
|
||||
---
|
||||
|
||||
### 1.3 Pull the Latest Code
|
||||
|
||||
**Objective**: Fetch the latest code updates.
|
||||
|
||||
**Steps**:
|
||||
1. Run `git fetch origin main`
|
||||
2. Run `git pull origin main`
|
||||
|
||||
**Success Criteria**:
|
||||
- The commands succeed without errors
|
||||
- The output shows "Already up to date" or indicates that new commits were pulled successfully
|
||||
|
||||
---
|
||||
|
||||
### 1.4 Confirm Code Update
|
||||
|
||||
**Objective**: Verify that the latest code was pulled successfully.
|
||||
|
||||
**Steps**:
|
||||
1. Run `git log -1 --oneline` to view the latest commit
|
||||
2. Record the commit hash and message
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Deployment Mode Selection and Environment Check
|
||||
|
||||
### 2.1 Choose Deployment Mode
|
||||
|
||||
**Objective**: Decide whether to use local mode or Docker mode.
|
||||
|
||||
**Decision Flow**:
|
||||
1. Prefer local mode first to avoid network-related issues
|
||||
2. If the user explicitly requests Docker, use Docker
|
||||
3. If Docker network issues occur, switch to local mode automatically
|
||||
|
||||
---
|
||||
|
||||
### 2.2 Local Mode Environment Check
|
||||
|
||||
**Objective**: Verify that local development environment dependencies are satisfied.
|
||||
|
||||
#### 2.2.1 Check Node.js Version
|
||||
|
||||
**Steps**:
|
||||
1. If nvm is used, run `nvm use 22` to switch to Node 22+
|
||||
2. Run `node --version`
|
||||
|
||||
**Success Criteria**: Version >= 22.x
|
||||
|
||||
**Failure Handling**:
|
||||
- If the version is too low, ask the user to install/switch Node.js with nvm:
|
||||
```bash
|
||||
nvm install 22
|
||||
nvm use 22
|
||||
```
|
||||
- Or install it from the official website: https://nodejs.org/
|
||||
|
||||
---
|
||||
|
||||
#### 2.2.2 Check pnpm
|
||||
|
||||
**Steps**:
|
||||
1. Run `pnpm --version`
|
||||
|
||||
**Success Criteria**: The command returns pnpm version information.
|
||||
|
||||
**Failure Handling**:
|
||||
- If pnpm is not installed, ask the user to install it with `npm install -g pnpm`
|
||||
|
||||
---
|
||||
|
||||
#### 2.2.3 Check uv
|
||||
|
||||
**Steps**:
|
||||
1. Run `uv --version`
|
||||
|
||||
**Success Criteria**: The command returns uv version information.
|
||||
|
||||
**Failure Handling**:
|
||||
- If uv is not installed, ask the user to install uv
|
||||
|
||||
---
|
||||
|
||||
#### 2.2.4 Check nginx
|
||||
|
||||
**Steps**:
|
||||
1. Run `nginx -v`
|
||||
|
||||
**Success Criteria**: The command returns nginx version information.
|
||||
|
||||
**Failure Handling**:
|
||||
- macOS: install with Homebrew using `brew install nginx`
|
||||
- Linux: install using the system package manager
|
||||
|
||||
---
|
||||
|
||||
#### 2.2.5 Check Required Ports
|
||||
|
||||
**Steps**:
|
||||
1. Run the following commands to check ports:
|
||||
```bash
|
||||
lsof -i :2026 # Main port
|
||||
lsof -i :3000 # Frontend
|
||||
lsof -i :8001 # Gateway
|
||||
lsof -i :2024 # LangGraph
|
||||
```
|
||||
|
||||
**Success Criteria**: All ports are free, or they are occupied only by DeerFlow-related processes.
|
||||
|
||||
**Failure Handling**:
|
||||
- If a port is occupied, ask the user to stop the related process
|
||||
|
||||
---
|
||||
|
||||
### 2.3 Docker Mode Environment Check (If Docker Is Selected)
|
||||
|
||||
#### 2.3.1 Check Whether Docker Is Installed
|
||||
|
||||
**Steps**:
|
||||
1. Run `docker --version`
|
||||
|
||||
**Success Criteria**: The command returns Docker version information, such as "Docker version 24.x.x".
|
||||
|
||||
---
|
||||
|
||||
#### 2.3.2 Check Docker Daemon Status
|
||||
|
||||
**Steps**:
|
||||
1. Run `docker info`
|
||||
|
||||
**Success Criteria**: The command runs successfully and shows Docker system information.
|
||||
|
||||
**Failure Handling**:
|
||||
- If it fails, ask the user to start Docker Desktop or the Docker service
|
||||
|
||||
---
|
||||
|
||||
#### 2.3.3 Check Docker Compose Availability
|
||||
|
||||
**Steps**:
|
||||
1. Run `docker compose version`
|
||||
|
||||
**Success Criteria**: The command returns Docker Compose version information.
|
||||
|
||||
---
|
||||
|
||||
#### 2.3.4 Check Required Ports
|
||||
|
||||
**Steps**:
|
||||
1. Run `lsof -i :2026` (macOS/Linux) or `netstat -ano | findstr :2026` (Windows)
|
||||
|
||||
**Success Criteria**: Port 2026 is free, or it is occupied only by a DeerFlow-related process.
|
||||
|
||||
**Failure Handling**:
|
||||
- If the port is occupied by another process, ask the user to stop that process or change the configuration
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Configuration Preparation
|
||||
|
||||
### 3.1 Check config.yaml
|
||||
|
||||
**Steps**:
|
||||
1. Check whether `config.yaml` exists
|
||||
2. If it does not exist, run `make config`
|
||||
3. If it already exists, consider running `make config-upgrade` to merge new fields
|
||||
|
||||
**Validation**:
|
||||
- Check whether at least one model is configured in config.yaml
|
||||
- Check whether the model configuration references the correct environment variables
|
||||
|
||||
---
|
||||
|
||||
### 3.2 Check the .env File
|
||||
|
||||
**Steps**:
|
||||
1. Check whether the `.env` file exists
|
||||
2. If it does not exist, copy it from `.env.example`
|
||||
3. Check whether the following environment variables are configured:
|
||||
- `OPENAI_API_KEY` (or other model API keys)
|
||||
- Other required settings
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Deployment Execution
|
||||
|
||||
### 4.1 Local Mode Deployment
|
||||
|
||||
#### 4.1.1 Check Dependencies
|
||||
|
||||
**Steps**:
|
||||
1. Run `make check`
|
||||
|
||||
**Description**: This command validates all required tools (Node.js 22+, pnpm, uv, nginx).
|
||||
|
||||
---
|
||||
|
||||
#### 4.1.2 Install Dependencies
|
||||
|
||||
**Steps**:
|
||||
1. Run `make install`
|
||||
|
||||
**Description**: This command installs both backend and frontend dependencies.
|
||||
|
||||
**Notes**:
|
||||
- This step may take some time
|
||||
- If network issues cause failures, try using a closer or mirrored package registry
|
||||
|
||||
---
|
||||
|
||||
#### 4.1.3 (Optional) Pre-pull the Sandbox Image
|
||||
|
||||
**Steps**:
|
||||
1. If Docker / Container sandbox is used, run `make setup-sandbox`
|
||||
|
||||
**Description**: This step is optional and not needed for local sandbox mode.
|
||||
|
||||
---
|
||||
|
||||
#### 4.1.4 Start Services
|
||||
|
||||
**Steps**:
|
||||
1. Run `make dev-daemon` (background mode)
|
||||
|
||||
**Description**: This command starts all services (LangGraph, Gateway, Frontend, Nginx).
|
||||
|
||||
**Notes**:
|
||||
- `make dev` runs in the foreground and stops with Ctrl+C
|
||||
- `make dev-daemon` runs in the background
|
||||
- Use `make stop` to stop services
|
||||
|
||||
---
|
||||
|
||||
#### 4.1.5 Wait for Services to Start
|
||||
|
||||
**Steps**:
|
||||
1. Wait 90-120 seconds for all services to start completely
|
||||
2. You can monitor startup progress by checking these log files:
|
||||
- `logs/langgraph.log`
|
||||
- `logs/gateway.log`
|
||||
- `logs/frontend.log`
|
||||
- `logs/nginx.log`
|
||||
|
||||
---
|
||||
|
||||
### 4.2 Docker Mode Deployment (If Docker Is Selected)
|
||||
|
||||
#### 4.2.1 Initialize the Docker Environment
|
||||
|
||||
**Steps**:
|
||||
1. Run `make docker-init`
|
||||
|
||||
**Description**: This command pulls the sandbox image if needed.
|
||||
|
||||
---
|
||||
|
||||
#### 4.2.2 Start Docker Services
|
||||
|
||||
**Steps**:
|
||||
1. Run `make docker-start`
|
||||
|
||||
**Description**: This command builds and starts all required Docker containers.
|
||||
|
||||
---
|
||||
|
||||
#### 4.2.3 Wait for Services to Start
|
||||
|
||||
**Steps**:
|
||||
1. Wait 60-90 seconds for all services to start completely
|
||||
2. You can run `make docker-logs` to monitor startup progress
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: Service Health Check
|
||||
|
||||
### 5.1 Local Mode Health Check
|
||||
|
||||
#### 5.1.1 Check Process Status
|
||||
|
||||
**Steps**:
|
||||
1. Run the following command to check processes:
|
||||
```bash
|
||||
ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
|
||||
```
|
||||
|
||||
**Success Criteria**: Confirm that the following processes are running:
|
||||
- LangGraph (`langgraph dev`)
|
||||
- Gateway (`uvicorn app.gateway.app:app`)
|
||||
- Frontend (`next dev` or `next start`)
|
||||
- Nginx (`nginx`)
|
||||
|
||||
---
|
||||
|
||||
#### 5.1.2 Check Frontend Service
|
||||
|
||||
**Steps**:
|
||||
1. Use curl or a browser to visit `http://localhost:2026`
|
||||
2. Verify that the page loads normally
|
||||
|
||||
**Example curl command**:
|
||||
```bash
|
||||
curl -I http://localhost:2026
|
||||
```
|
||||
|
||||
**Success Criteria**: Returns an HTTP 200 status code.
|
||||
|
||||
---
|
||||
|
||||
#### 5.1.3 Check API Gateway
|
||||
|
||||
**Steps**:
|
||||
1. Visit `http://localhost:2026/health`
|
||||
|
||||
**Example curl command**:
|
||||
```bash
|
||||
curl http://localhost:2026/health
|
||||
```
|
||||
|
||||
**Success Criteria**: Returns health status JSON.
|
||||
|
||||
---
|
||||
|
||||
#### 5.1.4 Check LangGraph Service
|
||||
|
||||
**Steps**:
|
||||
1. Visit relevant LangGraph endpoints to verify availability
|
||||
|
||||
---
|
||||
|
||||
### 5.2 Docker Mode Health Check (When Using Docker)
|
||||
|
||||
#### 5.2.1 Check Container Status
|
||||
|
||||
**Steps**:
|
||||
1. Run `docker ps`
|
||||
2. Confirm that the following containers are running:
|
||||
- `deer-flow-nginx`
|
||||
- `deer-flow-frontend`
|
||||
- `deer-flow-gateway`
|
||||
- `deer-flow-langgraph` (if not in gateway mode)
|
||||
|
||||
---
|
||||
|
||||
#### 5.2.2 Check Frontend Service
|
||||
|
||||
**Steps**:
|
||||
1. Use curl or a browser to visit `http://localhost:2026`
|
||||
2. Verify that the page loads normally
|
||||
|
||||
**Example curl command**:
|
||||
```bash
|
||||
curl -I http://localhost:2026
|
||||
```
|
||||
|
||||
**Success Criteria**: Returns an HTTP 200 status code.
|
||||
|
||||
---
|
||||
|
||||
#### 5.2.3 Check API Gateway
|
||||
|
||||
**Steps**:
|
||||
1. Visit `http://localhost:2026/health`
|
||||
|
||||
**Example curl command**:
|
||||
```bash
|
||||
curl http://localhost:2026/health
|
||||
```
|
||||
|
||||
**Success Criteria**: Returns health status JSON.
|
||||
|
||||
---
|
||||
|
||||
#### 5.2.4 Check LangGraph Service
|
||||
|
||||
**Steps**:
|
||||
1. Visit relevant LangGraph endpoints to verify availability
|
||||
|
||||
---
|
||||
|
||||
## Optional Functional Verification
|
||||
|
||||
### 6.1 List Available Models
|
||||
|
||||
**Steps**: Verify the model list through the API or UI.
|
||||
|
||||
---
|
||||
|
||||
### 6.2 List Available Skills
|
||||
|
||||
**Steps**: Verify the skill list through the API or UI.
|
||||
|
||||
---
|
||||
|
||||
### 6.3 Simple Chat Test
|
||||
|
||||
**Steps**: Send a simple message to test the complete workflow.
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: Generate the Test Report
|
||||
|
||||
### 6.1 Collect Test Results
|
||||
|
||||
Summarize the execution status of each phase and record successful and failed items.
|
||||
|
||||
### 6.2 Record Issues
|
||||
|
||||
If anything fails, record detailed error information.
|
||||
|
||||
### 6.3 Generate the Report
|
||||
|
||||
Use the template to create a complete test report.
|
||||
|
||||
### 6.4 Provide Recommendations
|
||||
|
||||
Provide follow-up recommendations based on the test results.
|
||||
@@ -1,612 +0,0 @@
|
||||
# Troubleshooting Guide
|
||||
|
||||
This document lists common issues encountered during DeerFlow smoke testing and how to resolve them.
|
||||
|
||||
## Code Update Issues
|
||||
|
||||
### Issue: `git pull` Fails with a Merge Conflict Warning
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
error: Your local changes to the following files would be overwritten by merge
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Option A: Commit local changes first
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "Save local changes"
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
2. Option B: Stash local changes
|
||||
```bash
|
||||
git stash
|
||||
git pull origin main
|
||||
git stash pop # Restore changes later if needed
|
||||
```
|
||||
|
||||
3. Option C: Discard local changes (use with caution)
|
||||
```bash
|
||||
git reset --hard HEAD
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Local Mode Environment Issues
|
||||
|
||||
### Issue: Node.js Version Is Too Old
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
Node.js version is too old. Requires 22+, got x.x.x
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Install or upgrade Node.js with nvm:
|
||||
```bash
|
||||
nvm install 22
|
||||
nvm use 22
|
||||
```
|
||||
|
||||
2. Or download and install it from the official website: https://nodejs.org/
|
||||
|
||||
3. Verify the version:
|
||||
```bash
|
||||
node --version
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: pnpm Is Not Installed
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
command not found: pnpm
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Install pnpm with npm:
|
||||
```bash
|
||||
npm install -g pnpm
|
||||
```
|
||||
|
||||
2. Or use the official installation script:
|
||||
```bash
|
||||
curl -fsSL https://get.pnpm.io/install.sh | sh -
|
||||
```
|
||||
|
||||
3. Verify the installation:
|
||||
```bash
|
||||
pnpm --version
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: uv Is Not Installed
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
command not found: uv
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Use the official installation script:
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
2. macOS users can also install it with Homebrew:
|
||||
```bash
|
||||
brew install uv
|
||||
```
|
||||
|
||||
3. Verify the installation:
|
||||
```bash
|
||||
uv --version
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: nginx Is Not Installed
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
command not found: nginx
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. macOS (Homebrew):
|
||||
```bash
|
||||
brew install nginx
|
||||
```
|
||||
|
||||
2. Ubuntu/Debian:
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install nginx
|
||||
```
|
||||
|
||||
3. CentOS/RHEL:
|
||||
```bash
|
||||
sudo yum install nginx
|
||||
```
|
||||
|
||||
4. Verify the installation:
|
||||
```bash
|
||||
nginx -v
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Port Is Already in Use
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
Error: listen EADDRINUSE: address already in use :::2026
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Find the process using the port:
|
||||
```bash
|
||||
lsof -i :2026 # macOS/Linux
|
||||
netstat -ano | findstr :2026 # Windows
|
||||
```
|
||||
|
||||
2. Stop that process:
|
||||
```bash
|
||||
kill -9 <PID> # macOS/Linux
|
||||
taskkill /PID <PID> /F # Windows
|
||||
```
|
||||
|
||||
3. Or stop DeerFlow services first:
|
||||
```bash
|
||||
make stop
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Local Mode Dependency Installation Issues
|
||||
|
||||
### Issue: `make install` Fails Due to Network Timeout
|
||||
|
||||
**Symptoms**:
|
||||
Network timeouts or connection failures occur during dependency installation.
|
||||
|
||||
**Solutions**:
|
||||
1. Configure pnpm to use a mirror registry:
|
||||
```bash
|
||||
pnpm config set registry https://registry.npmmirror.com
|
||||
```
|
||||
|
||||
2. Configure uv to use a mirror registry:
|
||||
```bash
|
||||
uv pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
3. Retry the installation:
|
||||
```bash
|
||||
make install
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Python Dependency Installation Fails
|
||||
|
||||
**Symptoms**:
|
||||
Errors occur during `uv sync`.
|
||||
|
||||
**Solutions**:
|
||||
1. Clean the uv cache:
|
||||
```bash
|
||||
cd backend
|
||||
uv cache clean
|
||||
```
|
||||
|
||||
2. Resync dependencies:
|
||||
```bash
|
||||
cd backend
|
||||
uv sync
|
||||
```
|
||||
|
||||
3. View detailed error logs:
|
||||
```bash
|
||||
cd backend
|
||||
uv sync --verbose
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Frontend Dependency Installation Fails
|
||||
|
||||
**Symptoms**:
|
||||
Errors occur during `pnpm install`.
|
||||
|
||||
**Solutions**:
|
||||
1. Clean the pnpm cache:
|
||||
```bash
|
||||
cd frontend
|
||||
pnpm store prune
|
||||
```
|
||||
|
||||
2. Remove node_modules and the lock file:
|
||||
```bash
|
||||
cd frontend
|
||||
rm -rf node_modules pnpm-lock.yaml
|
||||
```
|
||||
|
||||
3. Reinstall:
|
||||
```bash
|
||||
cd frontend
|
||||
pnpm install
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Local Mode Service Startup Issues
|
||||
|
||||
### Issue: Services Exit Immediately After Startup
|
||||
|
||||
**Symptoms**:
|
||||
Processes exit quickly after running `make dev-daemon`.
|
||||
|
||||
**Solutions**:
|
||||
1. Check log files:
|
||||
```bash
|
||||
tail -f logs/langgraph.log
|
||||
tail -f logs/gateway.log
|
||||
tail -f logs/frontend.log
|
||||
tail -f logs/nginx.log
|
||||
```
|
||||
|
||||
2. Check whether config.yaml is configured correctly
|
||||
3. Check environment variables in the .env file
|
||||
4. Confirm that required ports are not occupied
|
||||
5. Stop all services and restart:
|
||||
```bash
|
||||
make stop
|
||||
make dev-daemon
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Nginx Fails to Start Because Temp Directories Do Not Exist
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
nginx: [emerg] mkdir() "/opt/homebrew/var/run/nginx/client_body_temp" failed (2: No such file or directory)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
Add local temp directory configuration to `docker/nginx/nginx.local.conf` so nginx uses the repository's temp directory.
|
||||
|
||||
Add the following at the beginning of the `http` block:
|
||||
```nginx
|
||||
client_body_temp_path temp/client_body_temp;
|
||||
proxy_temp_path temp/proxy_temp;
|
||||
fastcgi_temp_path temp/fastcgi_temp;
|
||||
uwsgi_temp_path temp/uwsgi_temp;
|
||||
scgi_temp_path temp/scgi_temp;
|
||||
```
|
||||
|
||||
Note: The `temp/` directory under the repository root is created automatically by `make dev` or `make dev-daemon`.
|
||||
|
||||
---
|
||||
|
||||
### Issue: Nginx Fails to Start (General)
|
||||
|
||||
**Symptoms**:
|
||||
The nginx process fails to start or reports an error.
|
||||
|
||||
**Solutions**:
|
||||
1. Check the nginx configuration:
|
||||
```bash
|
||||
nginx -t -c docker/nginx/nginx.local.conf -p .
|
||||
```
|
||||
|
||||
2. Check nginx logs:
|
||||
```bash
|
||||
tail -f logs/nginx.log
|
||||
```
|
||||
|
||||
3. Ensure no other nginx process is running:
|
||||
```bash
|
||||
ps aux | grep nginx
|
||||
```
|
||||
|
||||
4. If needed, stop existing nginx processes:
|
||||
```bash
|
||||
pkill -9 nginx
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Frontend Compilation Fails
|
||||
|
||||
**Symptoms**:
|
||||
Compilation errors appear in `frontend.log`.
|
||||
|
||||
**Solutions**:
|
||||
1. Check frontend logs:
|
||||
```bash
|
||||
tail -f logs/frontend.log
|
||||
```
|
||||
|
||||
2. Check whether Node.js version is 22+
|
||||
3. Reinstall frontend dependencies:
|
||||
```bash
|
||||
cd frontend
|
||||
rm -rf node_modules .next
|
||||
pnpm install
|
||||
```
|
||||
|
||||
4. Restart services:
|
||||
```bash
|
||||
make stop
|
||||
make dev-daemon
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Gateway Fails to Start
|
||||
|
||||
**Symptoms**:
|
||||
Errors appear in `gateway.log`.
|
||||
|
||||
**Solutions**:
|
||||
1. Check gateway logs:
|
||||
```bash
|
||||
tail -f logs/gateway.log
|
||||
```
|
||||
|
||||
2. Check whether config.yaml exists and has valid formatting
|
||||
3. Check whether Python dependencies are complete:
|
||||
```bash
|
||||
cd backend
|
||||
uv sync
|
||||
```
|
||||
|
||||
4. Confirm that the LangGraph service is running normally (if not in gateway mode)
|
||||
|
||||
---
|
||||
|
||||
### Issue: LangGraph Fails to Start
|
||||
|
||||
**Symptoms**:
|
||||
Errors appear in `langgraph.log`.
|
||||
|
||||
**Solutions**:
|
||||
1. Check LangGraph logs:
|
||||
```bash
|
||||
tail -f logs/langgraph.log
|
||||
```
|
||||
|
||||
2. Check config.yaml
|
||||
3. Check whether Python dependencies are complete
|
||||
4. Confirm that port 2024 is not occupied
|
||||
|
||||
---
|
||||
|
||||
## Docker-Related Issues
|
||||
|
||||
### Issue: Docker Commands Cannot Run
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
Cannot connect to the Docker daemon
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Confirm that Docker Desktop is running
|
||||
2. macOS: check whether the Docker icon appears in the top menu bar
|
||||
3. Linux: run `sudo systemctl start docker`
|
||||
4. Run `docker info` again to verify
|
||||
|
||||
---
|
||||
|
||||
### Issue: `make docker-init` Fails to Pull the Image
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
Error pulling image: connection refused
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Check network connectivity
|
||||
2. Configure a Docker image mirror if needed
|
||||
3. Check whether a proxy is required
|
||||
4. Switch to local installation mode if necessary (recommended)
|
||||
|
||||
---
|
||||
|
||||
## Configuration File Issues
|
||||
|
||||
### Issue: config.yaml Is Missing or Invalid
|
||||
|
||||
**Symptoms**:
|
||||
```
|
||||
Error: could not read config.yaml
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Regenerate the configuration file:
|
||||
```bash
|
||||
make config
|
||||
```
|
||||
|
||||
2. Check YAML syntax:
|
||||
- Make sure indentation is correct (use 2 spaces)
|
||||
- Make sure there are no tab characters
|
||||
- Check that there is a space after each colon
|
||||
|
||||
3. Use a YAML validation tool to check the format
|
||||
|
||||
---
|
||||
|
||||
### Issue: Model API Key Is Not Configured
|
||||
|
||||
**Symptoms**:
|
||||
After services start, API requests fail with authentication errors.
|
||||
|
||||
**Solutions**:
|
||||
1. Edit the .env file and add the API key:
|
||||
```bash
|
||||
OPENAI_API_KEY=your-actual-api-key-here
|
||||
```
|
||||
|
||||
2. Restart services (local mode):
|
||||
```bash
|
||||
make stop
|
||||
make dev-daemon
|
||||
```
|
||||
|
||||
3. Restart services (Docker mode):
|
||||
```bash
|
||||
make docker-stop
|
||||
make docker-start
|
||||
```
|
||||
|
||||
4. Confirm that the model configuration in config.yaml references the environment variable correctly
|
||||
|
||||
---
|
||||
|
||||
## Service Health Check Issues
|
||||
|
||||
### Issue: Frontend Page Is Not Accessible
|
||||
|
||||
**Symptoms**:
|
||||
The browser shows a connection failure when visiting http://localhost:2026.
|
||||
|
||||
**Solutions** (local mode):
|
||||
1. Confirm that the nginx process is running:
|
||||
```bash
|
||||
ps aux | grep nginx
|
||||
```
|
||||
|
||||
2. Check nginx logs:
|
||||
```bash
|
||||
tail -f logs/nginx.log
|
||||
```
|
||||
|
||||
3. Check firewall settings
|
||||
|
||||
**Solutions** (Docker mode):
|
||||
1. Confirm that the nginx container is running:
|
||||
```bash
|
||||
docker ps | grep nginx
|
||||
```
|
||||
|
||||
2. Check nginx logs:
|
||||
```bash
|
||||
cd docker && docker compose -p deer-flow-dev -f docker-compose-dev.yaml logs nginx
|
||||
```
|
||||
|
||||
3. Check firewall settings
|
||||
|
||||
---
|
||||
|
||||
### Issue: API Gateway Health Check Fails
|
||||
|
||||
**Symptoms**:
|
||||
Accessing `/health` returns an error or times out.
|
||||
|
||||
**Solutions** (local mode):
|
||||
1. Check gateway logs:
|
||||
```bash
|
||||
tail -f logs/gateway.log
|
||||
```
|
||||
|
||||
2. Confirm that config.yaml exists and has valid formatting
|
||||
3. Check whether Python dependencies are complete
|
||||
4. Confirm that the LangGraph service is running normally
|
||||
|
||||
**Solutions** (Docker mode):
|
||||
1. Check gateway container logs:
|
||||
```bash
|
||||
make docker-logs-gateway
|
||||
```
|
||||
|
||||
2. Confirm that config.yaml is mounted correctly
|
||||
3. Check whether Python dependencies are complete
|
||||
4. Confirm that the LangGraph service is running normally
|
||||
|
||||
---
|
||||
|
||||
## Common Diagnostic Commands
|
||||
|
||||
### Local Mode Diagnostics
|
||||
|
||||
#### View All Service Processes
|
||||
```bash
|
||||
ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
|
||||
```
|
||||
|
||||
#### View Service Logs
|
||||
```bash
|
||||
# View all logs
|
||||
tail -f logs/*.log
|
||||
|
||||
# View specific service logs
|
||||
tail -f logs/langgraph.log
|
||||
tail -f logs/gateway.log
|
||||
tail -f logs/frontend.log
|
||||
tail -f logs/nginx.log
|
||||
```
|
||||
|
||||
#### Stop All Services
|
||||
```bash
|
||||
make stop
|
||||
```
|
||||
|
||||
#### Fully Reset the Local Environment
|
||||
```bash
|
||||
make stop
|
||||
make clean
|
||||
make config
|
||||
make install
|
||||
make dev-daemon
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Docker Mode Diagnostics
|
||||
|
||||
#### View All Container Status
|
||||
```bash
|
||||
docker ps -a
|
||||
```
|
||||
|
||||
#### View Container Resource Usage
|
||||
```bash
|
||||
docker stats
|
||||
```
|
||||
|
||||
#### Enter a Container for Debugging
|
||||
```bash
|
||||
docker exec -it deer-flow-gateway sh
|
||||
```
|
||||
|
||||
#### Clean Up All DeerFlow-Related Containers and Images
|
||||
```bash
|
||||
make docker-stop
|
||||
cd docker && docker compose -p deer-flow-dev -f docker-compose-dev.yaml down -v
|
||||
```
|
||||
|
||||
#### Fully Reset the Docker Environment
|
||||
```bash
|
||||
make docker-stop
|
||||
make clean
|
||||
make config
|
||||
make docker-init
|
||||
make docker-start
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Get More Help
|
||||
|
||||
If the solutions above do not resolve the issue:
|
||||
1. Check the GitHub issues for the project: https://github.com/bytedance/deer-flow/issues
|
||||
2. Review the project documentation: README.md and the `backend/docs/` directory
|
||||
3. Open a new issue and include detailed error logs
|
||||
@@ -1,80 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Checking Docker Environment"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check whether Docker is installed
|
||||
if command -v docker >/dev/null 2>&1; then
|
||||
echo "✓ Docker is installed"
|
||||
docker --version
|
||||
else
|
||||
echo "✗ Docker is not installed"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check the Docker daemon
|
||||
if docker info >/dev/null 2>&1; then
|
||||
echo "✓ Docker daemon is running normally"
|
||||
else
|
||||
echo "✗ Docker daemon is not running"
|
||||
echo " Please start Docker Desktop or the Docker service"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check Docker Compose
|
||||
if docker compose version >/dev/null 2>&1; then
|
||||
echo "✓ Docker Compose is available"
|
||||
docker compose version
|
||||
else
|
||||
echo "✗ Docker Compose is not available"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check port 2026
|
||||
if ! command -v lsof >/dev/null 2>&1; then
|
||||
echo "✗ lsof is required to check whether port 2026 is available"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
port_2026_usage="$(lsof -nP -iTCP:2026 -sTCP:LISTEN 2>/dev/null || true)"
|
||||
if [ -n "$port_2026_usage" ]; then
|
||||
echo "⚠ Port 2026 is already in use"
|
||||
echo " Occupying process:"
|
||||
echo "$port_2026_usage"
|
||||
|
||||
deerflow_process_found=0
|
||||
while IFS= read -r pid; do
|
||||
if [ -z "$pid" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
process_command="$(ps -p "$pid" -o command= 2>/dev/null || true)"
|
||||
case "$process_command" in
|
||||
*[Dd]eer[Ff]low*|*[Dd]eerflow*|*[Nn]ginx*deerflow*|*deerflow/*[Nn]ginx*)
|
||||
deerflow_process_found=1
|
||||
;;
|
||||
esac
|
||||
done <<EOF
|
||||
$(printf '%s\n' "$port_2026_usage" | awk 'NR > 1 {print $2}')
|
||||
EOF
|
||||
|
||||
if [ "$deerflow_process_found" -eq 1 ]; then
|
||||
echo "✓ Port 2026 is occupied by DeerFlow"
|
||||
else
|
||||
echo "✗ Port 2026 must be free before starting DeerFlow"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✓ Port 2026 is available"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo " Docker Environment Check Complete"
|
||||
echo "=========================================="
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Checking Local Development Environment"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
all_passed=true
|
||||
|
||||
# Check Node.js
|
||||
echo "1. Checking Node.js..."
|
||||
if command -v node >/dev/null 2>&1; then
|
||||
NODE_VERSION=$(node --version | sed 's/v//')
|
||||
NODE_MAJOR=$(echo "$NODE_VERSION" | cut -d. -f1)
|
||||
if [ "$NODE_MAJOR" -ge 22 ]; then
|
||||
echo "✓ Node.js is installed (version: $NODE_VERSION)"
|
||||
else
|
||||
echo "✗ Node.js version is too old (current: $NODE_VERSION, required: 22+)"
|
||||
all_passed=false
|
||||
fi
|
||||
else
|
||||
echo "✗ Node.js is not installed"
|
||||
all_passed=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check pnpm
|
||||
echo "2. Checking pnpm..."
|
||||
if command -v pnpm >/dev/null 2>&1; then
|
||||
echo "✓ pnpm is installed (version: $(pnpm --version))"
|
||||
else
|
||||
echo "✗ pnpm is not installed"
|
||||
echo " Install command: npm install -g pnpm"
|
||||
all_passed=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check uv
|
||||
echo "3. Checking uv..."
|
||||
if command -v uv >/dev/null 2>&1; then
|
||||
echo "✓ uv is installed (version: $(uv --version))"
|
||||
else
|
||||
echo "✗ uv is not installed"
|
||||
all_passed=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check nginx
|
||||
echo "4. Checking nginx..."
|
||||
if command -v nginx >/dev/null 2>&1; then
|
||||
echo "✓ nginx is installed (version: $(nginx -v 2>&1))"
|
||||
else
|
||||
echo "✗ nginx is not installed"
|
||||
echo " macOS: brew install nginx"
|
||||
echo " Linux: install it with the system package manager"
|
||||
all_passed=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check ports
|
||||
echo "5. Checking ports..."
|
||||
if ! command -v lsof >/dev/null 2>&1; then
|
||||
echo "✗ lsof is not installed, so port availability cannot be verified"
|
||||
echo " Install lsof and rerun this check"
|
||||
all_passed=false
|
||||
else
|
||||
for port in 2026 3000 8001 2024; do
|
||||
if lsof -i :$port >/dev/null 2>&1; then
|
||||
echo "⚠ Port $port is already in use:"
|
||||
lsof -i :$port | head -2
|
||||
all_passed=false
|
||||
else
|
||||
echo "✓ Port $port is available"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Summary
|
||||
echo "=========================================="
|
||||
echo " Environment Check Summary"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
if [ "$all_passed" = true ]; then
|
||||
echo "✅ All environment checks passed!"
|
||||
echo ""
|
||||
echo "Next step: run make install to install dependencies"
|
||||
exit 0
|
||||
else
|
||||
echo "❌ Some checks failed. Please fix the issues above first"
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,65 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Docker Deployment"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check config.yaml
|
||||
if [ ! -f "config.yaml" ]; then
|
||||
echo "config.yaml does not exist. Generating it..."
|
||||
make config
|
||||
echo ""
|
||||
echo "⚠ Please edit config.yaml to configure your models and API keys"
|
||||
echo " Then run this script again"
|
||||
exit 1
|
||||
else
|
||||
echo "✓ config.yaml exists"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check the .env file
|
||||
if [ ! -f ".env" ]; then
|
||||
echo ".env does not exist. Copying it from the example..."
|
||||
if [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo "✓ Created the .env file"
|
||||
else
|
||||
echo "⚠ .env.example does not exist. Please create the .env file manually"
|
||||
fi
|
||||
else
|
||||
echo "✓ .env file exists"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check the frontend .env file
|
||||
if [ ! -f "frontend/.env" ]; then
|
||||
echo "frontend/.env does not exist. Copying it from the example..."
|
||||
if [ -f "frontend/.env.example" ]; then
|
||||
cp frontend/.env.example frontend/.env
|
||||
echo "✓ Created the frontend/.env file"
|
||||
else
|
||||
echo "⚠ frontend/.env.example does not exist. Please create frontend/.env manually"
|
||||
fi
|
||||
else
|
||||
echo "✓ frontend/.env file exists"
|
||||
fi
|
||||
echo ""
|
||||
# Initialize the Docker environment
|
||||
echo "Initializing the Docker environment..."
|
||||
make docker-init
|
||||
echo ""
|
||||
|
||||
# Start Docker services
|
||||
echo "Starting Docker services..."
|
||||
make docker-start
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo " Deployment Complete"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "🌐 Access URL: http://localhost:2026"
|
||||
echo "📋 View logs: make docker-logs"
|
||||
echo "🛑 Stop services: make docker-stop"
|
||||
@@ -1,63 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Local Mode Deployment"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check config.yaml
|
||||
if [ ! -f "config.yaml" ]; then
|
||||
echo "config.yaml does not exist. Generating it..."
|
||||
make config
|
||||
echo ""
|
||||
echo "⚠ Please edit config.yaml to configure your models and API keys"
|
||||
echo " Then run this script again"
|
||||
exit 1
|
||||
else
|
||||
echo "✓ config.yaml exists"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check the .env file
|
||||
if [ ! -f ".env" ]; then
|
||||
echo ".env does not exist. Copying it from the example..."
|
||||
if [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo "✓ Created the .env file"
|
||||
else
|
||||
echo "⚠ .env.example does not exist. Please create the .env file manually"
|
||||
fi
|
||||
else
|
||||
echo "✓ .env file exists"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check dependencies
|
||||
echo "Checking dependencies..."
|
||||
make check
|
||||
echo ""
|
||||
|
||||
# Install dependencies
|
||||
echo "Installing dependencies..."
|
||||
make install
|
||||
echo ""
|
||||
|
||||
# Start services
|
||||
echo "Starting services (background mode)..."
|
||||
make dev-daemon
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo " Deployment Complete"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "🌐 Access URL: http://localhost:2026"
|
||||
echo "📋 View logs:"
|
||||
echo " - logs/langgraph.log"
|
||||
echo " - logs/gateway.log"
|
||||
echo " - logs/frontend.log"
|
||||
echo " - logs/nginx.log"
|
||||
echo "🛑 Stop services: make stop"
|
||||
echo ""
|
||||
echo "Please wait 90-120 seconds for all services to start completely, then run the health check"
|
||||
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set +e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Frontend Page Smoke Check"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
BASE_URL="${BASE_URL:-http://localhost:2026}"
|
||||
DOC_PATH="${DOC_PATH:-/en/docs}"
|
||||
|
||||
all_passed=true
|
||||
|
||||
check_status() {
|
||||
local name="$1"
|
||||
local url="$2"
|
||||
local expected_re="$3"
|
||||
|
||||
local status
|
||||
status="$(curl -s -o /dev/null -w "%{http_code}" -L "$url")"
|
||||
if echo "$status" | grep -Eq "$expected_re"; then
|
||||
echo "✓ $name ($url) -> $status"
|
||||
else
|
||||
echo "✗ $name ($url) -> $status (expected: $expected_re)"
|
||||
all_passed=false
|
||||
fi
|
||||
}
|
||||
|
||||
check_final_url() {
|
||||
local name="$1"
|
||||
local url="$2"
|
||||
local expected_path_re="$3"
|
||||
|
||||
local effective
|
||||
effective="$(curl -s -o /dev/null -w "%{url_effective}" -L "$url")"
|
||||
if echo "$effective" | grep -Eq "$expected_path_re"; then
|
||||
echo "✓ $name redirect target -> $effective"
|
||||
else
|
||||
echo "✗ $name redirect target -> $effective (expected path: $expected_path_re)"
|
||||
all_passed=false
|
||||
fi
|
||||
}
|
||||
|
||||
echo "1. Checking entry pages..."
|
||||
check_status "Landing page" "${BASE_URL}/" "200"
|
||||
check_status "Workspace redirect" "${BASE_URL}/workspace" "200|301|302|307|308"
|
||||
check_final_url "Workspace redirect" "${BASE_URL}/workspace" "/workspace/chats/"
|
||||
echo ""
|
||||
|
||||
echo "2. Checking key workspace routes..."
|
||||
check_status "New chat page" "${BASE_URL}/workspace/chats/new" "200"
|
||||
check_status "Chats list page" "${BASE_URL}/workspace/chats" "200"
|
||||
check_status "Agents gallery page" "${BASE_URL}/workspace/agents" "200"
|
||||
echo ""
|
||||
|
||||
echo "3. Checking docs route (optional)..."
|
||||
check_status "Docs page" "${BASE_URL}${DOC_PATH}" "200|404"
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo " Frontend Smoke Check Summary"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
if [ "$all_passed" = true ]; then
|
||||
echo "✅ Frontend smoke checks passed!"
|
||||
exit 0
|
||||
else
|
||||
echo "❌ Frontend smoke checks failed"
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,125 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set +e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Service Health Check"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
all_passed=true
|
||||
mode="${SMOKE_TEST_MODE:-auto}"
|
||||
summary_hint="make logs"
|
||||
|
||||
print_step() {
|
||||
echo "$1"
|
||||
}
|
||||
|
||||
check_http_status() {
|
||||
local name="$1"
|
||||
local url="$2"
|
||||
local expected_re="$3"
|
||||
local status
|
||||
|
||||
status="$(curl -s -o /dev/null -w "%{http_code}" "$url" 2>/dev/null)"
|
||||
if echo "$status" | grep -Eq "$expected_re"; then
|
||||
echo "✓ $name is accessible ($url -> $status)"
|
||||
else
|
||||
echo "✗ $name is not accessible ($url -> ${status:-000})"
|
||||
all_passed=false
|
||||
fi
|
||||
}
|
||||
|
||||
check_listen_port() {
|
||||
local name="$1"
|
||||
local port="$2"
|
||||
|
||||
if lsof -nP -iTCP:"$port" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "✓ $name is listening on port $port"
|
||||
else
|
||||
echo "✗ $name is not listening on port $port"
|
||||
all_passed=false
|
||||
fi
|
||||
}
|
||||
|
||||
docker_available() {
|
||||
command -v docker >/dev/null 2>&1 && docker info >/dev/null 2>&1
|
||||
}
|
||||
|
||||
detect_mode() {
|
||||
case "$mode" in
|
||||
local|docker)
|
||||
echo "$mode"
|
||||
return
|
||||
;;
|
||||
esac
|
||||
|
||||
if docker_available && docker ps --format "{{.Names}}" | grep -q "deer-flow"; then
|
||||
echo "docker"
|
||||
else
|
||||
echo "local"
|
||||
fi
|
||||
}
|
||||
|
||||
mode="$(detect_mode)"
|
||||
|
||||
echo "Deployment mode: $mode"
|
||||
echo ""
|
||||
|
||||
if [ "$mode" = "docker" ]; then
|
||||
summary_hint="make docker-logs"
|
||||
print_step "1. Checking container status..."
|
||||
if docker ps --format "{{.Names}}" | grep -q "deer-flow"; then
|
||||
echo "✓ Containers are running:"
|
||||
docker ps --format " - {{.Names}} ({{.Status}})"
|
||||
else
|
||||
echo "✗ No DeerFlow-related containers are running"
|
||||
all_passed=false
|
||||
fi
|
||||
else
|
||||
summary_hint="logs/{langgraph,gateway,frontend,nginx}.log"
|
||||
print_step "1. Checking local service ports..."
|
||||
check_listen_port "Nginx" 2026
|
||||
check_listen_port "Frontend" 3000
|
||||
check_listen_port "Gateway" 8001
|
||||
check_listen_port "LangGraph" 2024
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "2. Waiting for services to fully start (30 seconds)..."
|
||||
sleep 30
|
||||
echo ""
|
||||
|
||||
echo "3. Checking frontend service..."
|
||||
check_http_status "Frontend service" "http://localhost:2026" "200|301|302|307|308"
|
||||
echo ""
|
||||
|
||||
echo "4. Checking API Gateway..."
|
||||
health_response=$(curl -s http://localhost:2026/health 2>/dev/null)
|
||||
if [ $? -eq 0 ] && [ -n "$health_response" ]; then
|
||||
echo "✓ API Gateway health check passed"
|
||||
echo " Response: $health_response"
|
||||
else
|
||||
echo "✗ API Gateway health check failed"
|
||||
all_passed=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "5. Checking LangGraph service..."
|
||||
check_http_status "LangGraph service" "http://localhost:2024/" "200|301|302|307|308|404"
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo " Health Check Summary"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
if [ "$all_passed" = true ]; then
|
||||
echo "✅ All checks passed!"
|
||||
echo ""
|
||||
echo "🌐 Application URL: http://localhost:2026"
|
||||
exit 0
|
||||
else
|
||||
echo "❌ Some checks failed"
|
||||
echo ""
|
||||
echo "Please review: $summary_hint"
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Pulling the Latest Code"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check whether the current directory is a Git repository
|
||||
if [ ! -d ".git" ]; then
|
||||
echo "✗ The current directory is not a Git repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check Git status
|
||||
echo "Checking Git status..."
|
||||
if git status --porcelain | grep -q .; then
|
||||
echo "⚠ Uncommitted changes detected:"
|
||||
git status --short
|
||||
echo ""
|
||||
echo "Please commit or stash your changes before continuing"
|
||||
echo "Options:"
|
||||
echo " 1. git add . && git commit -m 'Save changes'"
|
||||
echo " 2. git stash (stash changes and restore them later)"
|
||||
echo " 3. git reset --hard HEAD (discard local changes - use with caution)"
|
||||
exit 1
|
||||
else
|
||||
echo "✓ Working tree is clean"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Fetch remote updates
|
||||
echo "Fetching remote updates..."
|
||||
git fetch origin main
|
||||
echo ""
|
||||
|
||||
# Pull the latest code
|
||||
echo "Pulling the latest code..."
|
||||
git pull origin main
|
||||
echo ""
|
||||
|
||||
# Show the latest commit
|
||||
echo "Latest commit:"
|
||||
git log -1 --oneline
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo " Code Update Complete"
|
||||
echo "=========================================="
|
||||
@@ -1,180 +0,0 @@
|
||||
# DeerFlow Smoke Test Report
|
||||
|
||||
**Test Date**: {{test_date}}
|
||||
**Test Environment**: {{test_environment}}
|
||||
**Deployment Mode**: Docker
|
||||
**Test Version**: {{git_commit}}
|
||||
|
||||
---
|
||||
|
||||
## Execution Summary
|
||||
|
||||
| Metric | Status |
|
||||
|------|------|
|
||||
| Total Test Phases | 6 |
|
||||
| Passed Phases | {{passed_stages}} |
|
||||
| Failed Phases | {{failed_stages}} |
|
||||
| Overall Conclusion | **{{overall_status}}** |
|
||||
|
||||
### Key Test Cases
|
||||
|
||||
| Case | Result | Details |
|
||||
|------|--------|---------|
|
||||
| Code update check | {{case_code_update}} | {{case_code_update_details}} |
|
||||
| Environment check | {{case_env_check}} | {{case_env_check_details}} |
|
||||
| Configuration preparation | {{case_config_prep}} | {{case_config_prep_details}} |
|
||||
| Deployment | {{case_deploy}} | {{case_deploy_details}} |
|
||||
| Health check | {{case_health_check}} | {{case_health_check_details}} |
|
||||
| Frontend routes | {{case_frontend_routes_overall}} | {{case_frontend_routes_details}} |
|
||||
|
||||
---
|
||||
|
||||
## Detailed Test Results
|
||||
|
||||
### Phase 1: Code Update Check
|
||||
|
||||
- [x] Confirm current directory - {{status_dir_check}}
|
||||
- [x] Check Git status - {{status_git_status}}
|
||||
- [x] Pull latest code - {{status_git_pull}}
|
||||
- [x] Confirm code update - {{status_git_verify}}
|
||||
|
||||
**Phase Status**: {{stage1_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Docker Environment Check
|
||||
|
||||
- [x] Docker version - {{status_docker_version}}
|
||||
- [x] Docker daemon - {{status_docker_daemon}}
|
||||
- [x] Docker Compose - {{status_docker_compose}}
|
||||
- [x] Port check - {{status_port_check}}
|
||||
|
||||
**Phase Status**: {{stage2_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Configuration Preparation
|
||||
|
||||
- [x] config.yaml - {{status_config_yaml}}
|
||||
- [x] .env file - {{status_env_file}}
|
||||
- [x] Model configuration - {{status_model_config}}
|
||||
|
||||
**Phase Status**: {{stage3_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: Docker Deployment
|
||||
|
||||
- [x] docker-init - {{status_docker_init}}
|
||||
- [x] docker-start - {{status_docker_start}}
|
||||
- [x] Service startup wait - {{status_wait_startup}}
|
||||
|
||||
**Phase Status**: {{stage4_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: Service Health Check
|
||||
|
||||
- [x] Container status - {{status_containers}}
|
||||
- [x] Frontend service - {{status_frontend}}
|
||||
- [x] API Gateway - {{status_api_gateway}}
|
||||
- [x] LangGraph service - {{status_langgraph}}
|
||||
|
||||
**Phase Status**: {{stage5_status}}
|
||||
|
||||
---
|
||||
|
||||
### Frontend Routes Smoke Results
|
||||
|
||||
| Route | Status | Details |
|
||||
|-------|--------|---------|
|
||||
| Landing `/` | {{landing_status}} | {{landing_details}} |
|
||||
| Workspace redirect `/workspace` | {{workspace_redirect_status}} | target {{workspace_redirect_target}} |
|
||||
| New chat `/workspace/chats/new` | {{new_chat_status}} | {{new_chat_details}} |
|
||||
| Chats list `/workspace/chats` | {{chats_list_status}} | {{chats_list_details}} |
|
||||
| Agents gallery `/workspace/agents` | {{agents_gallery_status}} | {{agents_gallery_details}} |
|
||||
| Docs `{{docs_path}}` | {{docs_status}} | {{docs_details}} |
|
||||
|
||||
**Summary**: {{frontend_routes_summary}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 6: Test Report Generation
|
||||
|
||||
- [x] Result summary - {{status_summary}}
|
||||
- [x] Issue log - {{status_issues}}
|
||||
- [x] Report generation - {{status_report}}
|
||||
|
||||
**Phase Status**: {{stage6_status}}
|
||||
|
||||
---
|
||||
|
||||
## Issue Log
|
||||
|
||||
### Issue 1
|
||||
**Description**: {{issue1_description}}
|
||||
**Severity**: {{issue1_severity}}
|
||||
**Solution**: {{issue1_solution}}
|
||||
|
||||
---
|
||||
|
||||
## Environment Information
|
||||
|
||||
### Docker Version
|
||||
```text
|
||||
{{docker_version_output}}
|
||||
```
|
||||
|
||||
### Git Information
|
||||
```text
|
||||
Repository: {{git_repo}}
|
||||
Branch: {{git_branch}}
|
||||
Commit: {{git_commit}}
|
||||
Commit Message: {{git_commit_message}}
|
||||
```
|
||||
|
||||
### Configuration Summary
|
||||
- config.yaml exists: {{config_exists}}
|
||||
- .env file exists: {{env_exists}}
|
||||
- Number of configured models: {{model_count}}
|
||||
|
||||
---
|
||||
|
||||
## Container Status
|
||||
|
||||
| Container Name | Status | Uptime |
|
||||
|----------|------|----------|
|
||||
| deer-flow-nginx | {{nginx_status}} | {{nginx_uptime}} |
|
||||
| deer-flow-frontend | {{frontend_status}} | {{frontend_uptime}} |
|
||||
| deer-flow-gateway | {{gateway_status}} | {{gateway_uptime}} |
|
||||
| deer-flow-langgraph | {{langgraph_status}} | {{langgraph_uptime}} |
|
||||
|
||||
---
|
||||
|
||||
## Recommendations and Next Steps
|
||||
|
||||
### If the Test Passes
|
||||
1. [ ] Visit http://localhost:2026 to start using DeerFlow
|
||||
2. [ ] Configure your preferred model if it is not configured yet
|
||||
3. [ ] Explore available skills
|
||||
4. [ ] Refer to the documentation to learn more features
|
||||
|
||||
### If the Test Fails
|
||||
1. [ ] Review references/troubleshooting.md for common solutions
|
||||
2. [ ] Check Docker logs: `make docker-logs`
|
||||
3. [ ] Verify configuration file format and content
|
||||
4. [ ] If needed, fully reset the environment: `make clean && make config && make docker-init && make docker-start`
|
||||
|
||||
---
|
||||
|
||||
## Appendix
|
||||
|
||||
### Full Logs
|
||||
{{full_logs}}
|
||||
|
||||
### Tester
|
||||
{{tester_name}}
|
||||
|
||||
---
|
||||
|
||||
*Report generated at: {{report_time}}*
|
||||
@@ -1,185 +0,0 @@
|
||||
# DeerFlow Smoke Test Report
|
||||
|
||||
**Test Date**: {{test_date}}
|
||||
**Test Environment**: {{test_environment}}
|
||||
**Deployment Mode**: Local
|
||||
**Test Version**: {{git_commit}}
|
||||
|
||||
---
|
||||
|
||||
## Execution Summary
|
||||
|
||||
| Metric | Status |
|
||||
|------|------|
|
||||
| Total Test Phases | 6 |
|
||||
| Passed Phases | {{passed_stages}} |
|
||||
| Failed Phases | {{failed_stages}} |
|
||||
| Overall Conclusion | **{{overall_status}}** |
|
||||
|
||||
### Key Test Cases
|
||||
|
||||
| Case | Result | Details |
|
||||
|------|--------|---------|
|
||||
| Code update check | {{case_code_update}} | {{case_code_update_details}} |
|
||||
| Environment check | {{case_env_check}} | {{case_env_check_details}} |
|
||||
| Configuration preparation | {{case_config_prep}} | {{case_config_prep_details}} |
|
||||
| Deployment | {{case_deploy}} | {{case_deploy_details}} |
|
||||
| Health check | {{case_health_check}} | {{case_health_check_details}} |
|
||||
| Frontend routes | {{case_frontend_routes_overall}} | {{case_frontend_routes_details}} |
|
||||
|
||||
---
|
||||
|
||||
## Detailed Test Results
|
||||
|
||||
### Phase 1: Code Update Check
|
||||
|
||||
- [x] Confirm current directory - {{status_dir_check}}
|
||||
- [x] Check Git status - {{status_git_status}}
|
||||
- [x] Pull latest code - {{status_git_pull}}
|
||||
- [x] Confirm code update - {{status_git_verify}}
|
||||
|
||||
**Phase Status**: {{stage1_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Local Environment Check
|
||||
|
||||
- [x] Node.js version - {{status_node_version}}
|
||||
- [x] pnpm - {{status_pnpm}}
|
||||
- [x] uv - {{status_uv}}
|
||||
- [x] nginx - {{status_nginx}}
|
||||
- [x] Port check - {{status_port_check}}
|
||||
|
||||
**Phase Status**: {{stage2_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Configuration Preparation
|
||||
|
||||
- [x] config.yaml - {{status_config_yaml}}
|
||||
- [x] .env file - {{status_env_file}}
|
||||
- [x] Model configuration - {{status_model_config}}
|
||||
|
||||
**Phase Status**: {{stage3_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: Local Deployment
|
||||
|
||||
- [x] make check - {{status_make_check}}
|
||||
- [x] make install - {{status_make_install}}
|
||||
- [x] make dev-daemon / make dev - {{status_local_start}}
|
||||
- [x] Service startup wait - {{status_wait_startup}}
|
||||
|
||||
**Phase Status**: {{stage4_status}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: Service Health Check
|
||||
|
||||
- [x] Process status - {{status_processes}}
|
||||
- [x] Frontend service - {{status_frontend}}
|
||||
- [x] API Gateway - {{status_api_gateway}}
|
||||
- [x] LangGraph service - {{status_langgraph}}
|
||||
|
||||
**Phase Status**: {{stage5_status}}
|
||||
|
||||
---
|
||||
|
||||
### Frontend Routes Smoke Results
|
||||
|
||||
| Route | Status | Details |
|
||||
|-------|--------|---------|
|
||||
| Landing `/` | {{landing_status}} | {{landing_details}} |
|
||||
| Workspace redirect `/workspace` | {{workspace_redirect_status}} | target {{workspace_redirect_target}} |
|
||||
| New chat `/workspace/chats/new` | {{new_chat_status}} | {{new_chat_details}} |
|
||||
| Chats list `/workspace/chats` | {{chats_list_status}} | {{chats_list_details}} |
|
||||
| Agents gallery `/workspace/agents` | {{agents_gallery_status}} | {{agents_gallery_details}} |
|
||||
| Docs `{{docs_path}}` | {{docs_status}} | {{docs_details}} |
|
||||
|
||||
**Summary**: {{frontend_routes_summary}}
|
||||
|
||||
---
|
||||
|
||||
### Phase 6: Test Report Generation
|
||||
|
||||
- [x] Result summary - {{status_summary}}
|
||||
- [x] Issue log - {{status_issues}}
|
||||
- [x] Report generation - {{status_report}}
|
||||
|
||||
**Phase Status**: {{stage6_status}}
|
||||
|
||||
---
|
||||
|
||||
## Issue Log
|
||||
|
||||
### Issue 1
|
||||
**Description**: {{issue1_description}}
|
||||
**Severity**: {{issue1_severity}}
|
||||
**Solution**: {{issue1_solution}}
|
||||
|
||||
---
|
||||
|
||||
## Environment Information
|
||||
|
||||
### Local Dependency Versions
|
||||
```text
|
||||
Node.js: {{node_version_output}}
|
||||
pnpm: {{pnpm_version_output}}
|
||||
uv: {{uv_version_output}}
|
||||
nginx: {{nginx_version_output}}
|
||||
```
|
||||
|
||||
### Git Information
|
||||
```text
|
||||
Repository: {{git_repo}}
|
||||
Branch: {{git_branch}}
|
||||
Commit: {{git_commit}}
|
||||
Commit Message: {{git_commit_message}}
|
||||
```
|
||||
|
||||
### Configuration Summary
|
||||
- config.yaml exists: {{config_exists}}
|
||||
- .env file exists: {{env_exists}}
|
||||
- Number of configured models: {{model_count}}
|
||||
|
||||
---
|
||||
|
||||
## Local Service Status
|
||||
|
||||
| Service | Status | Endpoint |
|
||||
|---------|--------|----------|
|
||||
| Nginx | {{nginx_status}} | {{nginx_endpoint}} |
|
||||
| Frontend | {{frontend_status}} | {{frontend_endpoint}} |
|
||||
| Gateway | {{gateway_status}} | {{gateway_endpoint}} |
|
||||
| LangGraph | {{langgraph_status}} | {{langgraph_endpoint}} |
|
||||
|
||||
---
|
||||
|
||||
## Recommendations and Next Steps
|
||||
|
||||
### If the Test Passes
|
||||
1. [ ] Visit http://localhost:2026 to start using DeerFlow
|
||||
2. [ ] Configure your preferred model if it is not configured yet
|
||||
3. [ ] Explore available skills
|
||||
4. [ ] Refer to the documentation to learn more features
|
||||
|
||||
### If the Test Fails
|
||||
1. [ ] Review references/troubleshooting.md for common solutions
|
||||
2. [ ] Check local logs: `logs/{langgraph,gateway,frontend,nginx}.log`
|
||||
3. [ ] Verify configuration file format and content
|
||||
4. [ ] If needed, fully reset the environment: `make stop && make clean && make install && make dev-daemon`
|
||||
|
||||
---
|
||||
|
||||
## Appendix
|
||||
|
||||
### Full Logs
|
||||
{{full_logs}}
|
||||
|
||||
### Tester
|
||||
{{tester_name}}
|
||||
|
||||
---
|
||||
|
||||
*Report generated at: {{report_time}}*
|
||||
@@ -1,128 +0,0 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
willem.jiang@gmail.com.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
@@ -77,18 +77,6 @@ export UV_INDEX_URL=https://pypi.org/simple
|
||||
export NPM_REGISTRY=https://registry.npmjs.org
|
||||
```
|
||||
|
||||
#### Recommended host resources
|
||||
|
||||
Use these as practical starting points for development and review environments:
|
||||
|
||||
| Scenario | Starting point | Recommended | Notes |
|
||||
|---------|-----------|------------|-------|
|
||||
| `make dev` on one machine | 4 vCPU, 8 GB RAM | 8 vCPU, 16 GB RAM | Best when DeerFlow uses hosted model APIs. |
|
||||
| `make docker-start` review environment | 4 vCPU, 8 GB RAM | 8 vCPU, 16 GB RAM | Docker image builds and sandbox containers need extra headroom. |
|
||||
| Shared Linux test server | 8 vCPU, 16 GB RAM | 16 vCPU, 32 GB RAM | Prefer this for heavier multi-agent runs or multiple reviewers. |
|
||||
|
||||
`2 vCPU / 4 GB` environments often fail to start reliably or become unresponsive under normal DeerFlow workloads.
|
||||
|
||||
#### Linux: Docker daemon permission denied
|
||||
|
||||
If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket:
|
||||
|
||||
@@ -1,25 +1,19 @@
|
||||
# DeerFlow - Unified Development Environment
|
||||
|
||||
.PHONY: help config config-upgrade check install setup doctor dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||
.PHONY: help config config-upgrade check install dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||
|
||||
BASH ?= bash
|
||||
BACKEND_UV_RUN = cd backend && uv run
|
||||
|
||||
# Detect OS for Windows compatibility
|
||||
ifeq ($(OS),Windows_NT)
|
||||
SHELL := cmd.exe
|
||||
PYTHON ?= python
|
||||
# Run repo shell scripts through Git Bash when Make is launched from cmd.exe / PowerShell.
|
||||
RUN_WITH_GIT_BASH = call scripts\run-with-git-bash.cmd
|
||||
else
|
||||
PYTHON ?= python3
|
||||
RUN_WITH_GIT_BASH =
|
||||
endif
|
||||
|
||||
help:
|
||||
@echo "DeerFlow Development Commands:"
|
||||
@echo " make setup - Interactive setup wizard (recommended for new users)"
|
||||
@echo " make doctor - Check configuration and system requirements"
|
||||
@echo " make config - Generate local config files (aborts if config already exists)"
|
||||
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
|
||||
@echo " make check - Check if all required tools are installed"
|
||||
@@ -50,18 +44,11 @@ help:
|
||||
@echo " make docker-logs-frontend - View Docker frontend logs"
|
||||
@echo " make docker-logs-gateway - View Docker gateway logs"
|
||||
|
||||
## Setup & Diagnosis
|
||||
setup:
|
||||
@$(BACKEND_UV_RUN) python ../scripts/setup_wizard.py
|
||||
|
||||
doctor:
|
||||
@$(BACKEND_UV_RUN) python ../scripts/doctor.py
|
||||
|
||||
config:
|
||||
@$(PYTHON) ./scripts/configure.py
|
||||
|
||||
config-upgrade:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/config-upgrade.sh
|
||||
@./scripts/config-upgrade.sh
|
||||
|
||||
# Check required tools
|
||||
check:
|
||||
@@ -119,46 +106,78 @@ setup-sandbox:
|
||||
# Start all services in development mode (with hot-reloading)
|
||||
dev:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev
|
||||
else
|
||||
@./scripts/serve.sh --dev
|
||||
endif
|
||||
|
||||
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
|
||||
dev-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway
|
||||
else
|
||||
@./scripts/serve.sh --dev --gateway
|
||||
endif
|
||||
|
||||
# Start all services in production mode (with optimizations)
|
||||
start:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod
|
||||
else
|
||||
@./scripts/serve.sh --prod
|
||||
endif
|
||||
|
||||
# Start all services in prod + Gateway mode (experimental)
|
||||
start-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway
|
||||
else
|
||||
@./scripts/serve.sh --prod --gateway
|
||||
endif
|
||||
|
||||
# Start all services in daemon mode (background)
|
||||
dev-daemon:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --daemon
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --daemon
|
||||
else
|
||||
@./scripts/serve.sh --dev --daemon
|
||||
endif
|
||||
|
||||
# Start daemon + Gateway mode (experimental)
|
||||
dev-daemon-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway --daemon
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway --daemon
|
||||
else
|
||||
@./scripts/serve.sh --dev --gateway --daemon
|
||||
endif
|
||||
|
||||
# Start prod services in daemon mode (background)
|
||||
start-daemon:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --daemon
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --daemon
|
||||
else
|
||||
@./scripts/serve.sh --prod --daemon
|
||||
endif
|
||||
|
||||
# Start prod daemon + Gateway mode (experimental)
|
||||
start-daemon-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway --daemon
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway --daemon
|
||||
else
|
||||
@./scripts/serve.sh --prod --gateway --daemon
|
||||
endif
|
||||
|
||||
# Stop all services
|
||||
stop:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --stop
|
||||
@./scripts/serve.sh --stop
|
||||
|
||||
# Clean up
|
||||
clean: stop
|
||||
@@ -174,29 +193,29 @@ clean: stop
|
||||
|
||||
# Initialize Docker containers and install dependencies
|
||||
docker-init:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh init
|
||||
@./scripts/docker.sh init
|
||||
|
||||
# Start Docker development environment
|
||||
docker-start:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start
|
||||
@./scripts/docker.sh start
|
||||
|
||||
# Start Docker in Gateway mode (experimental)
|
||||
docker-start-pro:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start --gateway
|
||||
@./scripts/docker.sh start --gateway
|
||||
|
||||
# Stop Docker development environment
|
||||
docker-stop:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh stop
|
||||
@./scripts/docker.sh stop
|
||||
|
||||
# View Docker development logs
|
||||
docker-logs:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh logs
|
||||
@./scripts/docker.sh logs
|
||||
|
||||
# View Docker development logs
|
||||
docker-logs-frontend:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh logs --frontend
|
||||
@./scripts/docker.sh logs --frontend
|
||||
docker-logs-gateway:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh logs --gateway
|
||||
@./scripts/docker.sh logs --gateway
|
||||
|
||||
# ==========================================
|
||||
# Production Docker Commands
|
||||
@@ -204,12 +223,12 @@ docker-logs-gateway:
|
||||
|
||||
# Build and start production services
|
||||
up:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh
|
||||
@./scripts/deploy.sh
|
||||
|
||||
# Build and start production services in Gateway mode
|
||||
up-pro:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh --gateway
|
||||
@./scripts/deploy.sh --gateway
|
||||
|
||||
# Stop and remove production containers
|
||||
down:
|
||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh down
|
||||
@./scripts/deploy.sh down
|
||||
|
||||
@@ -53,7 +53,6 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
- [Quick Start](#quick-start)
|
||||
- [Configuration](#configuration)
|
||||
- [Running the Application](#running-the-application)
|
||||
- [Deployment Sizing](#deployment-sizing)
|
||||
- [Option 1: Docker (Recommended)](#option-1-docker-recommended)
|
||||
- [Option 2: Local Development](#option-2-local-development)
|
||||
- [Advanced](#advanced)
|
||||
@@ -104,38 +103,35 @@ That prompt is intended for coding agents. It tells the agent to clone the repo
|
||||
cd deer-flow
|
||||
```
|
||||
|
||||
2. **Run the setup wizard**
|
||||
2. **Generate local configuration files**
|
||||
|
||||
From the project root directory (`deer-flow/`), run:
|
||||
|
||||
```bash
|
||||
make setup
|
||||
make config
|
||||
```
|
||||
|
||||
This launches an interactive wizard that guides you through choosing an LLM provider, optional web search, and execution/safety preferences such as sandbox mode, bash access, and file-write tools. It generates a minimal `config.yaml` and writes your keys to `.env`. Takes about 2 minutes.
|
||||
This command creates local configuration files based on the provided example templates.
|
||||
|
||||
The wizard also lets you configure an optional web search provider, or skip it for now.
|
||||
3. **Configure your preferred model(s)**
|
||||
|
||||
Run `make doctor` at any time to verify your setup and get actionable fix hints.
|
||||
|
||||
> **Advanced / manual configuration**: If you prefer to edit `config.yaml` directly, run `make config` instead to copy the full template. See `config.example.yaml` for the complete reference including CLI-backed providers (Codex CLI, Claude Code OAuth), OpenRouter, Responses API, and more.
|
||||
|
||||
<details>
|
||||
<summary>Manual model configuration examples</summary>
|
||||
Edit `config.yaml` and define at least one model:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- name: gpt-4o
|
||||
display_name: GPT-4o
|
||||
use: langchain_openai:ChatOpenAI
|
||||
model: gpt-4o
|
||||
api_key: $OPENAI_API_KEY
|
||||
- name: gpt-4 # Internal identifier
|
||||
display_name: GPT-4 # Human-readable name
|
||||
use: langchain_openai:ChatOpenAI # LangChain class path
|
||||
model: gpt-4 # Model identifier for API
|
||||
api_key: $OPENAI_API_KEY # API key (recommended: use env var)
|
||||
max_tokens: 4096 # Maximum tokens per request
|
||||
temperature: 0.7 # Sampling temperature
|
||||
|
||||
- name: openrouter-gemini-2.5-flash
|
||||
display_name: Gemini 2.5 Flash (OpenRouter)
|
||||
use: langchain_openai:ChatOpenAI
|
||||
model: google/gemini-2.5-flash-preview
|
||||
api_key: $OPENROUTER_API_KEY
|
||||
api_key: $OPENAI_API_KEY # OpenRouter still uses the OpenAI-compatible field name here
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
|
||||
- name: gpt-5-responses
|
||||
@@ -185,39 +181,50 @@ That prompt is intended for coding agents. It tells the agent to clone the repo
|
||||
```
|
||||
|
||||
- Codex CLI reads `~/.codex/auth.json`
|
||||
- Claude Code accepts `CLAUDE_CODE_OAUTH_TOKEN`, `ANTHROPIC_AUTH_TOKEN`, `CLAUDE_CODE_CREDENTIALS_PATH`, or `~/.claude/.credentials.json`
|
||||
- ACP agent entries are separate from model providers — if you configure `acp_agents.codex`, point it at a Codex ACP adapter such as `npx -y @zed-industries/codex-acp`
|
||||
- On macOS, export Claude Code auth explicitly if needed:
|
||||
- The Codex Responses endpoint currently rejects `max_tokens` and `max_output_tokens`, so `CodexChatModel` does not expose a request-level token cap
|
||||
- Claude Code accepts `CLAUDE_CODE_OAUTH_TOKEN`, `ANTHROPIC_AUTH_TOKEN`, `CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR`, `CLAUDE_CODE_CREDENTIALS_PATH`, or plaintext `~/.claude/.credentials.json`
|
||||
- ACP agent entries are separate from model providers. If you configure `acp_agents.codex`, point it at a Codex ACP adapter such as `npx -y @zed-industries/codex-acp`; the standard `codex` CLI binary is not ACP-compatible by itself
|
||||
- On macOS, DeerFlow does not probe Keychain automatically. Export Claude Code auth explicitly if needed:
|
||||
|
||||
```bash
|
||||
eval "$(python3 scripts/export_claude_code_oauth.py --print-export)"
|
||||
```
|
||||
|
||||
4. **Set API keys for your configured model(s)**
|
||||
|
||||
Choose one of the following methods:
|
||||
|
||||
- Option A: Edit the `.env` file in the project root (Recommended)
|
||||
|
||||
API keys can also be set manually in `.env` (recommended) or exported in your shell:
|
||||
|
||||
```bash
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
TAVILY_API_KEY=your-tavily-api-key
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
# OpenRouter also uses OPENAI_API_KEY when your config uses langchain_openai:ChatOpenAI + base_url.
|
||||
# Add other provider keys as needed
|
||||
INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
```
|
||||
|
||||
</details>
|
||||
- Option B: Export environment variables in your shell
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=your-openai-api-key
|
||||
```
|
||||
|
||||
For CLI-backed providers:
|
||||
- Codex CLI: `~/.codex/auth.json`
|
||||
- Claude Code OAuth: explicit env/file handoff or `~/.claude/.credentials.json`
|
||||
|
||||
- Option C: Edit `config.yaml` directly (Not recommended for production)
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- name: gpt-4
|
||||
api_key: your-actual-api-key-here # Replace placeholder
|
||||
```
|
||||
|
||||
### Running the Application
|
||||
|
||||
#### Deployment Sizing
|
||||
|
||||
Use the table below as a practical starting point when choosing how to run DeerFlow:
|
||||
|
||||
| Deployment target | Starting point | Recommended | Notes |
|
||||
|---------|-----------|------------|-------|
|
||||
| Local evaluation / `make dev` | 4 vCPU, 8 GB RAM, 20 GB free SSD | 8 vCPU, 16 GB RAM | Good for one developer or one light session with hosted model APIs. `2 vCPU / 4 GB` is usually not enough. |
|
||||
| Docker development / `make docker-start` | 4 vCPU, 8 GB RAM, 25 GB free SSD | 8 vCPU, 16 GB RAM | Image builds, bind mounts, and sandbox containers need more headroom than pure local dev. |
|
||||
| Long-running server / `make up` | 8 vCPU, 16 GB RAM, 40 GB free SSD | 16 vCPU, 32 GB RAM | Preferred for shared use, multi-agent runs, report generation, or heavier sandbox workloads. |
|
||||
|
||||
- These numbers cover DeerFlow itself. If you also host a local LLM, size that service separately.
|
||||
- Linux plus Docker is the recommended deployment target for a persistent server. macOS and Windows are best treated as development or evaluation environments.
|
||||
- If CPU or memory usage stays pinned, reduce concurrent runs first, then move to the next sizing tier.
|
||||
|
||||
#### Option 1: Docker (Recommended)
|
||||
|
||||
**Development** (hot-reload, source mounts):
|
||||
@@ -254,7 +261,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||
|
||||
If you prefer running services locally:
|
||||
|
||||
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root (can be overridden via `DEER_FLOW_CONFIG_PATH`). Run `make doctor` to verify your setup before starting.
|
||||
Prerequisite: complete the "Configuration" steps above first (`make config` and model API keys). `make dev` requires a valid configuration file (defaults to `config.yaml` in the project root; can be overridden via `DEER_FLOW_CONFIG_PATH`).
|
||||
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
||||
|
||||
1. **Check prerequisites**:
|
||||
@@ -368,7 +375,6 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
|
||||
| Telegram | Bot API (long-polling) | Easy |
|
||||
| Slack | Socket Mode | Moderate |
|
||||
| Feishu / Lark | WebSocket | Moderate |
|
||||
| WeChat | Tencent iLink (long-polling) | Moderate |
|
||||
| WeCom | WebSocket | Moderate |
|
||||
|
||||
**Configuration in `config.yaml`:**
|
||||
@@ -413,19 +419,6 @@ channels:
|
||||
bot_token: $TELEGRAM_BOT_TOKEN
|
||||
allowed_users: [] # empty = allow all
|
||||
|
||||
wechat:
|
||||
enabled: false
|
||||
bot_token: $WECHAT_BOT_TOKEN
|
||||
ilink_bot_id: $WECHAT_ILINK_BOT_ID
|
||||
qrcode_login_enabled: true # optional: allow first-time QR bootstrap when bot_token is absent
|
||||
allowed_users: [] # empty = allow all
|
||||
polling_timeout: 35
|
||||
state_dir: ./.deer-flow/wechat/state
|
||||
max_inbound_image_bytes: 20971520
|
||||
max_outbound_image_bytes: 20971520
|
||||
max_inbound_file_bytes: 52428800
|
||||
max_outbound_file_bytes: 52428800
|
||||
|
||||
# Optional: per-channel / per-user session settings
|
||||
session:
|
||||
assistant_id: mobile-agent # custom agent names are also supported here
|
||||
@@ -459,10 +452,6 @@ SLACK_APP_TOKEN=xapp-...
|
||||
FEISHU_APP_ID=cli_xxxx
|
||||
FEISHU_APP_SECRET=your_app_secret
|
||||
|
||||
# WeChat iLink
|
||||
WECHAT_BOT_TOKEN=your_ilink_bot_token
|
||||
WECHAT_ILINK_BOT_ID=your_ilink_bot_id
|
||||
|
||||
# WeCom
|
||||
WECOM_BOT_ID=your_bot_id
|
||||
WECOM_BOT_SECRET=your_bot_secret
|
||||
@@ -488,14 +477,6 @@ WECOM_BOT_SECRET=your_bot_secret
|
||||
3. Under **Events**, subscribe to `im.message.receive_v1` and select **Long Connection** mode.
|
||||
4. Copy the App ID and App Secret. Set `FEISHU_APP_ID` and `FEISHU_APP_SECRET` in `.env` and enable the channel in `config.yaml`.
|
||||
|
||||
**WeChat Setup**
|
||||
|
||||
1. Enable the `wechat` channel in `config.yaml`.
|
||||
2. Either set `WECHAT_BOT_TOKEN` in `.env`, or set `qrcode_login_enabled: true` for first-time QR bootstrap.
|
||||
3. When `bot_token` is absent and QR bootstrap is enabled, watch backend logs for the QR content returned by iLink and complete the binding flow.
|
||||
4. After the QR flow succeeds, DeerFlow persists the acquired token under `state_dir` for later restarts.
|
||||
5. For Docker Compose deployments, keep `state_dir` on a persistent volume so the `get_updates_buf` cursor and saved auth state survive restarts.
|
||||
|
||||
**WeCom Setup**
|
||||
|
||||
1. Create a bot on the WeCom AI Bot platform and obtain the `bot_id` and `bot_secret`.
|
||||
|
||||
@@ -40,7 +40,6 @@ https://github.com/user-attachments/assets/a8bcadc4-e040-4cf2-8fda-dd768b999c18
|
||||
- [快速开始](#快速开始)
|
||||
- [配置](#配置)
|
||||
- [运行应用](#运行应用)
|
||||
- [部署建议与资源规划](#部署建议与资源规划)
|
||||
- [方式一:Docker(推荐)](#方式一docker推荐)
|
||||
- [方式二:本地开发](#方式二本地开发)
|
||||
- [进阶配置](#进阶配置)
|
||||
@@ -151,20 +150,6 @@ https://github.com/user-attachments/assets/a8bcadc4-e040-4cf2-8fda-dd768b999c18
|
||||
|
||||
### 运行应用
|
||||
|
||||
#### 部署建议与资源规划
|
||||
|
||||
可以先按下面的资源档位来选择 DeerFlow 的运行方式:
|
||||
|
||||
| 部署场景 | 起步配置 | 推荐配置 | 说明 |
|
||||
|---------|-----------|------------|-------|
|
||||
| 本地体验 / `make dev` | 4 vCPU、8 GB 内存、20 GB SSD 可用空间 | 8 vCPU、16 GB 内存 | 适合单个开发者或单个轻量会话,且模型走外部 API。`2 核 / 4 GB` 通常跑不稳。 |
|
||||
| Docker 开发 / `make docker-start` | 4 vCPU、8 GB 内存、25 GB SSD 可用空间 | 8 vCPU、16 GB 内存 | 镜像构建、源码挂载和 sandbox 容器都会比纯本地模式更吃资源。 |
|
||||
| 长期运行服务 / `make up` | 8 vCPU、16 GB 内存、40 GB SSD 可用空间 | 16 vCPU、32 GB 内存 | 更适合共享环境、多 agent 任务、报告生成或更重的 sandbox 负载。 |
|
||||
|
||||
- 上面的配置只覆盖 DeerFlow 本身;如果你还要本机部署本地大模型,请单独为模型服务预留资源。
|
||||
- 持续运行的服务更推荐使用 Linux + Docker。macOS 和 Windows 更适合作为开发机或体验环境。
|
||||
- 如果 CPU 或内存长期打满,先降低并发会话或重任务数量,再考虑升级到更高一档配置。
|
||||
|
||||
#### 方式一:Docker(推荐)
|
||||
|
||||
**开发模式**(支持热更新,挂载源码):
|
||||
|
||||
+13
-27
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
|
||||
|
||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||
@@ -216,9 +216,6 @@ FastAPI application on port 8001 with health check at `GET /health`.
|
||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||
|
||||
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||
|
||||
@@ -232,7 +229,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
|
||||
**Virtual Path System**:
|
||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||
|
||||
@@ -272,7 +269,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
||||
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
||||
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||
- `image_search/` - Image search via DuckDuckGo
|
||||
|
||||
### MCP System (`packages/harness/deerflow/mcp/`)
|
||||
@@ -341,27 +338,18 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
|
||||
|
||||
**Components**:
|
||||
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
|
||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
|
||||
- `prompt.py` - Prompt templates for memory updates
|
||||
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
|
||||
|
||||
**Per-User Isolation**:
|
||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||
- Absolute `storage_path` in config opts out of per-user isolation
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||
|
||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
|
||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
||||
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
||||
|
||||
**Workflow**:
|
||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
|
||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
|
||||
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
||||
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
|
||||
3. Background thread invokes LLM to extract context updates and facts
|
||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||
|
||||
@@ -369,7 +357,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
||||
|
||||
**Configuration** (`config.yaml` → `memory`):
|
||||
- `enabled` / `injection_enabled` - Master switches
|
||||
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
|
||||
- `storage_path` - Path to memory.json
|
||||
- `debounce_seconds` - Wait time before processing (default: 30)
|
||||
- `model_name` - LLM for updates (null = default model)
|
||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||
@@ -407,16 +395,14 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
|
||||
**Architecture**: Imports the same `deerflow` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
|
||||
|
||||
**Agent Conversation** (replaces LangGraph Server):
|
||||
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
|
||||
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
|
||||
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
|
||||
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
|
||||
- `"custom"` — forwarded from `StreamWriter`
|
||||
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
|
||||
- `chat(message, thread_id)` — synchronous, returns final text
|
||||
- `stream(message, thread_id)` — yields `StreamEvent` aligned with LangGraph SSE protocol:
|
||||
- `"values"` — full state snapshot (title, messages, artifacts)
|
||||
- `"messages-tuple"` — per-message update (AI text, tool calls, tool results)
|
||||
- `"end"` — stream finished
|
||||
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
|
||||
- Supports `checkpointer` parameter for state persistence across turns
|
||||
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
||||
- See [docs/STREAMING.md](docs/STREAMING.md) for the full design: why Gateway and DeerFlowClient are parallel paths, LangGraph's `stream_mode` semantics, the per-id dedup invariants, and regression testing strategy
|
||||
|
||||
**Gateway Equivalent Methods** (replaces Gateway API):
|
||||
|
||||
|
||||
+1
-1
@@ -88,4 +88,4 @@ COPY --from=builder /app/backend ./backend
|
||||
EXPOSE 8001 2024
|
||||
|
||||
# Default command (can be overridden in docker-compose)
|
||||
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run --no-sync uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
||||
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ gateway:
|
||||
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||
|
||||
test:
|
||||
PYTHONPATH=. uv run pytest tests/unittest -v
|
||||
PYTHONPATH=. uv run pytest tests/ -v
|
||||
|
||||
lint:
|
||||
uvx ruff check .
|
||||
|
||||
@@ -9,12 +9,10 @@ import re
|
||||
import threading
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.plugins.auth.security.actor_context import bind_user_actor_context
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -299,35 +297,15 @@ class FeishuChannel(Channel):
|
||||
text = msg.text
|
||||
for file in files:
|
||||
if file.get("image_key"):
|
||||
virtual_path = await self._receive_single_file(
|
||||
msg.thread_ts,
|
||||
file["image_key"],
|
||||
"image",
|
||||
thread_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
virtual_path = await self._receive_single_file(msg.thread_ts, file["image_key"], "image", thread_id)
|
||||
text = text.replace("[image]", virtual_path, 1)
|
||||
elif file.get("file_key"):
|
||||
virtual_path = await self._receive_single_file(
|
||||
msg.thread_ts,
|
||||
file["file_key"],
|
||||
"file",
|
||||
thread_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
virtual_path = await self._receive_single_file(msg.thread_ts, file["file_key"], "file", thread_id)
|
||||
text = text.replace("[file]", virtual_path, 1)
|
||||
msg.text = text
|
||||
return msg
|
||||
|
||||
async def _receive_single_file(
|
||||
self,
|
||||
message_id: str,
|
||||
file_key: str,
|
||||
type: Literal["image", "file"],
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> str:
|
||||
async def _receive_single_file(self, message_id: str, file_key: str, type: Literal["image", "file"], thread_id: str) -> str:
|
||||
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
|
||||
|
||||
def inner():
|
||||
@@ -366,51 +344,49 @@ class FeishuChannel(Channel):
|
||||
return f"Failed to obtain the [{type}]"
|
||||
|
||||
paths = get_paths()
|
||||
with bind_user_actor_context(user_id):
|
||||
effective_user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=effective_user_id)
|
||||
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=effective_user_id).resolve()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
|
||||
|
||||
ext = "png" if type == "image" else "bin"
|
||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||
ext = "png" if type == "image" else "bin"
|
||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||
|
||||
# Sanitize filename: preserve extension, replace path chars in name part
|
||||
if "." in raw_filename:
|
||||
name_part, ext = raw_filename.rsplit(".", 1)
|
||||
name_part = re.sub(r"[./\\]", "_", name_part)
|
||||
filename = f"{name_part}.{ext}"
|
||||
else:
|
||||
filename = re.sub(r"[./\\]", "_", raw_filename)
|
||||
resolved_target = uploads_dir / filename
|
||||
# Sanitize filename: preserve extension, replace path chars in name part
|
||||
if "." in raw_filename:
|
||||
name_part, ext = raw_filename.rsplit(".", 1)
|
||||
name_part = re.sub(r"[./\\]", "_", name_part)
|
||||
filename = f"{name_part}.{ext}"
|
||||
else:
|
||||
filename = re.sub(r"[./\\]", "_", raw_filename)
|
||||
resolved_target = uploads_dir / filename
|
||||
|
||||
def down_load():
|
||||
# use thread_lock to avoid filename conflicts when writing
|
||||
with self._thread_lock:
|
||||
resolved_target.write_bytes(content)
|
||||
def down_load():
|
||||
# use thread_lock to avoid filename conflicts when writing
|
||||
with self._thread_lock:
|
||||
resolved_target.write_bytes(content)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(down_load)
|
||||
except Exception:
|
||||
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
|
||||
return f"Failed to obtain the [{type}]"
|
||||
try:
|
||||
await asyncio.to_thread(down_load)
|
||||
except Exception:
|
||||
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
|
||||
return f"Failed to obtain the [{type}]"
|
||||
|
||||
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
||||
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
||||
|
||||
try:
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
if sandbox_id != "local":
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
|
||||
return f"Failed to obtain the [{type}]"
|
||||
sandbox.update_file(virtual_path, content)
|
||||
except Exception:
|
||||
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
|
||||
return f"Failed to obtain the [{type}]"
|
||||
try:
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
if sandbox_id != "local":
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
|
||||
return f"Failed to obtain the [{type}]"
|
||||
sandbox.update_file(virtual_path, content)
|
||||
except Exception:
|
||||
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
|
||||
return f"Failed to obtain the [{type}]"
|
||||
|
||||
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
|
||||
return virtual_path
|
||||
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
|
||||
return virtual_path
|
||||
|
||||
# -- message formatting ------------------------------------------------
|
||||
|
||||
|
||||
@@ -8,17 +8,14 @@ import mimetypes
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langgraph_sdk.errors import ConflictError
|
||||
|
||||
from app.plugins.auth.security.actor_context import bind_user_actor_context
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,7 +37,6 @@ CHANNEL_CAPABILITIES = {
|
||||
"feishu": {"supports_streaming": True},
|
||||
"slack": {"supports_streaming": False},
|
||||
"telegram": {"supports_streaming": False},
|
||||
"wechat": {"supports_streaming": False},
|
||||
"wecom": {"supports_streaming": True},
|
||||
}
|
||||
|
||||
@@ -82,24 +78,7 @@ async def _read_wecom_inbound_file(file_info: dict[str, Any], client: httpx.Asyn
|
||||
return decrypt_file(data, aeskey)
|
||||
|
||||
|
||||
async def _read_wechat_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
|
||||
raw_path = file_info.get("path")
|
||||
if isinstance(raw_path, str) and raw_path.strip():
|
||||
try:
|
||||
return await asyncio.to_thread(Path(raw_path).read_bytes)
|
||||
except OSError:
|
||||
logger.exception("[Manager] failed to read WeChat inbound file from local path: %s", raw_path)
|
||||
return None
|
||||
|
||||
full_url = file_info.get("full_url")
|
||||
if isinstance(full_url, str) and full_url.strip():
|
||||
return await _read_http_inbound_file({"url": full_url}, client)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
register_inbound_file_reader("wecom", _read_wecom_inbound_file)
|
||||
register_inbound_file_reader("wechat", _read_wechat_inbound_file)
|
||||
|
||||
|
||||
class InvalidChannelSessionConfigError(ValueError):
|
||||
@@ -329,7 +308,7 @@ def _format_artifact_text(artifacts: list[str]) -> str:
|
||||
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
||||
|
||||
|
||||
def _resolve_attachments(thread_id: str, artifacts: list[str], *, user_id: str | None = None) -> list[ResolvedAttachment]:
|
||||
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
||||
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
||||
|
||||
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
|
||||
@@ -343,40 +322,38 @@ def _resolve_attachments(thread_id: str, artifacts: list[str], *, user_id: str |
|
||||
|
||||
attachments: list[ResolvedAttachment] = []
|
||||
paths = get_paths()
|
||||
with bind_user_actor_context(user_id):
|
||||
effective_user_id = get_effective_user_id()
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=effective_user_id).resolve()
|
||||
for virtual_path in artifacts:
|
||||
# Security: only allow files from the agent outputs directory
|
||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||
continue
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
||||
for virtual_path in artifacts:
|
||||
# Security: only allow files from the agent outputs directory
|
||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||
continue
|
||||
try:
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
||||
# Verify the resolved path is actually under the outputs directory
|
||||
# (guards against path-traversal even after prefix check)
|
||||
try:
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=effective_user_id)
|
||||
# Verify the resolved path is actually under the outputs directory
|
||||
# (guards against path-traversal even after prefix check)
|
||||
try:
|
||||
actual.resolve().relative_to(outputs_dir)
|
||||
except ValueError:
|
||||
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
||||
continue
|
||||
if not actual.is_file():
|
||||
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
||||
continue
|
||||
mime, _ = mimetypes.guess_type(str(actual))
|
||||
mime = mime or "application/octet-stream"
|
||||
attachments.append(
|
||||
ResolvedAttachment(
|
||||
virtual_path=virtual_path,
|
||||
actual_path=actual,
|
||||
filename=actual.name,
|
||||
mime_type=mime,
|
||||
size=actual.stat().st_size,
|
||||
is_image=mime.startswith("image/"),
|
||||
)
|
||||
actual.resolve().relative_to(outputs_dir)
|
||||
except ValueError:
|
||||
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
||||
continue
|
||||
if not actual.is_file():
|
||||
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
||||
continue
|
||||
mime, _ = mimetypes.guess_type(str(actual))
|
||||
mime = mime or "application/octet-stream"
|
||||
attachments.append(
|
||||
ResolvedAttachment(
|
||||
virtual_path=virtual_path,
|
||||
actual_path=actual,
|
||||
filename=actual.name,
|
||||
mime_type=mime,
|
||||
size=actual.stat().st_size,
|
||||
is_image=mime.startswith("image/"),
|
||||
)
|
||||
except (ValueError, OSError) as exc:
|
||||
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
||||
)
|
||||
except (ValueError, OSError) as exc:
|
||||
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
||||
return attachments
|
||||
|
||||
|
||||
@@ -384,15 +361,13 @@ def _prepare_artifact_delivery(
|
||||
thread_id: str,
|
||||
response_text: str,
|
||||
artifacts: list[str],
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, list[ResolvedAttachment]]:
|
||||
"""Resolve attachments and append filename fallbacks to the text response."""
|
||||
attachments: list[ResolvedAttachment] = []
|
||||
if not artifacts:
|
||||
return response_text, attachments
|
||||
|
||||
attachments = _resolve_attachments(thread_id, artifacts, user_id=user_id)
|
||||
attachments = _resolve_attachments(thread_id, artifacts)
|
||||
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
|
||||
unresolved = [path for path in artifacts if path not in resolved_virtuals]
|
||||
|
||||
@@ -415,8 +390,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
|
||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||
|
||||
with bind_user_actor_context(msg.user_id):
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
|
||||
created: list[dict[str, Any]] = []
|
||||
@@ -750,12 +724,7 @@ class ChannelManager:
|
||||
len(artifacts),
|
||||
)
|
||||
|
||||
response_text, attachments = _prepare_artifact_delivery(
|
||||
thread_id,
|
||||
response_text,
|
||||
artifacts,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
||||
|
||||
if not response_text:
|
||||
if attachments:
|
||||
@@ -846,12 +815,7 @@ class ChannelManager:
|
||||
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
||||
response_text = _extract_response_text(result)
|
||||
artifacts = _extract_artifacts(result)
|
||||
response_text, attachments = _prepare_artifact_delivery(
|
||||
thread_id,
|
||||
response_text,
|
||||
artifacts,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
||||
|
||||
if not response_text:
|
||||
if attachments:
|
||||
|
||||
@@ -18,7 +18,6 @@ _CHANNEL_REGISTRY: dict[str, str] = {
|
||||
"feishu": "app.channels.feishu:FeishuChannel",
|
||||
"slack": "app.channels.slack:SlackChannel",
|
||||
"telegram": "app.channels.telegram:TelegramChannel",
|
||||
"wechat": "app.channels.wechat:WechatChannel",
|
||||
"wecom": "app.channels.wecom:WeComChannel",
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from .app import app, create_app
|
||||
from .config import GatewayConfig, get_gateway_config
|
||||
|
||||
__all__ = ["GatewayConfig", "app", "get_gateway_config", "register_app"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "app":
|
||||
from .app import app
|
||||
|
||||
return app
|
||||
if name == "GatewayConfig":
|
||||
from .config import GatewayConfig
|
||||
|
||||
return GatewayConfig
|
||||
if name == "get_gateway_config":
|
||||
from .config import get_gateway_config
|
||||
|
||||
return get_gateway_config
|
||||
if name == "register_app":
|
||||
from .registrar import register_app
|
||||
|
||||
return register_app
|
||||
raise AttributeError(name)
|
||||
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]
|
||||
|
||||
+385
-4
@@ -1,8 +1,389 @@
|
||||
from app.gateway.registrar import register_app
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
memory,
|
||||
models,
|
||||
runs,
|
||||
skills,
|
||||
suggestions,
|
||||
thread_runs,
|
||||
threads,
|
||||
uploads,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app():
|
||||
return register_app()
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Auto-create the admin user on first boot if no users exist.
|
||||
|
||||
After admin creation, migrate orphan threads from the LangGraph
|
||||
store (metadata.owner_id unset) to the admin account. This is the
|
||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||
authentication have existing LangGraph thread data that needs an
|
||||
owner assigned.
|
||||
|
||||
No SQL persistence migration is needed: the four owner_id columns
|
||||
(threads_meta, runs, run_events, feedback) only come into existence
|
||||
alongside the auth module via create_all, so freshly created tables
|
||||
never contain NULL-owner rows. "Existing persistence DB + new auth"
|
||||
is not a supported upgrade path — fresh install or wipe-and-retry.
|
||||
|
||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve
|
||||
races during admin creation. Only the worker that successfully
|
||||
creates/updates the admin prints the password; losers silently skip.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
provider = get_local_provider()
|
||||
user_count = await provider.count_users()
|
||||
|
||||
admin = None
|
||||
fresh_admin_created = False
|
||||
|
||||
if user_count == 0:
|
||||
password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||
fresh_admin_created = True
|
||||
except ValueError:
|
||||
return # Another worker already created the admin.
|
||||
else:
|
||||
# Admin exists but setup never completed — reset password so operator
|
||||
# can always find it in the console without needing the CLI.
|
||||
# Multi-worker guard: if admin was created less than 30s ago, another
|
||||
# worker just created it and will print the password — skip reset.
|
||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||
if admin and admin.needs_setup:
|
||||
import time
|
||||
|
||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||
if age >= 30:
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
from app.gateway.auth.password import hash_password_async
|
||||
|
||||
password = secrets.token_urlsafe(16)
|
||||
admin.password_hash = await hash_password_async(password)
|
||||
admin.token_version += 1
|
||||
await provider.update_user(admin)
|
||||
|
||||
cred_path = write_initial_credentials(admin.email, password, label="reset")
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account setup incomplete — password reset")
|
||||
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if admin is None:
|
||||
return # Nothing to bind orphans to.
|
||||
|
||||
admin_id = str(admin.id)
|
||||
|
||||
# LangGraph store orphan migration — non-fatal.
|
||||
# This covers the "no-auth → with-auth" upgrade path for users
|
||||
# whose existing LangGraph thread metadata has no owner_id set.
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
try:
|
||||
migrated = await _migrate_orphaned_threads(store, admin_id)
|
||||
if migrated:
|
||||
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
|
||||
except Exception:
|
||||
logger.exception("LangGraph thread migration failed (non-fatal)")
|
||||
|
||||
if fresh_admin_created:
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
|
||||
cred_path = write_initial_credentials(admin.email, password, label="initial") # noqa: F821 — defined in the fresh_admin branch
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account created on first boot")
|
||||
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
app = register_app()
|
||||
async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
||||
"""Paginated async iterator over a LangGraph store namespace.
|
||||
|
||||
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
|
||||
loop so that environments with more than one page of orphans do
|
||||
not silently lose data. Terminates when a page is empty OR when a
|
||||
short page arrives (indicating the last page).
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
batch = await store.asearch(namespace, limit=page_size, offset=offset)
|
||||
if not batch:
|
||||
return
|
||||
for item in batch:
|
||||
yield item
|
||||
if len(batch) < page_size:
|
||||
return
|
||||
offset += page_size
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
||||
|
||||
Uses cursor pagination so all orphans are migrated regardless of
|
||||
count. Returns the number of rows migrated.
|
||||
"""
|
||||
migrated = 0
|
||||
async for item in _iter_store_items(store, ("threads",)):
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("owner_id"):
|
||||
metadata["owner_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
return migrated
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
|
||||
# Load config and check necessary environment variables at startup
|
||||
try:
|
||||
get_app_config()
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
config = get_gateway_config()
|
||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Ensure admin user exists (auto-create on first boot)
|
||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||
await _ensure_admin_user(app)
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
|
||||
channel_service = await start_channel_service()
|
||||
logger.info("Channel service started: %s", channel_service.get_status())
|
||||
except Exception:
|
||||
logger.exception("No IM channels configured or channel service failed to start")
|
||||
|
||||
yield
|
||||
|
||||
# Stop channel service on shutdown
|
||||
try:
|
||||
from app.channels.service import stop_channel_service
|
||||
|
||||
await stop_channel_service()
|
||||
except Exception:
|
||||
logger.exception("Failed to stop channel service")
|
||||
|
||||
logger.info("Shutting down API Gateway")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application instance.
|
||||
"""
|
||||
|
||||
app = FastAPI(
|
||||
title="DeerFlow API Gateway",
|
||||
description="""
|
||||
## DeerFlow API Gateway
|
||||
|
||||
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
|
||||
|
||||
### Features
|
||||
|
||||
- **Models Management**: Query and retrieve available AI models
|
||||
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
|
||||
- **Memory Management**: Access and manage global memory data for personalized conversations
|
||||
- **Skills Management**: Query and manage skills and their enabled status
|
||||
- **Artifacts**: Access thread artifacts and generated files
|
||||
- **Health Monitoring**: System health check endpoints
|
||||
|
||||
### Architecture
|
||||
|
||||
LangGraph requests are handled by nginx reverse proxy.
|
||||
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
|
||||
""",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "models",
|
||||
"description": "Operations for querying available AI models and their configurations",
|
||||
},
|
||||
{
|
||||
"name": "mcp",
|
||||
"description": "Manage Model Context Protocol (MCP) server configurations",
|
||||
},
|
||||
{
|
||||
"name": "memory",
|
||||
"description": "Access and manage global memory data for personalized conversations",
|
||||
},
|
||||
{
|
||||
"name": "skills",
|
||||
"description": "Manage skills and their configurations",
|
||||
},
|
||||
{
|
||||
"name": "artifacts",
|
||||
"description": "Access and download thread artifacts and generated files",
|
||||
},
|
||||
{
|
||||
"name": "uploads",
|
||||
"description": "Upload and manage user files for threads",
|
||||
},
|
||||
{
|
||||
"name": "threads",
|
||||
"description": "Manage DeerFlow thread-local filesystem data",
|
||||
},
|
||||
{
|
||||
"name": "agents",
|
||||
"description": "Create and manage custom agents with per-agent config and prompts",
|
||||
},
|
||||
{
|
||||
"name": "suggestions",
|
||||
"description": "Generate follow-up question suggestions for conversations",
|
||||
},
|
||||
{
|
||||
"name": "channels",
|
||||
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
|
||||
},
|
||||
{
|
||||
"name": "assistants-compat",
|
||||
"description": "LangGraph Platform-compatible assistants API (stub)",
|
||||
},
|
||||
{
|
||||
"name": "runs",
|
||||
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
|
||||
},
|
||||
{
|
||||
"name": "health",
|
||||
"description": "Health check and system status endpoints",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
|
||||
# In production, nginx handles CORS and no middleware is needed.
|
||||
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||
if cors_origins_env:
|
||||
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||
for origin in cors_origins:
|
||||
if origin == "*":
|
||||
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
|
||||
cors_origins = [o for o in cors_origins if o != "*"]
|
||||
break
|
||||
if cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
app.include_router(models.router)
|
||||
|
||||
# MCP API is mounted at /api/mcp
|
||||
app.include_router(mcp.router)
|
||||
|
||||
# Memory API is mounted at /api/memory
|
||||
app.include_router(memory.router)
|
||||
|
||||
# Skills API is mounted at /api/skills
|
||||
app.include_router(skills.router)
|
||||
|
||||
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
||||
app.include_router(artifacts.router)
|
||||
|
||||
# Uploads API is mounted at /api/threads/{thread_id}/uploads
|
||||
app.include_router(uploads.router)
|
||||
|
||||
# Thread cleanup API is mounted at /api/threads/{thread_id}
|
||||
app.include_router(threads.router)
|
||||
|
||||
# Agents API is mounted at /api/agents
|
||||
app.include_router(agents.router)
|
||||
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Auth API is mounted at /api/v1/auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
app.include_router(feedback.router)
|
||||
|
||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
# Stateless Runs API (stream/wait without a pre-existing thread)
|
||||
app.include_router(runs.router)
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict:
|
||||
"""Health check endpoint.
|
||||
|
||||
Returns:
|
||||
Service health status information.
|
||||
"""
|
||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create app instance for uvicorn
|
||||
app = create_app()
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Authentication module for DeerFlow.
|
||||
|
||||
This module provides:
|
||||
- JWT-based authentication
|
||||
- Provider Factory pattern for extensible auth methods
|
||||
- UserRepository interface for storage backends (SQLite)
|
||||
"""
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.models import User, UserResponse
|
||||
from app.gateway.auth.password import hash_password, verify_password
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"AuthConfig",
|
||||
"get_auth_config",
|
||||
"set_auth_config",
|
||||
# Errors
|
||||
"AuthErrorCode",
|
||||
"AuthErrorResponse",
|
||||
"TokenError",
|
||||
# JWT
|
||||
"TokenPayload",
|
||||
"create_access_token",
|
||||
"decode_token",
|
||||
# Password
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
# Models
|
||||
"User",
|
||||
"UserResponse",
|
||||
# Providers
|
||||
"AuthProvider",
|
||||
"LocalAuthProvider",
|
||||
# Repository
|
||||
"UserRepository",
|
||||
]
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Authentication configuration for DeerFlow."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup.
|
||||
|
||||
Note: the ``users`` table now lives in the shared persistence
|
||||
database managed by ``deerflow.persistence.engine``. The old
|
||||
``users_db_path`` config key has been removed — user storage is
|
||||
configured through ``config.database`` like every other table.
|
||||
"""
|
||||
|
||||
jwt_secret: str = Field(
|
||||
...,
|
||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
||||
)
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
if _auth_config is None:
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
||||
return _auth_config
|
||||
|
||||
|
||||
def set_auth_config(config: AuthConfig) -> None:
|
||||
"""Set the global AuthConfig instance (for testing)."""
|
||||
global _auth_config
|
||||
_auth_config = config
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Write initial admin credentials to a restricted file instead of logs.
|
||||
|
||||
Logging secrets to stdout/stderr is a well-known CodeQL finding
|
||||
(py/clear-text-logging-sensitive-data) — in production those logs
|
||||
get collected into ELK/Splunk/etc and become a secret sprawl
|
||||
source. This helper writes the credential to a 0600 file that only
|
||||
the process user can read, and returns the path so the caller can
|
||||
log **the path** (not the password) for the operator to pick up.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
_CREDENTIAL_FILE = Path(".deer-flow") / "admin_initial_credentials.txt"
|
||||
|
||||
|
||||
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
||||
"""Write the admin email + password to ``.deer-flow/admin_initial_credentials.txt``.
|
||||
|
||||
Creates the parent directory if it does not exist. Sets the file
|
||||
mode to 0600 so only the owning process user can read it.
|
||||
|
||||
``label`` distinguishes "initial" (fresh creation) from "reset"
|
||||
(password reset) in the file header, so an operator picking up
|
||||
the file after a restart can tell which event produced it.
|
||||
|
||||
Returns the absolute :class:`Path` to the file.
|
||||
"""
|
||||
_CREDENTIAL_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
content = (
|
||||
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
||||
)
|
||||
_CREDENTIAL_FILE.write_text(content)
|
||||
os.chmod(_CREDENTIAL_FILE, 0o600)
|
||||
return _CREDENTIAL_FILE.resolve()
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Typed error definitions for auth plugin."""
|
||||
"""Typed error definitions for auth module.
|
||||
|
||||
AuthErrorCode: exhaustive enum of all auth failure conditions.
|
||||
TokenError: exhaustive enum of JWT decode failures.
|
||||
AuthErrorResponse: structured error payload for HTTP responses.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
@@ -6,6 +11,8 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthErrorCode(StrEnum):
|
||||
"""Exhaustive list of auth error conditions."""
|
||||
|
||||
INVALID_CREDENTIALS = "invalid_credentials"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
TOKEN_INVALID = "token_invalid"
|
||||
@@ -13,21 +20,25 @@ class AuthErrorCode(StrEnum):
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
"""Exhaustive list of JWT decode failure reasons."""
|
||||
|
||||
EXPIRED = "expired"
|
||||
INVALID_SIGNATURE = "invalid_signature"
|
||||
MALFORMED = "malformed"
|
||||
|
||||
|
||||
class AuthErrorResponse(BaseModel):
|
||||
"""Structured error response — replaces bare `detail` strings."""
|
||||
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||
"""Map TokenError to AuthErrorCode — single source of truth."""
|
||||
if err == TokenError.EXPIRED:
|
||||
return AuthErrorCode.TOKEN_EXPIRED
|
||||
return AuthErrorCode.TOKEN_INVALID
|
||||
@@ -5,26 +5,44 @@ from datetime import UTC, datetime, timedelta
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
from app.plugins.auth.runtime.config_state import get_auth_config
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
"""JWT token payload."""
|
||||
|
||||
sub: str # user_id
|
||||
exp: datetime
|
||||
iat: datetime | None = None
|
||||
ver: int = 0
|
||||
ver: int = 0 # token_version — must match User.token_version
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID as string
|
||||
expires_delta: Optional custom expiry, defaults to 7 days
|
||||
token_version: User's current token_version for invalidation
|
||||
|
||||
Returns:
|
||||
Encoded JWT string
|
||||
"""
|
||||
config = get_auth_config()
|
||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||
"""Decode and validate a JWT token.
|
||||
|
||||
Returns:
|
||||
TokenPayload if valid, or a specific TokenError variant.
|
||||
"""
|
||||
config = get_auth_config()
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Local email/password authentication provider."""
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
|
||||
class LocalAuthProvider(AuthProvider):
|
||||
"""Email/password authentication provider using local database."""
|
||||
|
||||
def __init__(self, repository: UserRepository):
|
||||
"""Initialize with a UserRepository.
|
||||
|
||||
Args:
|
||||
repository: UserRepository implementation (SQLite)
|
||||
"""
|
||||
self._repo = repository
|
||||
|
||||
async def authenticate(self, credentials: dict) -> User | None:
|
||||
"""Authenticate with email and password.
|
||||
|
||||
Args:
|
||||
credentials: dict with 'email' and 'password' keys
|
||||
|
||||
Returns:
|
||||
User if authentication succeeds, None otherwise
|
||||
"""
|
||||
email = credentials.get("email")
|
||||
password = credentials.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return None
|
||||
|
||||
user = await self._repo.get_user_by_email(email)
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
if user.password_hash is None:
|
||||
# OAuth user without local password
|
||||
return None
|
||||
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_user(self, user_id: str) -> User | None:
|
||||
"""Get user by ID."""
|
||||
return await self._repo.get_user_by_id(user_id)
|
||||
|
||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
||||
"""Create a new local user.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Plain text password (will be hashed)
|
||||
system_role: Role to assign ("admin" or "user")
|
||||
needs_setup: If True, user must complete setup on first login
|
||||
|
||||
Returns:
|
||||
Created User instance
|
||||
"""
|
||||
password_hash = await hash_password_async(password) if password else None
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
system_role=system_role,
|
||||
needs_setup=needs_setup,
|
||||
)
|
||||
return await self._repo.create_user(user)
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID."""
|
||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await self._repo.count_users()
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update_user(user)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email."""
|
||||
return await self._repo.get_user_by_email(email)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""User Pydantic models for the auth plugin."""
|
||||
"""User Pydantic models for authentication."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
@@ -8,10 +8,13 @@ from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Return current UTC time (timezone-aware)."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Internal user representation."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||
@@ -19,13 +22,19 @@ class User(BaseModel):
|
||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||
system_role: Literal["admin", "user"] = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
# OAuth linkage (optional)
|
||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||
|
||||
# Auth lifecycle
|
||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info endpoint."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
system_role: Literal["admin", "user"]
|
||||
@@ -6,16 +6,28 @@ import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
async def hash_password_async(password: str) -> str:
|
||||
"""Hash a password using bcrypt (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password hashing.
|
||||
"""
|
||||
return await asyncio.to_thread(hash_password, password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password verification.
|
||||
"""
|
||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Auth provider abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, credentials: dict) -> "User | None":
|
||||
"""Authenticate user with given credentials.
|
||||
|
||||
Returns User if authentication succeeds, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> "User | None":
|
||||
"""Retrieve user by ID."""
|
||||
...
|
||||
|
||||
|
||||
# Import User at runtime to avoid circular imports
|
||||
from app.gateway.auth.models import User # noqa: E402
|
||||
@@ -0,0 +1,82 @@
|
||||
"""User repository interface for abstracting database operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
|
||||
class UserRepository(ABC):
|
||||
"""Abstract interface for user data storage.
|
||||
|
||||
Implement this interface to support different storage backends
|
||||
(SQLite)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
user: User object to create
|
||||
|
||||
Returns:
|
||||
Created User with ID assigned
|
||||
|
||||
Raises:
|
||||
ValueError: If email already exists
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User UUID as string
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user.
|
||||
|
||||
Args:
|
||||
user: User object with updated fields
|
||||
|
||||
Returns:
|
||||
Updated User
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g. 'github', 'google')
|
||||
oauth_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,116 @@
|
||||
"""SQLAlchemy-backed UserRepository implementation.
|
||||
|
||||
Uses the shared async session factory from
|
||||
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
||||
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
||||
``feedback``.
|
||||
|
||||
Constructor takes the session factory directly (same pattern as the
|
||||
other four repositories in ``deerflow.persistence.*``). Callers
|
||||
construct this after ``init_engine_from_config()`` has run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""Async user repository backed by the shared SQLAlchemy engine."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
# ── Converters ────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: UserRow) -> User:
|
||||
return User(
|
||||
id=UUID(row.id),
|
||||
email=row.email,
|
||||
password_hash=row.password_hash,
|
||||
system_role=row.system_role, # type: ignore[arg-type]
|
||||
# SQLite loses tzinfo on read; reattach UTC so downstream
|
||||
# code can compare timestamps reliably.
|
||||
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
||||
oauth_provider=row.oauth_provider,
|
||||
oauth_id=row.oauth_id,
|
||||
needs_setup=row.needs_setup,
|
||||
token_version=row.token_version,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _user_to_row(user: User) -> UserRow:
|
||||
return UserRow(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role,
|
||||
created_at=user.created_at,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
)
|
||||
|
||||
# ── CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
||||
row = self._user_to_row(user)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as exc:
|
||||
await session.rollback()
|
||||
raise ValueError(f"Email already registered: {user.email}") from exc
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, user_id)
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.email == email)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, str(user.id))
|
||||
if row is None:
|
||||
return user
|
||||
row.email = user.email
|
||||
row.password_hash = user.password_hash
|
||||
row.system_role = user.system_role
|
||||
row.oauth_provider = user.oauth_provider
|
||||
row.oauth_id = user.oauth_id
|
||||
row.needs_setup = user.needs_setup
|
||||
row.token_version = user.token_version
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
stmt = select(func.count()).select_from(UserRow)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
@@ -0,0 +1,91 @@
|
||||
"""CLI tool to reset an admin password.
|
||||
|
||||
Usage:
|
||||
python -m app.gateway.auth.reset_admin
|
||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
||||
|
||||
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
|
||||
(mode 0600) instead of printing it, so CI / log aggregators never see
|
||||
the cleartext secret.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
from app.gateway.auth.password import hash_password
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
async def _run(email: str | None) -> int:
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine_from_config,
|
||||
)
|
||||
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo = SQLiteUserRepository(sf)
|
||||
|
||||
if email:
|
||||
user = await repo.get_user_by_email(email)
|
||||
else:
|
||||
# Find first admin via direct SELECT — repository does not
|
||||
# expose a "first admin" helper and we do not want to add
|
||||
# one just for this CLI.
|
||||
async with sf() as session:
|
||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
user = None
|
||||
else:
|
||||
user = await repo.get_user_by_id(row.id)
|
||||
|
||||
if user is None:
|
||||
if email:
|
||||
print(f"Error: user '{email}' not found.", file=sys.stderr)
|
||||
else:
|
||||
print("Error: no admin user found.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
user.password_hash = hash_password(new_password)
|
||||
user.token_version += 1
|
||||
user.needs_setup = True
|
||||
await repo.update_user(user)
|
||||
|
||||
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
||||
print(f"Password reset for: {user.email}")
|
||||
print(f"Credentials written to: {cred_path} (mode 0600)")
|
||||
print("Next login will require setup (new email + password).")
|
||||
return 0
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||
args = parser.parse_args()
|
||||
|
||||
exit_code = asyncio.run(_run(args.email))
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,117 @@
|
||||
"""Global authentication middleware — fail-closed safety net.
|
||||
|
||||
Rejects unauthenticated requests to non-public paths with 401. When a
|
||||
request passes the cookie check, resolves the JWT payload to a real
|
||||
``User`` object and stamps it into both ``request.state.user`` and the
|
||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
||||
owner filtering works automatically via the sentinel pattern.
|
||||
|
||||
Fine-grained permission checks remain in authz.py decorators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
# Paths that never require authentication.
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
)
|
||||
|
||||
# Exact auth paths that are public (login/register/status check).
|
||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_public(path: str) -> bool:
|
||||
stripped = path.rstrip("/")
|
||||
if stripped in _PUBLIC_EXACT_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Strict auth gate: reject requests without a valid session.
|
||||
|
||||
Two-stage check for non-public paths:
|
||||
|
||||
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
|
||||
2. JWT validation via ``get_optional_user_from_request`` — return 401
|
||||
TOKEN_INVALID if the token is absent, malformed, expired, or the
|
||||
signed user does not exist / is stale
|
||||
|
||||
On success, stamps ``request.state.user`` and the
|
||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
||||
owner filters work downstream without every route needing a
|
||||
``@require_auth`` decorator. Routes that need per-resource
|
||||
authorization (e.g. "user A cannot read user B's thread by guessing
|
||||
the URL") should additionally use ``@require_permission(...,
|
||||
owner_check=True)`` for explicit enforcement — but authentication
|
||||
itself is fully handled here.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": {
|
||||
"code": AuthErrorCode.NOT_AUTHENTICATED,
|
||||
"message": "Authentication required",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Strict JWT validation: reject junk/expired tokens with 401
|
||||
# right here instead of silently passing through. This closes
|
||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||
# without this, non-isolation routes like /api/models would
|
||||
# accept any cookie-shaped string as authentication.
|
||||
#
|
||||
# We call the *strict* resolver so that fine-grained error
|
||||
# codes (token_expired, token_invalid, user_not_found, …)
|
||||
# propagate from AuthErrorCode, not get flattened into one
|
||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||
# bubble up, so we catch and render it as JSONResponse here.
|
||||
#
|
||||
# On success we stamp request.state.user and the contextvar
|
||||
# so repository-layer owner filters work downstream without
|
||||
# every route needing a decorator.
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
request.state.user = user
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
@@ -0,0 +1,269 @@
|
||||
"""Authorization decorators and context for DeerFlow.
|
||||
|
||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
||||
|
||||
**Usage:**
|
||||
|
||||
1. Use ``@require_auth`` on routes that need authentication
|
||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
||||
3. The decorator chain processes from bottom to top
|
||||
|
||||
**Example:**
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
# User is authenticated and has threads:read permission
|
||||
...
|
||||
|
||||
**Permission Model:**
|
||||
|
||||
- threads:read - View thread
|
||||
- threads:write - Create/update thread
|
||||
- threads:delete - Delete thread
|
||||
- runs:create - Run agent
|
||||
- runs:read - View run
|
||||
- runs:cancel - Cancel run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Permission constants
|
||||
class Permissions:
|
||||
"""Permission constants for resource:action format."""
|
||||
|
||||
# Threads
|
||||
THREADS_READ = "threads:read"
|
||||
THREADS_WRITE = "threads:write"
|
||||
THREADS_DELETE = "threads:delete"
|
||||
|
||||
# Runs
|
||||
RUNS_CREATE = "runs:create"
|
||||
RUNS_READ = "runs:read"
|
||||
RUNS_CANCEL = "runs:cancel"
|
||||
|
||||
|
||||
class AuthContext:
|
||||
"""Authentication context for the current request.
|
||||
|
||||
Stored in request.state.auth after require_auth decoration.
|
||||
|
||||
Attributes:
|
||||
user: The authenticated user, or None if anonymous
|
||||
permissions: List of permission strings (e.g., "threads:read")
|
||||
"""
|
||||
|
||||
__slots__ = ("user", "permissions")
|
||||
|
||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||
self.user = user
|
||||
self.permissions = permissions or []
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if user is authenticated."""
|
||||
return self.user is not None
|
||||
|
||||
def has_permission(self, resource: str, action: str) -> bool:
|
||||
"""Check if context has permission for resource:action.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads")
|
||||
action: Action name (e.g., "read")
|
||||
|
||||
Returns:
|
||||
True if user has permission
|
||||
"""
|
||||
permission = f"{resource}:{action}"
|
||||
return permission in self.permissions
|
||||
|
||||
def require_user(self) -> User:
|
||||
"""Get user or raise 401.
|
||||
|
||||
Raises:
|
||||
HTTPException 401 if not authenticated
|
||||
"""
|
||||
if not self.user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return self.user
|
||||
|
||||
|
||||
def get_auth_context(request: Request) -> AuthContext | None:
|
||||
"""Get AuthContext from request state."""
|
||||
return getattr(request.state, "auth", None)
|
||||
|
||||
|
||||
_ALL_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
async def _authenticate(request: Request) -> AuthContext:
|
||||
"""Authenticate request and return AuthContext.
|
||||
|
||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
||||
Returns AuthContext with user=None for anonymous requests.
|
||||
"""
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return AuthContext(user=None, permissions=[])
|
||||
|
||||
# In future, permissions could be stored in user record
|
||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
|
||||
|
||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""Decorator that authenticates the request and sets AuthContext.
|
||||
|
||||
Must be placed ABOVE other decorators (executes after them).
|
||||
|
||||
Usage:
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth # Bottom decorator (executes first after permission check)
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
auth: AuthContext = request.state.auth
|
||||
...
|
||||
|
||||
Raises:
|
||||
ValueError: If 'request' parameter is missing
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||
|
||||
# Authenticate and set context
|
||||
auth_context = await _authenticate(request)
|
||||
request.state.auth = auth_context
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_permission(
|
||||
resource: str,
|
||||
action: str,
|
||||
owner_check: bool = False,
|
||||
owner_filter_key: str = "owner_id",
|
||||
inject_record: bool = False,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Decorator that checks permission for resource:action.
|
||||
|
||||
Must be used AFTER @require_auth.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads", "runs")
|
||||
action: Action name (e.g., "read", "write", "delete")
|
||||
owner_check: If True, validates that the current user owns the resource.
|
||||
Requires 'thread_id' path parameter and performs ownership check.
|
||||
owner_filter_key: Field name for ownership filter (default: "owner_id")
|
||||
inject_record: If True and owner_check is True, injects the thread record
|
||||
into kwargs['thread_record'] for use in the handler.
|
||||
|
||||
Usage:
|
||||
# Simple permission check
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check (for /threads/{thread_id} endpoints)
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check and record injection
|
||||
@require_permission("threads", "delete", owner_check=True, inject_record=True)
|
||||
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
|
||||
# thread_record is injected if found
|
||||
...
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If authentication required but user is anonymous
|
||||
HTTPException 403: If user lacks permission
|
||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||
|
||||
auth: AuthContext = getattr(request.state, "auth", None)
|
||||
if auth is None:
|
||||
auth = await _authenticate(request)
|
||||
request.state.auth = auth
|
||||
|
||||
if not auth.is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Check permission
|
||||
if not auth.has_permission(resource, action):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission denied: {resource}:{action}",
|
||||
)
|
||||
|
||||
# Owner check for thread-specific resources.
|
||||
#
|
||||
# 2.0-rc moved thread metadata into the SQL persistence layer
|
||||
# (``threads_meta`` table). We verify ownership via
|
||||
# ``ThreadMetaStore.check_access`` instead of the LangGraph
|
||||
# store path that the original PR #1728 used. ``check_access``
|
||||
# returns True for missing rows (untracked legacy thread) and
|
||||
# for rows whose ``owner_id`` is NULL (shared / pre-auth data),
|
||||
# so this is a strict-deny check rather than strict-allow:
|
||||
# only an *existing* row with a *different* owner_id triggers
|
||||
# 404.
|
||||
#
|
||||
# ``inject_record`` is no longer supported — it was a
|
||||
# convenience for handlers that wanted the LangGraph store
|
||||
# blob; the SQL repo would need a different shape and no
|
||||
# caller in 2.0 needs it.
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
allowed = await thread_meta_repo.check_access(thread_id, str(auth.user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Thread {thread_id} not found",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,3 +0,0 @@
|
||||
from .lifespan import lifespan_manager
|
||||
|
||||
__all__ = ["lifespan_manager"]
|
||||
@@ -1,52 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]]
|
||||
|
||||
|
||||
class LifespanManager:
|
||||
"""FastAPI lifespan manager"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lifespans: list[LifespanFunc] = []
|
||||
|
||||
def register(self, func: LifespanFunc) -> LifespanFunc:
|
||||
"""
|
||||
Register a lifespan hook.
|
||||
|
||||
:param func: lifespan hook
|
||||
:return:
|
||||
"""
|
||||
if func not in self._lifespans:
|
||||
self._lifespans.append(func)
|
||||
return func
|
||||
|
||||
def build(self) -> LifespanFunc:
|
||||
"""
|
||||
Build the combined lifespan hook.
|
||||
|
||||
:return:
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def combined_lifespan(app: FastAPI): # noqa: ANN202
|
||||
state: dict[str, Any] = {}
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
for lifespan_fn in self._lifespans:
|
||||
result = await exit_stack.enter_async_context(lifespan_fn(app))
|
||||
if isinstance(result, dict):
|
||||
state.update(result)
|
||||
|
||||
for key, value in state.items():
|
||||
setattr(app.state, key, value)
|
||||
|
||||
yield state or None
|
||||
|
||||
return combined_lifespan
|
||||
|
||||
|
||||
# Singleton lifespan_manager instance
|
||||
lifespan_manager = LifespanManager()
|
||||
+26
-20
@@ -1,4 +1,8 @@
|
||||
"""CSRF protection middleware and helpers for cookie-based auth flows."""
|
||||
"""CSRF protection middleware for FastAPI.
|
||||
|
||||
Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
@@ -24,11 +28,16 @@ def generate_csrf_token() -> str:
|
||||
|
||||
|
||||
def should_check_csrf(request: Request) -> bool:
|
||||
"""Determine if a request needs CSRF validation."""
|
||||
"""Determine if a request needs CSRF validation.
|
||||
|
||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
||||
"""
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
return True
|
||||
@@ -39,18 +48,20 @@ _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/initialize",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_auth_endpoint(request: Request) -> bool:
|
||||
"""Check if the request is to an auth endpoint."""
|
||||
"""Check if the request is to an auth endpoint.
|
||||
|
||||
Auth endpoints don't need CSRF validation on first call (no token).
|
||||
"""
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Implement CSRF protection using the double-submit cookie pattern."""
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
@@ -76,13 +87,16 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# For auth endpoints that set up session, also set CSRF cookie
|
||||
if _is_auth and request.method == "POST":
|
||||
# Generate a new CSRF token for the session
|
||||
csrf_token = generate_csrf_token()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
httponly=False,
|
||||
secure=is_secure_request(request),
|
||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
||||
secure=is_https,
|
||||
samesite="strict",
|
||||
)
|
||||
|
||||
@@ -90,17 +104,9 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> str | None:
|
||||
"""Get the CSRF token from the current request's cookies."""
|
||||
"""Get the CSRF token from the current request's cookies.
|
||||
|
||||
This is useful for server-side rendering where you need to embed
|
||||
token in forms or headers.
|
||||
"""
|
||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CSRF_COOKIE_NAME",
|
||||
"CSRF_HEADER_NAME",
|
||||
"CSRFMiddleware",
|
||||
"generate_csrf_token",
|
||||
"get_csrf_token",
|
||||
"is_auth_endpoint",
|
||||
"is_secure_request",
|
||||
"should_check_csrf",
|
||||
]
|
||||
@@ -1,59 +0,0 @@
|
||||
from app.gateway.dependencies.checkpointer import (
|
||||
CurrentCheckpointer,
|
||||
get_checkpointer,
|
||||
)
|
||||
from app.plugins.auth.security.dependencies import (
|
||||
CurrentAuthService,
|
||||
CurrentUserRepository,
|
||||
get_auth_service,
|
||||
get_current_user_from_request,
|
||||
get_current_user_id,
|
||||
get_optional_user_from_request,
|
||||
get_user_repository,
|
||||
)
|
||||
from app.gateway.dependencies.db import (
|
||||
CurrentSession,
|
||||
CurrentSessionTransaction,
|
||||
get_db_session,
|
||||
get_db_session_transaction,
|
||||
)
|
||||
from app.gateway.dependencies.repositories import (
|
||||
CurrentFeedbackRepository,
|
||||
CurrentRunRepository,
|
||||
CurrentThreadMetaRepository,
|
||||
CurrentThreadMetaStorage,
|
||||
get_feedback_repository,
|
||||
get_run_repository,
|
||||
get_thread_meta_repository,
|
||||
get_thread_meta_storage,
|
||||
)
|
||||
from app.gateway.dependencies.stream_bridge import (
|
||||
CurrentStreamBridge,
|
||||
get_stream_bridge,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CurrentCheckpointer",
|
||||
"CurrentAuthService",
|
||||
"CurrentFeedbackRepository",
|
||||
"CurrentRunRepository",
|
||||
"CurrentSession",
|
||||
"CurrentSessionTransaction",
|
||||
"CurrentStreamBridge",
|
||||
"CurrentThreadMetaRepository",
|
||||
"CurrentThreadMetaStorage",
|
||||
"CurrentUserRepository",
|
||||
"get_auth_service",
|
||||
"get_checkpointer",
|
||||
"get_current_user_from_request",
|
||||
"get_current_user_id",
|
||||
"get_db_session",
|
||||
"get_db_session_transaction",
|
||||
"get_feedback_repository",
|
||||
"get_optional_user_from_request",
|
||||
"get_run_repository",
|
||||
"get_stream_bridge",
|
||||
"get_thread_meta_repository",
|
||||
"get_thread_meta_storage",
|
||||
"get_user_repository",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
|
||||
def get_checkpointer(request: Request) -> Checkpointer:
|
||||
"""Get checkpointer from app.state.persistence."""
|
||||
persistence = getattr(request.app.state, "persistence", None)
|
||||
if persistence is None:
|
||||
raise HTTPException(status_code=503, detail="Persistence not available")
|
||||
checkpointer = getattr(persistence, "checkpointer", None)
|
||||
if checkpointer is None:
|
||||
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||
return checkpointer
|
||||
|
||||
|
||||
CurrentCheckpointer = Annotated[Checkpointer, Depends(get_checkpointer)]
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession]:
|
||||
factory = getattr(request.app.state.persistence, "session_factory", None)
|
||||
if factory is None:
|
||||
raise HTTPException(status_code=503, detail="Database session factory not available")
|
||||
return factory
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
"""Open a session without auto-commit. Use for read-only endpoints."""
|
||||
session_factory = _get_session_factory(request)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_db_session_transaction(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
"""Open a session and commit on success, rollback on error."""
|
||||
session_factory = _get_session_factory(request)
|
||||
async with session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
CurrentSession = Annotated[AsyncSession, Depends(get_db_session)]
|
||||
CurrentSessionTransaction = Annotated[AsyncSession, Depends(get_db_session_transaction)]
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
|
||||
from app.infra.storage import ThreadMetaStorage
|
||||
from store.repositories.contracts import (
|
||||
FeedbackRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
|
||||
|
||||
def _require_state(request: Request, attr: str, label: str):
|
||||
value = getattr(request.app.state, attr, None)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return value
|
||||
|
||||
|
||||
def get_run_repository(request: Request) -> RunRepositoryProtocol:
|
||||
return _require_state(request, "run_store", "Run store")
|
||||
|
||||
|
||||
def get_thread_meta_repository(request: Request) -> ThreadMetaRepositoryProtocol:
|
||||
return _require_state(request, "thread_meta_repo", "Thread metadata store")
|
||||
|
||||
|
||||
def get_thread_meta_storage(request: Request) -> ThreadMetaStorage:
|
||||
return _require_state(request, "thread_meta_storage", "Thread metadata storage")
|
||||
|
||||
|
||||
def get_feedback_repository(request: Request) -> FeedbackRepositoryProtocol:
|
||||
return _require_state(request, "feedback_repo", "Feedback")
|
||||
|
||||
|
||||
CurrentRunRepository = Annotated[RunRepositoryProtocol, Depends(get_run_repository)]
|
||||
CurrentThreadMetaRepository = Annotated[ThreadMetaRepositoryProtocol, Depends(get_thread_meta_repository)]
|
||||
CurrentThreadMetaStorage = Annotated[ThreadMetaStorage, Depends(get_thread_meta_storage)]
|
||||
CurrentFeedbackRepository = Annotated[FeedbackRepositoryProtocol, Depends(get_feedback_repository)]
|
||||
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import StreamBridge
|
||||
|
||||
|
||||
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||
"""Get stream bridge from app.state."""
|
||||
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||
if bridge is None:
|
||||
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||
return bridge
|
||||
|
||||
|
||||
CurrentStreamBridge = Annotated[StreamBridge, Depends(get_stream_bridge)]
|
||||
@@ -0,0 +1,225 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||
``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunContext, RunManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
|
||||
# Initialize persistence engine BEFORE checkpointer so that
|
||||
# auto-create-database logic runs first (postgres backend).
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
|
||||
# Initialize repositories — one get_session_factory() call for all.
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
app.state.run_store = RunRepository(sf)
|
||||
app.state.feedback_repo = FeedbackRepository(sf)
|
||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||
else:
|
||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
app.state.run_store = MemoryRunStore()
|
||||
app.state.feedback_repo = None
|
||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
||||
|
||||
# RunManager with store backing for persistence
|
||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters -- called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _require(attr: str, label: str):
|
||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||
|
||||
def dep(request: Request):
|
||||
val = getattr(request.app.state, attr, None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return val
|
||||
|
||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||
return dep
|
||||
|
||||
|
||||
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager = _require("run_manager", "Run manager")
|
||||
get_checkpointer = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo = _require("feedback_repo", "Feedback")
|
||||
get_run_store = _require("run_store", "Run store")
|
||||
|
||||
|
||||
def get_store(request: Request):
|
||||
"""Return the global store (may be ``None`` if not configured)."""
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
||||
|
||||
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies. Callers that
|
||||
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||
to :func:`run_agent`.
|
||||
"""
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
return RunContext(
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||
thread_meta_repo=get_thread_meta_repo(request),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Cached singletons to avoid repeated instantiation per request
|
||||
_cached_local_provider: LocalAuthProvider | None = None
|
||||
_cached_repo: SQLiteUserRepository | None = None
|
||||
|
||||
|
||||
def get_local_provider() -> LocalAuthProvider:
|
||||
"""Get or create the cached LocalAuthProvider singleton.
|
||||
|
||||
Must be called after ``init_engine_from_config()`` — the shared
|
||||
session factory is required to construct the user repository.
|
||||
"""
|
||||
global _cached_local_provider, _cached_repo
|
||||
if _cached_repo is None:
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
|
||||
_cached_repo = SQLiteUserRepository(sf)
|
||||
if _cached_local_provider is None:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
|
||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
||||
return _cached_local_provider
|
||||
|
||||
|
||||
async def get_current_user_from_request(request: Request):
|
||||
"""Get the current authenticated user from the request cookie.
|
||||
|
||||
Raises HTTPException 401 if not authenticated.
|
||||
"""
|
||||
from app.gateway.auth import decode_token
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||
)
|
||||
|
||||
payload = decode_token(access_token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
||||
)
|
||||
|
||||
provider = get_local_provider()
|
||||
user = await provider.get_user(payload.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||
)
|
||||
|
||||
# Token version mismatch → password was changed, token is stale
|
||||
if user.token_version != payload.ver:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user_from_request(request: Request):
|
||||
"""Get optional authenticated user from request.
|
||||
|
||||
Returns None if not authenticated.
|
||||
"""
|
||||
try:
|
||||
return await get_current_user_from_request(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> str | None:
|
||||
"""Extract user_id from request cookie, or None if not authenticated.
|
||||
|
||||
Thin adapter that returns the string id for callers that only need
|
||||
identification (e.g., ``feedback.py``). Full-user callers should use
|
||||
``get_current_user_from_request`` or ``get_optional_user_from_request``.
|
||||
"""
|
||||
user = await get_optional_user_from_request(request)
|
||||
return str(user.id) if user else None
|
||||
@@ -0,0 +1,106 @@
|
||||
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||
|
||||
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||
so both modes validate tokens with the same secret and rules.
|
||||
|
||||
Two layers:
|
||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
auth = Auth()
|
||||
|
||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
|
||||
def _check_csrf(request) -> None:
|
||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
||||
|
||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
||||
proxied directly by nginx have the same CSRF protection.
|
||||
"""
|
||||
method = getattr(request, "method", "") or ""
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token mismatch.",
|
||||
)
|
||||
|
||||
|
||||
@auth.authenticate
|
||||
async def authenticate(request):
|
||||
"""Validate the session cookie, decode JWT, and check token_version.
|
||||
|
||||
Same validation chain as Gateway's get_current_user_from_request:
|
||||
cookie → decode JWT → DB lookup → token_version match
|
||||
Also enforces CSRF on state-changing methods.
|
||||
"""
|
||||
# CSRF check before authentication so forged cross-site requests
|
||||
# are rejected early, even if the cookie carries a valid JWT.
|
||||
_check_csrf(request)
|
||||
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
|
||||
payload = decode_token(token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Token error: {payload.value}",
|
||||
)
|
||||
|
||||
user = await get_local_provider().get_user(payload.sub)
|
||||
if user is None:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.token_version != payload.ver:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Token revoked (password changed)",
|
||||
)
|
||||
|
||||
return payload.sub
|
||||
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp owner_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["owner_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"owner_id": ctx.user.identity}
|
||||
@@ -5,17 +5,15 @@ from pathlib import Path
|
||||
from fastapi import HTTPException
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
virtual_path: The virtual path as seen inside the sandbox
|
||||
(e.g., /mnt/user-data/outputs/file.txt).
|
||||
user_id: Explicit user id override. Falls back to the current actor context.
|
||||
|
||||
Returns:
|
||||
The resolved filesystem path.
|
||||
@@ -24,8 +22,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: s
|
||||
HTTPException: If the path is invalid or outside allowed directories.
|
||||
"""
|
||||
try:
|
||||
resolved_user_id = get_effective_user_id() if user_id is None else user_id
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=resolved_user_id)
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
||||
except ValueError as e:
|
||||
status = 403 if "traversal" in str(e) else 400
|
||||
raise HTTPException(status_code=status, detail=str(e))
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from scalar_fastapi import AgentScalarConfig, get_scalar_api_reference
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from store.persistence import create_persistence
|
||||
|
||||
from app.gateway.common import lifespan_manager
|
||||
from app.gateway.router import router as gateway_router
|
||||
from app.infra.run_events import build_run_event_store
|
||||
from app.infra.storage import FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||
from app.plugins.auth.injection import install_route_guards, load_route_policy_registry, validate_route_policy_registry
|
||||
from app.plugins.auth.security import AuthMiddleware, CSRFMiddleware
|
||||
|
||||
STATIC_DIR = Path(__file__).resolve().parents[1] / "static"
|
||||
STATIC_MOUNT = "/api/static"
|
||||
SCALAR_JS_URL = f"{STATIC_MOUNT}/scalar.js"
|
||||
|
||||
|
||||
@lifespan_manager.register
|
||||
@asynccontextmanager
|
||||
async def init_persistence(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Initialize persistence layer (DB, checkpointer, store)."""
|
||||
app_persistence = await create_persistence()
|
||||
|
||||
await app_persistence.setup()
|
||||
run_store = RunStoreAdapter(app_persistence.session_factory)
|
||||
thread_meta_store = ThreadMetaStoreAdapter(app_persistence.session_factory)
|
||||
feedback_store = FeedbackStoreAdapter(app_persistence.session_factory)
|
||||
|
||||
try:
|
||||
yield {
|
||||
"persistence": app_persistence,
|
||||
"checkpointer": app_persistence.checkpointer,
|
||||
"store": None,
|
||||
"session_factory": app_persistence.session_factory,
|
||||
"run_store": run_store,
|
||||
"run_read_repo": run_store,
|
||||
"run_write_repo": run_store,
|
||||
"run_delete_repo": run_store,
|
||||
"feedback_repo": feedback_store,
|
||||
"thread_meta_repo": thread_meta_store,
|
||||
"thread_meta_storage": ThreadMetaStorage(thread_meta_store),
|
||||
"run_event_store": build_run_event_store(app_persistence.session_factory),
|
||||
}
|
||||
finally:
|
||||
await app_persistence.aclose()
|
||||
|
||||
|
||||
@lifespan_manager.register
|
||||
@asynccontextmanager
|
||||
async def init_runtime(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Initialize StreamBridge for LangGraph-compatible runtime endpoints."""
|
||||
from app.infra.stream_bridge import build_stream_bridge
|
||||
|
||||
async with build_stream_bridge() as stream_bridge:
|
||||
yield {
|
||||
"stream_bridge": stream_bridge,
|
||||
}
|
||||
|
||||
|
||||
def register_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="DeerFlow API Gateway",
|
||||
version="0.1.0",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan_manager.build(),
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "threads",
|
||||
"description": "Endpoints for managing threads, which are conversations between a human and an assistant. A thread can have multiple runs as the conversation progresses."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
app.state.authz_hooks = build_authz_hooks()
|
||||
|
||||
_register_static(app)
|
||||
_register_routes(app)
|
||||
_register_scalar(app)
|
||||
_register_auth_route_policies(app)
|
||||
_register_middlewares(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _register_static(app: FastAPI) -> None:
|
||||
app.mount(STATIC_MOUNT, StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
|
||||
def _register_routes(app: FastAPI) -> None:
|
||||
app.include_router(gateway_router)
|
||||
|
||||
|
||||
def _register_auth_route_policies(app: FastAPI) -> None:
|
||||
registry = load_route_policy_registry()
|
||||
validate_route_policy_registry(app, registry)
|
||||
app.state.auth_route_policy_registry = registry
|
||||
install_route_guards(app)
|
||||
|
||||
|
||||
def _register_middlewares(app: FastAPI) -> None:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
)
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
|
||||
def _register_scalar(app: FastAPI) -> None:
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def scalar_docs() -> HTMLResponse:
|
||||
return get_scalar_api_reference(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
scalar_js_url=SCALAR_JS_URL,
|
||||
agent=AgentScalarConfig(disabled=True),
|
||||
hide_client_button=True,
|
||||
overrides={"mcp": {"disabled": True}},
|
||||
)
|
||||
@@ -1,22 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.plugins.auth.api.router import router as auth_router
|
||||
|
||||
from .routers import artifacts, channels, mcp, models, skills, uploads
|
||||
from .routers.agents import router as agents_router
|
||||
from .routers.langgraph import feedback_router, runs_router, suggestion_router, threads_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(auth_router)
|
||||
router.include_router(threads_router, prefix="/api/threads")
|
||||
router.include_router(runs_router, prefix="/api/threads")
|
||||
router.include_router(feedback_router, prefix="/api/threads")
|
||||
router.include_router(suggestion_router)
|
||||
router.include_router(agents_router)
|
||||
router.include_router(channels.router)
|
||||
router.include_router(artifacts.router)
|
||||
router.include_router(mcp.router)
|
||||
router.include_router(models.router)
|
||||
router.include_router(skills.router)
|
||||
router.include_router(uploads.router)
|
||||
@@ -1,3 +1,3 @@
|
||||
from . import artifacts, mcp, models, skills, suggestions, uploads
|
||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
|
||||
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]
|
||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
|
||||
@@ -7,6 +7,7 @@ from urllib.parse import quote
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
||||
summary="Get Artifact File",
|
||||
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
||||
)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
||||
"""Get an artifact file by its path.
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Assistants compatibility endpoints.
|
||||
|
||||
Provides LangGraph Platform-compatible assistants API backed by the
|
||||
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
|
||||
|
||||
This is a minimal stub that satisfies the ``useStream`` React hook's
|
||||
initialization requirements (``assistants.search()`` and ``assistants.get()``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
|
||||
|
||||
|
||||
class AssistantResponse(BaseModel):
|
||||
assistant_id: str
|
||||
graph_id: str
|
||||
name: str
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
description: str | None = None
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
version: int = 1
|
||||
|
||||
|
||||
class AssistantSearchRequest(BaseModel):
|
||||
graph_id: str | None = None
|
||||
name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
limit: int = 10
|
||||
offset: int = 0
|
||||
|
||||
|
||||
def _get_default_assistant() -> AssistantResponse:
|
||||
"""Return the default lead_agent assistant."""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
return AssistantResponse(
|
||||
assistant_id="lead_agent",
|
||||
graph_id="lead_agent",
|
||||
name="lead_agent",
|
||||
config={},
|
||||
metadata={"created_by": "system"},
|
||||
description="DeerFlow lead agent",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
def _list_assistants() -> list[AssistantResponse]:
|
||||
"""List all available assistants from config."""
|
||||
assistants = [_get_default_assistant()]
|
||||
|
||||
# Also include custom agents from config.yaml agents directory
|
||||
try:
|
||||
from deerflow.config.agents_config import list_custom_agents
|
||||
|
||||
for agent_cfg in list_custom_agents():
|
||||
now = datetime.now(UTC).isoformat()
|
||||
assistants.append(
|
||||
AssistantResponse(
|
||||
assistant_id=agent_cfg.name,
|
||||
graph_id="lead_agent", # All agents use the same graph
|
||||
name=agent_cfg.name,
|
||||
config={},
|
||||
metadata={"created_by": "user"},
|
||||
description=agent_cfg.description or "",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
version=1,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not load custom agents for assistants list")
|
||||
|
||||
return assistants
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[AssistantResponse])
|
||||
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
|
||||
"""Search assistants.
|
||||
|
||||
Returns all registered assistants (lead_agent + custom agents from config).
|
||||
"""
|
||||
assistants = _list_assistants()
|
||||
|
||||
if body and body.graph_id:
|
||||
assistants = [a for a in assistants if a.graph_id == body.graph_id]
|
||||
if body and body.name:
|
||||
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
|
||||
|
||||
offset = body.offset if body else 0
|
||||
limit = body.limit if body else 10
|
||||
return assistants[offset : offset + limit]
|
||||
|
||||
|
||||
@router.get("/{assistant_id}", response_model=AssistantResponse)
|
||||
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
|
||||
"""Get an assistant by ID."""
|
||||
for a in _list_assistants():
|
||||
if a.assistant_id == assistant_id:
|
||||
return a
|
||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/graph")
|
||||
async def get_assistant_graph(assistant_id: str) -> dict:
|
||||
"""Get the graph structure for an assistant.
|
||||
|
||||
Returns a minimal graph description. Full graph introspection is
|
||||
not supported in the Gateway — this stub satisfies SDK validation.
|
||||
"""
|
||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
||||
|
||||
return {
|
||||
"graph_id": "lead_agent",
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/schemas")
|
||||
async def get_assistant_schemas(assistant_id: str) -> dict:
|
||||
"""Get JSON schemas for an assistant's input/output/state.
|
||||
|
||||
Returns empty schemas — full introspection not supported in Gateway.
|
||||
"""
|
||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
||||
|
||||
return {
|
||||
"graph_id": "lead_agent",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"state_schema": {},
|
||||
"config_schema": {},
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from app.gateway.auth import (
|
||||
UserResponse,
|
||||
create_access_token,
|
||||
)
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.csrf_middleware import is_secure_request
|
||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request/Response Models ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
||||
|
||||
expires_in: int # seconds
|
||||
needs_setup: bool = False
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Request model for password change (also handles setup flow)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
new_email: EmailStr | None = None
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||
"""Set the access_token HttpOnly cookie on the response."""
|
||||
config = get_auth_config()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_https,
|
||||
samesite="lax",
|
||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||
)
|
||||
|
||||
|
||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||
|
||||
_MAX_LOGIN_ATTEMPTS = 5
|
||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
# ip → (fail_count, lock_until_timestamp)
|
||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract the real client IP for rate limiting.
|
||||
|
||||
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
|
||||
$remote_addr``). Nginx unconditionally overwrites any client-supplied
|
||||
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
|
||||
that nginx observed — it cannot be spoofed by the client.
|
||||
|
||||
``request.client.host`` is NOT reliable because uvicorn's default
|
||||
``proxy_headers=True`` replaces it with the *first* entry from
|
||||
``X-Forwarded-For``, which IS client-spoofable.
|
||||
|
||||
``X-Forwarded-For`` is intentionally NOT used for the same reason.
|
||||
"""
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback: direct connection without nginx (e.g. unit tests, dev).
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _check_rate_limit(ip: str) -> None:
|
||||
"""Raise 429 if the IP is currently locked out."""
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
return
|
||||
fail_count, lock_until = record
|
||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||
if time.time() < lock_until:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many login attempts. Try again later.",
|
||||
)
|
||||
del _login_attempts[ip]
|
||||
|
||||
|
||||
_MAX_TRACKED_IPS = 10000
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> None:
|
||||
"""Record a failed login attempt for the given IP."""
|
||||
# Evict expired lockouts when dict grows too large
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
now = time.time()
|
||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||
for k in expired:
|
||||
del _login_attempts[k]
|
||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||
for k, _ in by_time[: len(by_time) // 2]:
|
||||
del _login_attempts[k]
|
||||
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
_login_attempts[ip] = (1, 0.0)
|
||||
else:
|
||||
new_count = record[0] + 1
|
||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||
_login_attempts[ip] = (new_count, lock_until)
|
||||
|
||||
|
||||
def _record_login_success(ip: str) -> None:
|
||||
"""Clear failure counter for the given IP on successful login."""
|
||||
_login_attempts.pop(ip, None)
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/login/local", response_model=LoginResponse)
|
||||
async def login_local(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
"""Local email/password login."""
|
||||
client_ip = _get_client_ip(request)
|
||||
_check_rate_limit(client_ip)
|
||||
|
||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
||||
|
||||
if user is None:
|
||||
_record_login_failure(client_ip)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
||||
)
|
||||
|
||||
_record_login_success(client_ip)
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return LoginResponse(
|
||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||
needs_setup=user.needs_setup,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||
"""Register a new user account (always 'user' role).
|
||||
|
||||
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||
Auto-login by setting the session cookie.
|
||||
"""
|
||||
try:
|
||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
||||
)
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response):
|
||||
"""Logout current user by clearing the cookie."""
|
||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/change-password", response_model=MessageResponse)
|
||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
||||
"""Change password for the currently authenticated user.
|
||||
|
||||
Also handles the first-boot setup flow:
|
||||
- If new_email is provided, updates email (checks uniqueness)
|
||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
||||
- Always increments token_version to invalidate old sessions
|
||||
- Re-issues session cookie with new token_version
|
||||
"""
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||
|
||||
if not await verify_password_async(body.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
||||
|
||||
provider = get_local_provider()
|
||||
|
||||
# Update email if provided
|
||||
if body.new_email is not None:
|
||||
existing = await provider.get_user_by_email(body.new_email)
|
||||
if existing and str(existing.id) != str(user.id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
||||
user.email = body.new_email
|
||||
|
||||
# Update password + bump version
|
||||
user.password_hash = await hash_password_async(body.new_password)
|
||||
user.token_version += 1
|
||||
|
||||
# Clear setup flag if this is the setup flow
|
||||
if user.needs_setup and body.new_email is not None:
|
||||
user.needs_setup = False
|
||||
|
||||
await provider.update_user(user)
|
||||
|
||||
# Re-issue cookie with new token_version
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return MessageResponse(message="Password changed successfully")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(request: Request):
|
||||
"""Get current authenticated user info."""
|
||||
user = await get_current_user_from_request(request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status():
|
||||
"""Check if admin account exists. Always False after first boot."""
|
||||
user_count = await get_local_provider().count_users()
|
||||
return {"needs_setup": user_count == 0}
|
||||
|
||||
|
||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
async def oauth_login(provider: str):
|
||||
"""Initiate OAuth login flow.
|
||||
|
||||
Redirects to the OAuth provider's authorization URL.
|
||||
Currently a placeholder - requires OAuth provider implementation.
|
||||
"""
|
||||
if provider not in ["github", "google"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider}",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth login not yet implemented",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback/{provider}")
|
||||
async def oauth_callback(provider: str, code: str, state: str):
|
||||
"""OAuth callback endpoint.
|
||||
|
||||
Handles the OAuth provider's callback after user authorization.
|
||||
Currently a placeholder.
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth callback not yet implemented",
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Feedback endpoints — create, list, stats, delete.
|
||||
|
||||
Allows users to submit thumbs-up/down feedback on runs,
|
||||
optionally scoped to a specific message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["feedback"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FeedbackCreateRequest(BaseModel):
|
||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
owner_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
class FeedbackStatsResponse(BaseModel):
|
||||
run_id: str
|
||||
total: int = 0
|
||||
positive: int = 0
|
||||
negative: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def create_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Submit feedback (thumbs-up/down) for a run."""
|
||||
if body.rating not in (1, -1):
|
||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||
|
||||
user_id = await get_current_user(request)
|
||||
|
||||
# Validate run exists and belongs to thread
|
||||
run_store = get_run_store(request)
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
if run.get("thread_id") != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.create(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
owner_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def list_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all feedback for a run."""
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.list_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def feedback_stats(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated feedback stats (positive/negative counts) for a run."""
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
feedback_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a feedback record."""
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
# Verify feedback belongs to the specified thread/run before deleting
|
||||
existing = await feedback_repo.get(feedback_id)
|
||||
if existing is None:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
|
||||
deleted = await feedback_repo.delete(feedback_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
return {"success": True}
|
||||
@@ -1,6 +0,0 @@
|
||||
from .feedback import router as feedback_router
|
||||
from .runs import router as runs_router
|
||||
from .suggestions import router as suggestion_router
|
||||
from .threads import router as threads_router
|
||||
|
||||
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]
|
||||
@@ -1,179 +0,0 @@
|
||||
"""LangGraph-compatible run feedback endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.dependencies import get_feedback_repository, get_run_repository
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||
from app.plugins.auth.security.dependencies import get_current_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["feedback"])
|
||||
|
||||
|
||||
class FeedbackCreateRequest(BaseModel):
|
||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
owner_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
class FeedbackStatsResponse(BaseModel):
|
||||
run_id: str
|
||||
total: int = 0
|
||||
positive: int = 0
|
||||
negative: int = 0
|
||||
|
||||
|
||||
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
|
||||
run_store = get_run_repository(request)
|
||||
if resolve_request_user_id(request) is None:
|
||||
run = await run_store.get(run_id, user_id=None)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
if run.get("thread_id") != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||
|
||||
|
||||
async def _get_current_user(request: Request) -> str | None:
|
||||
"""Extract current user id from auth dependencies when available."""
|
||||
return await get_current_user_id(request)
|
||||
|
||||
|
||||
async def _create_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
if body.rating not in (1, -1):
|
||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||
|
||||
await _validate_run_scope(thread_id, run_id, request)
|
||||
user_id = await _get_current_user(request)
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
return await feedback_repo.create(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
async def upsert_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Create or replace the run-level feedback record."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
user_id = await _get_current_user(request)
|
||||
if user_id is not None:
|
||||
return await feedback_repo.upsert(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||
for item in existing:
|
||||
feedback_id = item.get("feedback_id")
|
||||
if isinstance(feedback_id, str):
|
||||
await feedback_repo.delete(feedback_id)
|
||||
return await _create_feedback(thread_id, run_id, body, request)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
async def create_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Submit feedback for a run."""
|
||||
return await _create_feedback(thread_id, run_id, body, request)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||
async def list_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all feedback for a run."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
user_id = await _get_current_user(request)
|
||||
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||
async def feedback_stats(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated feedback stats for a run."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||
async def delete_run_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete all feedback records for a run."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
user_id = await _get_current_user(request)
|
||||
if user_id is not None:
|
||||
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
|
||||
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||
for item in existing:
|
||||
feedback_id = item.get("feedback_id")
|
||||
if isinstance(feedback_id, str):
|
||||
await feedback_repo.delete(feedback_id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||
async def delete_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
feedback_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a single feedback record."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
existing = await feedback_repo.get(feedback_id)
|
||||
if existing is None:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
|
||||
deleted = await feedback_repo.delete(feedback_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
return {"success": True}
|
||||
@@ -1,501 +0,0 @@
|
||||
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
|
||||
from app.gateway.services.runs.input import (
|
||||
AdaptedRunRequest,
|
||||
RunSpecBuilder,
|
||||
UnsupportedRunFeatureError,
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from deerflow.runtime.runs.types import RunRecord, RunSpec
|
||||
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
|
||||
|
||||
router = APIRouter(tags=["runs"])
|
||||
|
||||
|
||||
class RunCreateRequest(BaseModel):
|
||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
|
||||
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
|
||||
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
|
||||
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
|
||||
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
|
||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
status: str
|
||||
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||
multitask_strategy: str = "reject"
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class RunDeleteResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class RunMessageResponse(BaseModel):
|
||||
run_id: str
|
||||
content: JSONValue
|
||||
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
seq: int
|
||||
|
||||
|
||||
class RunMessagesResponse(BaseModel):
|
||||
data: list[RunMessageResponse]
|
||||
hasMore: bool = False
|
||||
|
||||
|
||||
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
|
||||
"""Format a single SSE frame."""
|
||||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||
parts = [f"event: {event}", f"data: {payload}"]
|
||||
if event_id:
|
||||
parts.append(f"id: {event_id}")
|
||||
parts.append("")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
return RunResponse(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=record.status,
|
||||
metadata=record.metadata,
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _trim_paginated_rows(
|
||||
rows: list[dict],
|
||||
*,
|
||||
limit: int,
|
||||
after_seq: int | None,
|
||||
) -> tuple[list[dict], bool]:
|
||||
has_more = len(rows) > limit
|
||||
if not has_more:
|
||||
return rows, False
|
||||
if after_seq is not None:
|
||||
return rows[:limit], True
|
||||
return rows[-limit:], True
|
||||
|
||||
|
||||
def _event_to_run_message(event: dict) -> RunMessageResponse:
|
||||
return RunMessageResponse(
|
||||
run_id=str(event["run_id"]),
|
||||
content=event.get("content"),
|
||||
metadata=dict(event.get("metadata") or {}),
|
||||
created_at=str(event.get("created_at") or ""),
|
||||
seq=int(event["seq"]),
|
||||
)
|
||||
|
||||
|
||||
async def _sse_consumer(
|
||||
stream: AsyncIterator[StreamEvent],
|
||||
request: Request,
|
||||
*,
|
||||
cancel_on_disconnect: bool,
|
||||
cancel_run,
|
||||
run_id: str,
|
||||
) -> AsyncIterator[str]:
|
||||
try:
|
||||
async for event in stream:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
if event.event == "__heartbeat__":
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
|
||||
if event.event == "__end__":
|
||||
yield format_sse("end", None, event_id=event.id or None)
|
||||
return
|
||||
|
||||
if event.event == "__cancelled__":
|
||||
yield format_sse("cancel", None, event_id=event.id or None)
|
||||
return
|
||||
|
||||
yield format_sse(event.event, event.data, event_id=event.id or None)
|
||||
finally:
|
||||
if cancel_on_disconnect:
|
||||
await cancel_run(run_id)
|
||||
|
||||
|
||||
def _get_run_event_store(request: Request):
|
||||
event_store = getattr(request.app.state, "run_event_store", None)
|
||||
if event_store is None:
|
||||
raise HTTPException(status_code=503, detail="Run event store not available")
|
||||
return event_store
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
async def list_runs(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
status: str | None = None,
|
||||
) -> list[RunResponse]:
|
||||
# Accepted for API compatibility; field projection is not implemented yet.
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
records = await facade.list_runs(thread_id)
|
||||
if status is not None:
|
||||
records = [record for record in records if record.status == status]
|
||||
records = records[offset : offset + limit]
|
||||
return [_record_to_response(record) for record in records]
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
|
||||
async def run_messages(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> RunMessagesResponse:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
event_store = _get_run_event_store(request)
|
||||
with bind_request_actor_context(request):
|
||||
rows = await event_store.list_messages_by_run(
|
||||
thread_id,
|
||||
run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
|
||||
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
|
||||
|
||||
|
||||
def _build_spec(
|
||||
*,
|
||||
adapted: AdaptedRunRequest,
|
||||
) -> RunSpec:
|
||||
try:
|
||||
return RunSpecBuilder().build(adapted)
|
||||
except UnsupportedRunFeatureError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
async def create_run(
|
||||
thread_id: str,
|
||||
body: RunCreateRequest,
|
||||
request: Request,
|
||||
) -> Response:
|
||||
adapted = adapt_create_run_request(
|
||||
thread_id=thread_id,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.create_background(spec)
|
||||
return Response(
|
||||
content=_record_to_response(record).model_dump_json(),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
async def stream_run(
|
||||
thread_id: str,
|
||||
body: RunCreateRequest,
|
||||
request: Request,
|
||||
) -> StreamingResponse:
|
||||
adapted = adapt_create_stream_request(
|
||||
thread_id=thread_id,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
|
||||
spec = _build_spec(adapted=adapted)
|
||||
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, stream = await facade.create_and_stream(spec)
|
||||
|
||||
return StreamingResponse(
|
||||
_sse_consumer(
|
||||
stream,
|
||||
request,
|
||||
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||
cancel_run=facade.cancel,
|
||||
run_id=record.run_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait")
|
||||
async def wait_run(
|
||||
thread_id: str,
|
||||
body: RunCreateRequest,
|
||||
request: Request,
|
||||
) -> Response:
|
||||
adapted = adapt_create_wait_request(
|
||||
thread_id=thread_id,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, result = await facade.create_and_wait(spec)
|
||||
return Response(
|
||||
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs", response_model=RunResponse)
|
||||
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||
adapted = adapt_create_run_request(
|
||||
thread_id=None,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.create_background(spec)
|
||||
return Response(
|
||||
content=_record_to_response(record).model_dump_json(),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs/stream")
|
||||
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
adapted = adapt_create_stream_request(
|
||||
thread_id=None,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, stream = await facade.create_and_stream(spec)
|
||||
|
||||
return StreamingResponse(
|
||||
_sse_consumer(
|
||||
stream,
|
||||
request,
|
||||
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||
cancel_run=facade.cancel,
|
||||
run_id=record.run_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs/wait")
|
||||
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||
adapted = adapt_create_wait_request(
|
||||
thread_id=None,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, result = await facade.create_and_wait(spec)
|
||||
return Response(
|
||||
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||
async def stream_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
action: Literal["interrupt", "rollback"] | None = None,
|
||||
wait: bool = False,
|
||||
cancel_on_disconnect: bool = False,
|
||||
stream_mode: str | None = None,
|
||||
) -> StreamingResponse | Response:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
if action is not None:
|
||||
with bind_request_actor_context(request):
|
||||
cancelled = await facade.cancel(run_id, action=action)
|
||||
if not cancelled:
|
||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||
if wait:
|
||||
with bind_request_actor_context(request):
|
||||
await facade.join_wait(run_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
adapted = adapt_join_stream_request(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
with bind_request_actor_context(request):
|
||||
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
|
||||
|
||||
return StreamingResponse(
|
||||
_sse_consumer(
|
||||
stream,
|
||||
request,
|
||||
cancel_on_disconnect=cancel_on_disconnect,
|
||||
cancel_run=facade.cancel,
|
||||
run_id=run_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
async def join_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
cancel_on_disconnect: bool = False,
|
||||
) -> JSONValue:
|
||||
# Accepted for API compatibility; current join_wait path does not change
|
||||
# behavior based on client disconnect.
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
adapted = adapt_join_wait_request(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
with bind_request_actor_context(request):
|
||||
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
async def cancel_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
wait: bool = False,
|
||||
action: Literal["interrupt", "rollback"] = "interrupt",
|
||||
) -> JSONValue:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
with bind_request_actor_context(request):
|
||||
cancelled = await facade.cancel(run_id, action=action)
|
||||
if not cancelled:
|
||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||
if wait:
|
||||
with bind_request_actor_context(request):
|
||||
return await facade.join_wait(run_id)
|
||||
return {}
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
|
||||
async def delete_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> RunDeleteResponse:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
with bind_request_actor_context(request):
|
||||
deleted = await facade.delete_run(run_id)
|
||||
return RunDeleteResponse(deleted=deleted)
|
||||
@@ -1,132 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["suggestions"])
|
||||
|
||||
|
||||
class SuggestionMessage(BaseModel):
|
||||
role: str = Field(..., description="Message role: user|assistant")
|
||||
content: str = Field(..., description="Message content as plain text")
|
||||
|
||||
|
||||
class SuggestionsRequest(BaseModel):
|
||||
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
|
||||
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
|
||||
model_name: str | None = Field(default=None, description="Optional model override")
|
||||
|
||||
|
||||
class SuggestionsResponse(BaseModel):
|
||||
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||
|
||||
|
||||
def _strip_markdown_code_fence(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
if not stripped.startswith("```"):
|
||||
return stripped
|
||||
lines = stripped.splitlines()
|
||||
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||
return "\n".join(lines[1:-1]).strip()
|
||||
return stripped
|
||||
|
||||
|
||||
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||
candidate = _strip_markdown_code_fence(text)
|
||||
start = candidate.find("[")
|
||||
end = candidate.rfind("]")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
return None
|
||||
candidate = candidate[start : end + 1]
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(data, list):
|
||||
return None
|
||||
out: list[str] = []
|
||||
for item in data:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
s = item.strip()
|
||||
if not s:
|
||||
continue
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
def _extract_response_text(content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else ""
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||
parts: list[str] = []
|
||||
for m in messages:
|
||||
role = m.role.strip().lower()
|
||||
if role in ("user", "human"):
|
||||
parts.append(f"User: {m.content.strip()}")
|
||||
elif role in ("assistant", "ai"):
|
||||
parts.append(f"Assistant: {m.content.strip()}")
|
||||
else:
|
||||
parts.append(f"{m.role}: {m.content.strip()}")
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/suggestions",
|
||||
response_model=SuggestionsResponse,
|
||||
summary="Generate Follow-up Questions",
|
||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||
)
|
||||
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||
if not request.messages:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
n = request.n
|
||||
conversation = _format_conversation(request.messages)
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
system_instruction = (
|
||||
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||
"Requirements:\n"
|
||||
"- Questions must be relevant to the preceding conversation.\n"
|
||||
"- Questions must be written in the same language as the user.\n"
|
||||
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||
"- Output MUST be a JSON array of strings only.\n"
|
||||
)
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||
cleaned = cleaned[:n]
|
||||
return SuggestionsResponse(suggestions=cleaned)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
@@ -1,455 +0,0 @@
|
||||
"""Thread management endpoints.
|
||||
|
||||
Provides CRUD operations for threads and checkpoint state management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
|
||||
from app.infra.storage import ThreadMetaStorage
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["threads"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ThreadCreateRequest(BaseModel):
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||
status: str | None = Field(default=None, description="Filter by thread status")
|
||||
user_id: str | None = Field(default=None, description="Filter by user ID")
|
||||
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
|
||||
|
||||
|
||||
class ThreadResponse(BaseModel):
|
||||
thread_id: str = Field(description="Unique thread identifier")
|
||||
status: str = Field(default="idle", description="Thread status")
|
||||
created_at: str = Field(default="", description="ISO timestamp")
|
||||
updated_at: str = Field(default="", description="ISO timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
|
||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||
|
||||
|
||||
class ThreadDeleteResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ThreadStateUpdateRequest(BaseModel):
|
||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||
|
||||
|
||||
class ThreadStateResponse(BaseModel):
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
|
||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||
|
||||
|
||||
class ThreadHistoryRequest(BaseModel):
|
||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
|
||||
|
||||
|
||||
class HistoryEntry(BaseModel):
|
||||
checkpoint_id: str
|
||||
parent_checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
values: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
next: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def sanitize_log_param(value: str) -> str:
|
||||
"""Strip control characters to prevent log injection."""
|
||||
|
||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||
|
||||
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||
"""Delete local filesystem data for a thread."""
|
||||
path_manager = paths or get_paths()
|
||||
try:
|
||||
path_manager.delete_thread_dir(thread_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||
|
||||
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||
|
||||
|
||||
async def _thread_or_run_exists(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
thread_meta_storage: ThreadMetaStorage,
|
||||
run_repo,
|
||||
) -> bool:
|
||||
request_user_id = resolve_request_user_id(request)
|
||||
|
||||
if request_user_id is None:
|
||||
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||
if thread is not None:
|
||||
return True
|
||||
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
|
||||
return bool(runs)
|
||||
|
||||
with bind_request_actor_context(request):
|
||||
thread = await thread_meta_storage.get_thread(thread_id)
|
||||
if thread is not None:
|
||||
return True
|
||||
runs = await run_repo.list_by_thread(thread_id, limit=1)
|
||||
return bool(runs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("", response_model=ThreadResponse)
|
||||
async def create_thread(
|
||||
body: ThreadCreateRequest,
|
||||
request: Request,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> ThreadResponse:
|
||||
"""Create a new thread."""
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
|
||||
request_user_id = resolve_request_user_id(request)
|
||||
if request_user_id is None:
|
||||
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
existing = await thread_meta_storage.get_thread(thread_id)
|
||||
if existing is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing.status,
|
||||
created_at=existing.created_time.isoformat() if existing.created_time else "",
|
||||
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
|
||||
metadata=existing.metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
if request_user_id is None:
|
||||
created = await thread_meta_storage.ensure_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
user_id=None,
|
||||
)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
created = await thread_meta_storage.ensure_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=created.status,
|
||||
created_at=created.created_time.isoformat() if created.created_time else "",
|
||||
updated_at=created.updated_time.isoformat() if created.updated_time else "",
|
||||
metadata=created.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[ThreadResponse])
|
||||
async def search_threads(
|
||||
body: ThreadSearchRequest,
|
||||
request: Request,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> list[ThreadResponse]:
|
||||
"""Search threads with filters."""
|
||||
try:
|
||||
request_user_id = resolve_request_user_id(request)
|
||||
if request_user_id is None:
|
||||
threads = await thread_meta_storage.search_threads(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
user_id=body.user_id,
|
||||
assistant_id=body.assistant_id,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
threads = await thread_meta_storage.search_threads(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
assistant_id=body.assistant_id,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to search threads")
|
||||
raise HTTPException(status_code=500, detail="Failed to search threads")
|
||||
|
||||
return [
|
||||
ThreadResponse(
|
||||
thread_id=t.thread_id,
|
||||
status=t.status,
|
||||
created_at=t.created_time.isoformat() if t.created_time else "",
|
||||
updated_at=t.updated_time.isoformat() if t.updated_time else "",
|
||||
metadata=t.metadata,
|
||||
values={"title": t.display_name} if t.display_name else {},
|
||||
interrupts={},
|
||||
)
|
||||
for t in threads
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
async def delete_thread(
|
||||
thread_id: str,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> ThreadDeleteResponse:
|
||||
"""Delete a thread and all associated data."""
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
await checkpointer.adelete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
|
||||
|
||||
# Remove thread_meta (best-effort)
|
||||
try:
|
||||
await thread_meta_storage.delete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def get_thread_state(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
run_repo: CurrentRunRepository,
|
||||
) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread."""
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
if await _thread_or_run_exists(
|
||||
request=request,
|
||||
thread_id=thread_id,
|
||||
thread_meta_storage=thread_meta_storage,
|
||||
run_repo=run_repo,
|
||||
):
|
||||
return ThreadStateResponse()
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=next_nodes,
|
||||
tasks=tasks,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
metadata=metadata,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def update_thread_state(
|
||||
thread_id: str,
|
||||
body: ThreadStateUpdateRequest,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> ThreadStateResponse:
|
||||
"""Update thread state (human-in-the-loop or title rename)."""
|
||||
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
if body.checkpoint_id:
|
||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||
|
||||
if body.values:
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = time.time()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
metadata["step"] = metadata.get("step", 0) + 1
|
||||
metadata["writes"] = {body.as_node: body.values}
|
||||
|
||||
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||
|
||||
new_checkpoint_id: str | None = None
|
||||
if isinstance(new_config, dict):
|
||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
# Sync title to thread_meta
|
||||
if body.values and "title" in body.values:
|
||||
new_title = body.values["title"]
|
||||
if new_title:
|
||||
try:
|
||||
await thread_meta_storage.sync_thread_title(
|
||||
thread_id=thread_id,
|
||||
title=new_title,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
async def get_thread_history(
|
||||
thread_id: str,
|
||||
body: ThreadHistoryRequest,
|
||||
request: Request,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
run_repo: CurrentRunRepository,
|
||||
) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
if body.before:
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_first = True
|
||||
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
if title := channel_values.get("title"):
|
||||
values["title"] = title
|
||||
if is_first and (messages := channel_values.get("messages")):
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_first = False
|
||||
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
|
||||
entries.append(
|
||||
HistoryEntry(
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=metadata,
|
||||
values=values,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
next=next_nodes,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||
|
||||
if not entries and await _thread_or_run_exists(
|
||||
request=request,
|
||||
thread_id=thread_id,
|
||||
thread_meta_storage=thread_meta_storage,
|
||||
run_repo=run_repo,
|
||||
):
|
||||
return []
|
||||
|
||||
return entries
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Memory API router for retrieving and managing global memory data."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||
from deerflow.agents.memory.updater import (
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
@@ -14,7 +13,6 @@ from deerflow.agents.memory.updater import (
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["memory"])
|
||||
|
||||
@@ -115,7 +113,7 @@ class MemoryStatusResponse(BaseModel):
|
||||
summary="Get Memory Data",
|
||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||
)
|
||||
async def get_memory(request: Request) -> MemoryResponse:
|
||||
async def get_memory() -> MemoryResponse:
|
||||
"""Get the current global memory data.
|
||||
|
||||
Returns:
|
||||
@@ -149,9 +147,8 @@ async def get_memory(request: Request) -> MemoryResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
with bind_request_actor_context(request):
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -161,7 +158,7 @@ async def get_memory(request: Request) -> MemoryResponse:
|
||||
summary="Reload Memory Data",
|
||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||
)
|
||||
async def reload_memory(request: Request) -> MemoryResponse:
|
||||
async def reload_memory() -> MemoryResponse:
|
||||
"""Reload memory data from file.
|
||||
|
||||
This forces a reload of the memory data from the storage file,
|
||||
@@ -170,9 +167,8 @@ async def reload_memory(request: Request) -> MemoryResponse:
|
||||
Returns:
|
||||
The reloaded memory data.
|
||||
"""
|
||||
with bind_request_actor_context(request):
|
||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
memory_data = reload_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -182,15 +178,14 @@ async def reload_memory(request: Request) -> MemoryResponse:
|
||||
summary="Clear All Memory Data",
|
||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||
)
|
||||
async def clear_memory(request: Request) -> MemoryResponse:
|
||||
async def clear_memory() -> MemoryResponse:
|
||||
"""Clear all persisted memory data."""
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||
try:
|
||||
memory_data = clear_memory_data()
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -200,22 +195,20 @@ async def clear_memory(request: Request) -> MemoryResponse:
|
||||
summary="Create Memory Fact",
|
||||
description="Create a single saved memory fact manually.",
|
||||
)
|
||||
async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse:
|
||||
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
|
||||
"""Create a single fact manually."""
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
memory_data = create_memory_fact(
|
||||
content=payload.content,
|
||||
category=payload.category,
|
||||
confidence=payload.confidence,
|
||||
user_id=get_effective_user_id(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
||||
try:
|
||||
memory_data = create_memory_fact(
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -225,17 +218,16 @@ async def create_memory_fact_endpoint(request: Request, payload: FactCreateReque
|
||||
summary="Delete Memory Fact",
|
||||
description="Delete a single saved memory fact by its fact id.",
|
||||
)
|
||||
async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse:
|
||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
"""Delete a single fact from memory by fact id."""
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
||||
try:
|
||||
memory_data = delete_memory_fact(fact_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -245,25 +237,23 @@ async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryR
|
||||
summary="Patch Memory Fact",
|
||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||
)
|
||||
async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse:
|
||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
|
||||
"""Partially update a single fact manually."""
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
memory_data = update_memory_fact(
|
||||
fact_id=fact_id,
|
||||
content=payload.content,
|
||||
category=payload.category,
|
||||
confidence=payload.confidence,
|
||||
user_id=get_effective_user_id(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
||||
try:
|
||||
memory_data = update_memory_fact(
|
||||
fact_id=fact_id,
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -273,11 +263,10 @@ async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: F
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
async def export_memory(request: Request) -> MemoryResponse:
|
||||
async def export_memory() -> MemoryResponse:
|
||||
"""Export the current memory data."""
|
||||
with bind_request_actor_context(request):
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -287,15 +276,14 @@ async def export_memory(request: Request) -> MemoryResponse:
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse:
|
||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
"""Import and persist memory data."""
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
try:
|
||||
memory_data = import_memory_data(request.model_dump())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -342,25 +330,24 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
summary="Get Memory Status",
|
||||
description="Retrieve both memory configuration and current data in a single request.",
|
||||
)
|
||||
async def get_memory_status(request: Request) -> MemoryStatusResponse:
|
||||
async def get_memory_status() -> MemoryStatusResponse:
|
||||
"""Get the memory system status including configuration and data.
|
||||
|
||||
Returns:
|
||||
Combined memory configuration and current data.
|
||||
"""
|
||||
with bind_request_actor_context(request):
|
||||
config = get_memory_config()
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
config = get_memory_config()
|
||||
memory_data = get_memory_data()
|
||||
|
||||
return MemoryStatusResponse(
|
||||
config=MemoryConfigResponse(
|
||||
enabled=config.enabled,
|
||||
storage_path=config.storage_path,
|
||||
debounce_seconds=config.debounce_seconds,
|
||||
max_facts=config.max_facts,
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
),
|
||||
data=MemoryResponse(**memory_data),
|
||||
)
|
||||
return MemoryStatusResponse(
|
||||
config=MemoryConfigResponse(
|
||||
enabled=config.enabled,
|
||||
storage_path=config.storage_path,
|
||||
debounce_seconds=config.debounce_seconds,
|
||||
max_facts=config.max_facts,
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
),
|
||||
data=MemoryResponse(**memory_data),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
|
||||
|
||||
These endpoints auto-create a temporary thread when no ``thread_id`` is
|
||||
supplied in the request body. When a ``thread_id`` **is** provided, it
|
||||
is reused so that conversation history is preserved across calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
||||
|
||||
|
||||
def _resolve_thread_id(body: RunCreateRequest) -> str:
|
||||
"""Return the thread_id from the request body, or generate a new one."""
|
||||
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
|
||||
if thread_id:
|
||||
return str(thread_id)
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
If ``config.configurable.thread_id`` is provided, the run is created
|
||||
on the given thread so that conversation history is preserved.
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/wait", response_model=dict)
|
||||
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until completion.
|
||||
|
||||
If ``config.configurable.thread_id`` is provided, the run is created
|
||||
on the given thread so that conversation history is preserved.
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
@@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.agents.lead_agent.prompt import clear_skills_system_prompt_cache
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.skills import Skill, load_skills
|
||||
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||
@@ -119,7 +119,6 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||
try:
|
||||
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
||||
result = install_skill_from_archive(skill_file_path)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return SkillInstallResponse(**result)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -182,7 +181,7 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
},
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
clear_skills_system_prompt_cache()
|
||||
return await get_custom_skill(skill_name)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -214,7 +213,7 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
},
|
||||
)
|
||||
shutil.rmtree(skill_dir)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
clear_skills_system_prompt_cache()
|
||||
return {"success": True}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -269,7 +268,7 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
||||
atomic_write(skill_file, target_content)
|
||||
append_history(skill_name, history_entry)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
clear_skills_system_prompt_cache()
|
||||
return await get_custom_skill(skill_name)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -338,7 +337,6 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
|
||||
|
||||
logger.info(f"Skills configuration updated and saved to: {config_path}")
|
||||
reload_extensions_config()
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import APIRouter, Request
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -98,6 +99,7 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||
summary="Generate Follow-up Questions",
|
||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||
)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||
if not body.messages:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
"""Runs endpoints — create, stream, wait, cancel.
|
||||
|
||||
Implements the LangGraph Platform runs API on top of
|
||||
:class:`deerflow.agents.runs.RunManager` and
|
||||
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
||||
|
||||
SSE format is aligned with the LangGraph Platform protocol so that
|
||||
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
||||
works without modification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RunCreateRequest(BaseModel):
|
||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
status: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
multitask_strategy: str = "reject"
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
return RunResponse(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=record.status.value,
|
||||
metadata=record.metadata,
|
||||
kwargs=record.kwargs,
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
||||
"""Create a background run (returns immediately)."""
|
||||
record = await start_run(body, thread_id, request)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
The response includes a ``Content-Location`` header with the run's
|
||||
resource URL, matching the LangGraph Platform protocol. The
|
||||
``useStream`` React hook uses this to extract run metadata.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
# LangGraph Platform includes run metadata in this header.
|
||||
# The SDK uses a greedy regex to extract the run id from this path,
|
||||
# so it must point at the canonical run resource without extra suffixes.
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||
"""List all runs for a thread."""
|
||||
run_mgr = get_run_manager(request)
|
||||
records = await run_mgr.list_by_thread(thread_id)
|
||||
return [_record_to_response(r) for r in records]
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
"""Get details of a specific run."""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
@require_permission("runs", "cancel", owner_check=True)
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
||||
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
||||
) -> Response:
|
||||
"""Cancel a running or pending run.
|
||||
|
||||
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
||||
- wait=true: Block until the run fully stops, return 204
|
||||
- wait=false: Return immediately with 202
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||
if not cancelled:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
||||
)
|
||||
|
||||
if wait and record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return Response(status_code=204)
|
||||
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||
"""Join an existing run's SSE stream."""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def stream_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
||||
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
||||
):
|
||||
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
||||
|
||||
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
||||
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
||||
is present the run is cancelled first; the response then streams any
|
||||
remaining buffered events so the client observes a clean shutdown.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
||||
if action is not None:
|
||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||
if cancelled and wait and record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
return Response(status_code=204)
|
||||
|
||||
bridge = get_stream_bridge(request)
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Messages / Events / Token usage endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_thread_messages(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages for a thread (across all runs)."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||
"""Return displayable messages for a specific run."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_run_events(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
event_types: str | None = Query(default=None),
|
||||
limit: int = Query(default=500, le=2000),
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run (debug/audit)."""
|
||||
event_store = get_run_event_store(request)
|
||||
types = event_types.split(",") if event_types else None
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return {"thread_id": thread_id, **agg}
|
||||
@@ -0,0 +1,593 @@
|
||||
"""Thread CRUD, state, and history endpoints.
|
||||
|
||||
Combines the existing thread-local filesystem cleanup with LangGraph
|
||||
Platform-compatible thread management backed by the checkpointer.
|
||||
|
||||
Channel values returned in state responses are serialized through
|
||||
:func:`deerflow.runtime.serialization.serialize_channel_values` to
|
||||
ensure LangChain message objects are converted to JSON-safe dicts
|
||||
matching the LangGraph Platform wire format expected by the
|
||||
``useStream`` React hook.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response / request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ThreadDeleteResponse(BaseModel):
|
||||
"""Response model for thread cleanup."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ThreadResponse(BaseModel):
|
||||
"""Response model for a single thread."""
|
||||
|
||||
thread_id: str = Field(description="Unique thread identifier")
|
||||
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
|
||||
created_at: str = Field(default="", description="ISO timestamp")
|
||||
updated_at: str = Field(default="", description="ISO timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
|
||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||
|
||||
|
||||
class ThreadCreateRequest(BaseModel):
|
||||
"""Request body for creating a thread."""
|
||||
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
"""Request body for searching threads."""
|
||||
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||
status: str | None = Field(default=None, description="Filter by thread status")
|
||||
|
||||
|
||||
class ThreadStateResponse(BaseModel):
|
||||
"""Response model for thread state."""
|
||||
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||
|
||||
|
||||
class ThreadPatchRequest(BaseModel):
|
||||
"""Request body for patching thread metadata."""
|
||||
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
|
||||
|
||||
|
||||
class ThreadStateUpdateRequest(BaseModel):
|
||||
"""Request body for updating thread state (human-in-the-loop resume)."""
|
||||
|
||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||
|
||||
|
||||
class HistoryEntry(BaseModel):
|
||||
"""Single checkpoint history entry."""
|
||||
|
||||
checkpoint_id: str
|
||||
parent_checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
values: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
next: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ThreadHistoryRequest(BaseModel):
|
||||
"""Request body for checkpoint history."""
|
||||
|
||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||
before: str | None = Field(default=None, description="Cursor for pagination")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread."""
|
||||
path_manager = paths or get_paths()
|
||||
try:
|
||||
path_manager.delete_thread_dir(thread_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
# Not critical — thread data may not exist on disk
|
||||
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||
|
||||
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||
|
||||
|
||||
def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
"""Derive thread status from checkpoint metadata."""
|
||||
if checkpoint_tuple is None:
|
||||
return "idle"
|
||||
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
|
||||
|
||||
# Check for error in pending writes
|
||||
for pw in pending_writes:
|
||||
if len(pw) >= 2 and pw[1] == "__error__":
|
||||
return "error"
|
||||
|
||||
# Check for pending next tasks (indicates interrupt)
|
||||
tasks = getattr(checkpoint_tuple, "tasks", None)
|
||||
if tasks:
|
||||
return "interrupted"
|
||||
|
||||
return "idle"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread.
|
||||
|
||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
||||
and removes the thread_meta row from the configured ThreadMetaStore
|
||||
(sqlite or memory).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
await checkpointer.adelete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||
# so the deleted thread no longer appears in /threads/search.
|
||||
try:
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
await thread_meta_repo.delete(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("", response_model=ThreadResponse)
|
||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
||||
"""Create a new thread.
|
||||
|
||||
Writes a thread_meta record (so the thread appears in /threads/search)
|
||||
and an empty checkpoint (so state endpoints work immediately).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_meta_repo.get(thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
|
||||
# Write thread_meta so the thread appears in /threads/search immediately
|
||||
try:
|
||||
await thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
ckpt_metadata = {
|
||||
"step": -1,
|
||||
"source": "input",
|
||||
"writes": None,
|
||||
"parents": {},
|
||||
**body.metadata,
|
||||
"created_at": now,
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
metadata=body.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[ThreadResponse])
|
||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
||||
"""Search and list threads.
|
||||
|
||||
Delegates to the configured ThreadMetaStore implementation
|
||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
repo = get_thread_meta_repo(request)
|
||||
rows = await repo.search(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
return [
|
||||
ThreadResponse(
|
||||
thread_id=r["thread_id"],
|
||||
status=r.get("status", "idle"),
|
||||
created_at=r.get("created_at", ""),
|
||||
updated_at=r.get("updated_at", ""),
|
||||
metadata=r.get("metadata", {}),
|
||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||
interrupts={},
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
record = await thread_meta_repo.get(thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
try:
|
||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
# Re-read to get the merged metadata + refreshed updated_at
|
||||
record = await thread_meta_repo.get(thread_id) or record
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"""Get thread info.
|
||||
|
||||
Reads metadata from the ThreadMetaStore and derives the accurate
|
||||
execution status from the checkpointer. Falls back to the checkpointer
|
||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
||||
|
||||
# Derive accurate status from the checkpointer
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
||||
|
||||
if record is None and checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# If the thread exists in the checkpointer but not in thread_meta (e.g.
|
||||
# legacy data created before thread_meta adoption), synthesize a minimal
|
||||
# record from the checkpoint metadata.
|
||||
if record is None and checkpoint_tuple is not None:
|
||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": ckpt_meta.get("created_at", ""),
|
||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread.
|
||||
|
||||
Channel values are serialized to ensure LangChain message objects
|
||||
are converted to JSON-safe dicts.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
checkpoint_id = None
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
||||
if ckpt_config:
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
parent_checkpoint_id = None
|
||||
if parent_config:
|
||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
||||
|
||||
Writes a new checkpoint that merges *body.values* into the latest
|
||||
channel values, then syncs any updated ``title`` field through the
|
||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||
change immediately in both sqlite and memory backends.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
|
||||
# checkpoint_ns must be present in the config for aput — default to ""
|
||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||
# fetches the latest checkpoint for the thread.
|
||||
read_config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": "",
|
||||
}
|
||||
}
|
||||
if body.checkpoint_id:
|
||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# Work on mutable copies so we don't accidentally mutate cached objects.
|
||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||
|
||||
if body.values:
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = time.time()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
metadata["step"] = metadata.get("step", 0) + 1
|
||||
metadata["writes"] = {body.as_node: body.values}
|
||||
|
||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
||||
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
||||
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
||||
write_config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": "",
|
||||
}
|
||||
}
|
||||
try:
|
||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||
|
||||
new_checkpoint_id: str | None = None
|
||||
if isinstance(new_config, dict):
|
||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
||||
# reflects them immediately in both sqlite and memory backends.
|
||||
if body.values and "title" in body.values:
|
||||
new_title = body.values["title"]
|
||||
if new_title: # Skip empty strings and None
|
||||
try:
|
||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread.
|
||||
|
||||
Messages are read from the checkpointer's channel values (the
|
||||
authoritative source) and serialized via
|
||||
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
|
||||
Only the latest (first) checkpoint carries the ``messages`` key to
|
||||
avoid duplicating them across every entry.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
||||
if body.before:
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_latest_checkpoint = True
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||
parent_id = None
|
||||
if parent_config:
|
||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
# Build values from checkpoint channel_values
|
||||
values: dict[str, Any] = {}
|
||||
if title := channel_values.get("title"):
|
||||
values["title"] = title
|
||||
if thread_data := channel_values.get("thread_data"):
|
||||
values["thread_data"] = thread_data
|
||||
|
||||
# Attach messages from checkpointer only for the latest checkpoint
|
||||
if is_latest_checkpoint:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
|
||||
# Derive next tasks
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
|
||||
# Strip LangGraph internal keys from metadata
|
||||
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
||||
# Keep step for ordering context
|
||||
if "step" in metadata:
|
||||
user_meta["step"] = metadata["step"]
|
||||
|
||||
entries.append(
|
||||
HistoryEntry(
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=user_meta,
|
||||
values=values,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
next=next_tasks,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||
|
||||
return entries
|
||||
@@ -7,10 +7,9 @@ import stat
|
||||
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
delete_file_safe,
|
||||
@@ -56,6 +55,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
|
||||
|
||||
@router.post("", response_model=UploadResponse)
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def upload_files(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
@@ -65,69 +65,68 @@ async def upload_files(
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
uploaded_files = []
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
continue
|
||||
|
||||
try:
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
uploaded_files = []
|
||||
safe_filename = normalize_filename(file.filename)
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||
continue
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
try:
|
||||
content = await file.read()
|
||||
file_path = uploads_dir / safe_filename
|
||||
file_path.write_bytes(content)
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
continue
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
|
||||
try:
|
||||
safe_filename = normalize_filename(file.filename)
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||
continue
|
||||
if sandbox_id != "local":
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, content)
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
file_path = uploads_dir / safe_filename
|
||||
file_path.write_bytes(content)
|
||||
file_info = {
|
||||
"filename": safe_filename,
|
||||
"size": str(len(content)),
|
||||
"path": str(sandbox_uploads / safe_filename),
|
||||
"virtual_path": virtual_path,
|
||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||
}
|
||||
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||
|
||||
if sandbox_id != "local":
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, content)
|
||||
file_ext = file_path.suffix.lower()
|
||||
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||
md_path = await convert_file_to_markdown(file_path)
|
||||
if md_path:
|
||||
md_virtual_path = upload_virtual_path(md_path.name)
|
||||
|
||||
file_info = {
|
||||
"filename": safe_filename,
|
||||
"size": str(len(content)),
|
||||
"path": str(sandbox_uploads / safe_filename),
|
||||
"virtual_path": virtual_path,
|
||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||
}
|
||||
if sandbox_id != "local":
|
||||
_make_file_sandbox_writable(md_path)
|
||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||
|
||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||
file_info["markdown_file"] = md_path.name
|
||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||
file_info["markdown_virtual_path"] = md_virtual_path
|
||||
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||
md_path = await convert_file_to_markdown(file_path)
|
||||
if md_path:
|
||||
md_virtual_path = upload_virtual_path(md_path.name)
|
||||
uploaded_files.append(file_info)
|
||||
|
||||
if sandbox_id != "local":
|
||||
_make_file_sandbox_writable(md_path)
|
||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||
|
||||
file_info["markdown_file"] = md_path.name
|
||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||
file_info["markdown_virtual_path"] = md_virtual_path
|
||||
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
||||
|
||||
uploaded_files.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
@@ -137,25 +136,26 @@ async def upload_files(
|
||||
|
||||
|
||||
@router.get("/list", response_model=dict)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||
"""List all files in a thread's uploads directory."""
|
||||
with bind_request_actor_context(request):
|
||||
try:
|
||||
uploads_dir = get_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
result = list_files_in_dir(uploads_dir)
|
||||
enrich_file_listing(result, thread_id)
|
||||
try:
|
||||
uploads_dir = get_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
result = list_files_in_dir(uploads_dir)
|
||||
enrich_file_listing(result, thread_id)
|
||||
|
||||
# Gateway additionally includes the sandbox-relative path.
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
for f in result["files"]:
|
||||
f["path"] = str(sandbox_uploads / f["filename"])
|
||||
# Gateway additionally includes the sandbox-relative path.
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
for f in result["files"]:
|
||||
f["path"] = str(sandbox_uploads / f["filename"])
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{filename}")
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
||||
"""Delete a file from a thread's uploads directory."""
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
"""Run lifecycle service layer.
|
||||
|
||||
Centralizes the business logic for creating runs, formatting SSE
|
||||
frames, and consuming stream bridge events. Router modules
|
||||
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
ConflictError,
|
||||
DisconnectMode,
|
||||
RunManager,
|
||||
RunRecord,
|
||||
RunStatus,
|
||||
StreamBridge,
|
||||
UnsupportedStrategyError,
|
||||
run_agent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE formatting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
|
||||
"""Format a single SSE frame.
|
||||
|
||||
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
|
||||
This matches the LangGraph Platform wire format consumed by the
|
||||
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
|
||||
"""
|
||||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||
parts = [f"event: {event}", f"data: {payload}"]
|
||||
if event_id:
|
||||
parts.append(f"id: {event_id}")
|
||||
parts.append("")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input / config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
||||
"""Normalize the stream_mode parameter to a list.
|
||||
|
||||
Default matches what ``useStream`` expects: values + messages-tuple.
|
||||
"""
|
||||
if raw is None:
|
||||
return ["values"]
|
||||
if isinstance(raw, str):
|
||||
return [raw]
|
||||
return raw if raw else ["values"]
|
||||
|
||||
|
||||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Convert LangGraph Platform input format to LangChain state dict."""
|
||||
if raw_input is None:
|
||||
return {}
|
||||
messages = raw_input.get("messages")
|
||||
if messages and isinstance(messages, list):
|
||||
converted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "user"))
|
||||
content = msg.get("content", "")
|
||||
if role in ("user", "human"):
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
# TODO: handle other message types (system, ai, tool)
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
converted.append(msg)
|
||||
return {**raw_input, "messages": converted}
|
||||
return raw_input
|
||||
|
||||
|
||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
|
||||
|
||||
def resolve_agent_factory(assistant_id: str | None):
|
||||
"""Resolve the agent factory callable from config.
|
||||
|
||||
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
||||
injected into ``configurable`` — see :func:`build_run_config`. All
|
||||
``assistant_id`` values therefore map to the same factory; the routing
|
||||
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
|
||||
"""
|
||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||
|
||||
return make_lead_agent
|
||||
|
||||
|
||||
def build_run_config(
|
||||
thread_id: str,
|
||||
request_config: dict[str, Any] | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a RunnableConfig dict for the agent.
|
||||
|
||||
When *assistant_id* refers to a custom agent (anything other than
|
||||
``"lead_agent"`` / ``None``), the name is forwarded as
|
||||
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
|
||||
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||||
without it the agent silently runs as the default lead agent.
|
||||
|
||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||
identically.
|
||||
"""
|
||||
config: dict[str, Any] = {"recursion_limit": 100}
|
||||
if request_config:
|
||||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
||||
# pass thread-level data and rejects requests that include both
|
||||
# ``configurable`` and ``context``. If the caller already sends
|
||||
# ``context``, honour it and skip our own ``configurable`` dict.
|
||||
if "context" in request_config:
|
||||
if "configurable" in request_config:
|
||||
logger.warning(
|
||||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
||||
thread_id,
|
||||
list(request_config.get("configurable", {}).keys()),
|
||||
)
|
||||
config["context"] = request_config["context"]
|
||||
else:
|
||||
configurable = {"thread_id": thread_id}
|
||||
configurable.update(request_config.get("configurable", {}))
|
||||
config["configurable"] = configurable
|
||||
for k, v in request_config.items():
|
||||
if k not in ("configurable", "context"):
|
||||
config[k] = v
|
||||
else:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
# Inject custom agent name when the caller specified a non-default assistant.
|
||||
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
||||
if "agent_name" not in config["configurable"]:
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
config["configurable"]["agent_name"] = normalized
|
||||
if metadata:
|
||||
config.setdefault("metadata", {}).update(metadata)
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def start_run(
|
||||
body: Any,
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
) -> RunRecord:
|
||||
"""Create a RunRecord and launch the background agent task.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
body : RunCreateRequest
|
||||
The validated request body (typed as Any to avoid circular import
|
||||
with the router module that defines the Pydantic model).
|
||||
thread_id : str
|
||||
Target thread.
|
||||
request : Request
|
||||
FastAPI request — used to retrieve singletons from ``app.state``.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
run_ctx = get_run_context(request)
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||
if follow_up_to_run_id is None:
|
||||
run_store = get_run_store(request)
|
||||
try:
|
||||
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
||||
if recent_runs and recent_runs[0].get("status") == "success":
|
||||
follow_up_to_run_id = recent_runs[0]["run_id"]
|
||||
except Exception:
|
||||
pass # Don't block run creation
|
||||
|
||||
# Enrich base context with per-run field
|
||||
if follow_up_to_run_id:
|
||||
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
# Merge DeerFlow-specific context overrides into configurable.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
context = getattr(body, "context", None)
|
||||
if context:
|
||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
}
|
||||
configurable = config.setdefault("configurable", {})
|
||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
|
||||
|
||||
async def sse_consumer(
|
||||
bridge: StreamBridge,
|
||||
record: RunRecord,
|
||||
request: Request,
|
||||
run_mgr: RunManager,
|
||||
):
|
||||
"""Async generator that yields SSE frames from the bridge.
|
||||
|
||||
The ``finally`` block implements ``on_disconnect`` semantics:
|
||||
- ``cancel``: abort the background task on client disconnect.
|
||||
- ``continue``: let the task run; events are discarded.
|
||||
"""
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
try:
|
||||
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
if entry is HEARTBEAT_SENTINEL:
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
|
||||
if entry is END_SENTINEL:
|
||||
yield format_sse("end", None, event_id=entry.id or None)
|
||||
return
|
||||
|
||||
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
|
||||
|
||||
finally:
|
||||
if record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Gateway service layer."""
|
||||
|
||||
"""Compatibility package for app service submodules."""
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Runs app layer services."""
|
||||
|
||||
from app.infra.storage import StorageRunObserver
|
||||
from .input import (
|
||||
AdaptedRunRequest,
|
||||
RunSpecBuilder,
|
||||
UnsupportedRunFeatureError,
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||
|
||||
__all__ = [
|
||||
"AdaptedRunRequest",
|
||||
"AppRunCreateStore",
|
||||
"AppRunDeleteStore",
|
||||
"AppRunQueryStore",
|
||||
"RunSpecBuilder",
|
||||
"StorageRunObserver",
|
||||
"UnsupportedRunFeatureError",
|
||||
"adapt_create_run_request",
|
||||
"adapt_create_stream_request",
|
||||
"adapt_create_wait_request",
|
||||
"adapt_join_stream_request",
|
||||
"adapt_join_wait_request",
|
||||
]
|
||||
@@ -1,150 +0,0 @@
|
||||
"""Facade factory - assembles RunsFacade with dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.gateway.dependencies import get_checkpointer, get_stream_bridge
|
||||
from deerflow.runtime.runs.facade import RunsFacade
|
||||
from deerflow.runtime.runs.facade import RunsRuntime
|
||||
from deerflow.runtime.runs.internal.execution.supervisor import RunSupervisor
|
||||
from deerflow.runtime.runs.internal.planner import ExecutionPlanner
|
||||
from deerflow.runtime.runs.internal.registry import RunRegistry
|
||||
from deerflow.runtime.runs.internal.streams import RunStreamService
|
||||
from deerflow.runtime.runs.internal.wait import RunWaitService
|
||||
|
||||
from app.infra.storage import StorageRunObserver, ThreadMetaStorage
|
||||
from app.infra.storage.runs import RunDeleteRepository, RunReadRepository, RunWriteRepository
|
||||
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
|
||||
type AgentFactory = Callable[..., object]
|
||||
|
||||
|
||||
# Module-level singleton registry (shared across requests)
|
||||
_registry: RunRegistry | None = None
|
||||
_supervisor: RunSupervisor | None = None
|
||||
|
||||
|
||||
def _get_state(request: Request, attr: str, label: str):
|
||||
value = getattr(request.app.state, attr, None)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return value
|
||||
|
||||
|
||||
def get_registry() -> RunRegistry:
|
||||
"""Get or create singleton registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = RunRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def get_supervisor() -> RunSupervisor:
|
||||
"""Get or create singleton run supervisor."""
|
||||
global _supervisor
|
||||
if _supervisor is None:
|
||||
_supervisor = RunSupervisor()
|
||||
return _supervisor
|
||||
|
||||
|
||||
def resolve_agent_factory(assistant_id: str | None) -> AgentFactory:
|
||||
"""Resolve the agent factory callable from config."""
|
||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||
|
||||
return make_lead_agent
|
||||
|
||||
|
||||
def build_runs_facade(
|
||||
*,
|
||||
stream_bridge: "StreamBridge",
|
||||
checkpointer: object,
|
||||
store: object | None = None,
|
||||
run_read_repo: RunReadRepository | None = None,
|
||||
run_write_repo: RunWriteRepository | None = None,
|
||||
run_delete_repo: RunDeleteRepository | None = None,
|
||||
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||
run_event_store: object | None = None,
|
||||
) -> RunsFacade:
|
||||
"""
|
||||
Build RunsFacade with all dependencies.
|
||||
|
||||
Args:
|
||||
stream_bridge: StreamBridge instance
|
||||
checkpointer: LangGraph checkpointer
|
||||
store: Optional LangGraph runtime store
|
||||
run_read_repo: Optional run repository for durable reads
|
||||
run_write_repo: Optional run repository for durable writes
|
||||
run_delete_repo: Optional run repository for durable deletes
|
||||
thread_meta_storage: Optional thread metadata storage adapter
|
||||
|
||||
Returns:
|
||||
Configured RunsFacade instance
|
||||
"""
|
||||
registry = get_registry()
|
||||
planner = ExecutionPlanner()
|
||||
supervisor = get_supervisor()
|
||||
|
||||
stream_service = RunStreamService(stream_bridge)
|
||||
wait_service = RunWaitService(stream_service)
|
||||
query_store = AppRunQueryStore(run_read_repo) if run_read_repo else None
|
||||
create_store = (
|
||||
AppRunCreateStore(run_write_repo, thread_meta_storage=thread_meta_storage)
|
||||
if run_write_repo
|
||||
else None
|
||||
)
|
||||
delete_store = AppRunDeleteStore(run_delete_repo) if run_delete_repo else None
|
||||
|
||||
# Build storage observer if repositories provided
|
||||
storage_observer = None
|
||||
if run_write_repo or thread_meta_storage:
|
||||
storage_observer = StorageRunObserver(
|
||||
run_write_repo=run_write_repo,
|
||||
thread_meta_storage=thread_meta_storage,
|
||||
)
|
||||
|
||||
return RunsFacade(
|
||||
registry=registry,
|
||||
planner=planner,
|
||||
supervisor=supervisor,
|
||||
stream_service=stream_service,
|
||||
wait_service=wait_service,
|
||||
runtime=RunsRuntime(
|
||||
bridge=stream_bridge,
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
event_store=run_event_store,
|
||||
agent_factory_resolver=resolve_agent_factory,
|
||||
),
|
||||
observer=storage_observer,
|
||||
query_store=query_store,
|
||||
create_store=create_store,
|
||||
delete_store=delete_store,
|
||||
)
|
||||
|
||||
|
||||
def build_runs_facade_from_request(request: "Request") -> RunsFacade:
|
||||
"""
|
||||
Build RunsFacade from FastAPI request context.
|
||||
|
||||
Extracts dependencies from request.app.state.
|
||||
"""
|
||||
app_state = request.app.state
|
||||
|
||||
return build_runs_facade(
|
||||
stream_bridge=get_stream_bridge(request),
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=getattr(request.app.state, "store", None),
|
||||
run_read_repo=getattr(app_state, "run_read_repo", None),
|
||||
run_write_repo=getattr(app_state, "run_write_repo", None),
|
||||
run_delete_repo=getattr(app_state, "run_delete_repo", None),
|
||||
thread_meta_storage=getattr(app_state, "thread_meta_storage", None),
|
||||
run_event_store=getattr(app_state, "run_event_store", None),
|
||||
)
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Input adapters for app-owned runs entrypoints."""
|
||||
|
||||
from .request_adapter import (
|
||||
AdaptedRunRequest,
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from .spec_builder import RunSpecBuilder, UnsupportedRunFeatureError
|
||||
|
||||
__all__ = [
|
||||
"AdaptedRunRequest",
|
||||
"RunSpecBuilder",
|
||||
"UnsupportedRunFeatureError",
|
||||
"adapt_create_run_request",
|
||||
"adapt_create_stream_request",
|
||||
"adapt_create_wait_request",
|
||||
"adapt_join_stream_request",
|
||||
"adapt_join_wait_request",
|
||||
]
|
||||
@@ -1,127 +0,0 @@
|
||||
"""App-owned request adapter for runs entrypoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deerflow.runtime.stream_bridge import JSONValue
|
||||
from deerflow.runtime.runs.types import RunIntent
|
||||
|
||||
type RequestBody = dict[str, JSONValue]
|
||||
type RequestQuery = dict[str, str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdaptedRunRequest:
|
||||
"""
|
||||
统一的内部请求 DTO.
|
||||
|
||||
路由层只负责提取 path/query/body,适配器负责转成稳定内部结构。
|
||||
"""
|
||||
|
||||
intent: RunIntent
|
||||
thread_id: str | None
|
||||
run_id: str | None
|
||||
body: RequestBody
|
||||
headers: dict[str, str]
|
||||
query: RequestQuery
|
||||
|
||||
@property
|
||||
def last_event_id(self) -> str | None:
|
||||
"""Extract Last-Event-ID from headers."""
|
||||
return self.headers.get("last-event-id") or self.headers.get("Last-Event-ID")
|
||||
|
||||
@property
|
||||
def is_stateless(self) -> bool:
|
||||
"""Check if this is a stateless request."""
|
||||
return self.thread_id is None
|
||||
|
||||
|
||||
def adapt_create_run_request(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
body: RequestBody,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt POST /threads/{thread_id}/runs or POST /runs."""
|
||||
return AdaptedRunRequest(
|
||||
intent="create_background",
|
||||
thread_id=thread_id,
|
||||
run_id=None,
|
||||
body=body,
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_create_stream_request(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
body: RequestBody,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt POST /threads/{thread_id}/runs/stream or POST /runs/stream."""
|
||||
return AdaptedRunRequest(
|
||||
intent="create_and_stream",
|
||||
thread_id=thread_id,
|
||||
run_id=None,
|
||||
body=body,
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_create_wait_request(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
body: RequestBody,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt POST /threads/{thread_id}/runs/wait or POST /runs/wait."""
|
||||
return AdaptedRunRequest(
|
||||
intent="create_and_wait",
|
||||
thread_id=thread_id,
|
||||
run_id=None,
|
||||
body=body,
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_join_stream_request(
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt GET /threads/{thread_id}/runs/{run_id}/stream."""
|
||||
return AdaptedRunRequest(
|
||||
intent="join_stream",
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
body={},
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_join_wait_request(
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt GET /threads/{thread_id}/runs/{run_id}/join."""
|
||||
return AdaptedRunRequest(
|
||||
intent="join_wait",
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
body={},
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
@@ -1,254 +0,0 @@
|
||||
"""App-owned RunSpec builder."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.runtime.runs.types import CheckpointRequest, RunScope, RunSpec
|
||||
from deerflow.runtime.stream_bridge import JSONValue
|
||||
|
||||
from .request_adapter import AdaptedRunRequest
|
||||
|
||||
type JSONMapping = dict[str, JSONValue]
|
||||
type GraphInput = dict[str, object]
|
||||
type RunnableConfigDict = dict[str, object]
|
||||
|
||||
|
||||
class UnsupportedRunFeatureError(ValueError):
|
||||
"""Raised when a phase1-unsupported feature is requested."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunSpecBuilder:
|
||||
"""
|
||||
Build RunSpec from AdaptedRunRequest.
|
||||
|
||||
Phase 1 rules:
|
||||
1. messages-tuple normalized to messages
|
||||
2. enqueue not supported
|
||||
3. rollback not supported
|
||||
4. after_seconds not supported
|
||||
5. stream_resumable accepted
|
||||
6. stateless auto-generates temporary thread
|
||||
"""
|
||||
|
||||
# Phase 1 unsupported features
|
||||
UNSUPPORTED_MULTITASK_STRATEGIES = {"enqueue"}
|
||||
UNSUPPORTED_ACTIONS = {"rollback"}
|
||||
|
||||
# Default stream modes
|
||||
DEFAULT_STREAM_MODES = ["values", "messages"]
|
||||
CONTEXT_CONFIGURABLE_KEYS = frozenset({
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
})
|
||||
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
|
||||
@staticmethod
|
||||
def _as_json_mapping(value: JSONValue | None) -> JSONMapping | None:
|
||||
return value if isinstance(value, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _as_string_list(value: JSONValue | None) -> list[str] | None:
|
||||
if not isinstance(value, list):
|
||||
return None
|
||||
return [item for item in value if isinstance(item, str)]
|
||||
|
||||
def build(self, request: AdaptedRunRequest) -> RunSpec:
|
||||
"""Build RunSpec from adapted request."""
|
||||
body = request.body
|
||||
|
||||
# Validate phase1 constraints
|
||||
self._validate_constraints(body)
|
||||
|
||||
# Build scope
|
||||
scope = self._build_scope(request)
|
||||
|
||||
# Normalize stream modes
|
||||
stream_modes = self._normalize_stream_modes(body.get("stream_mode"))
|
||||
|
||||
# Build checkpoint request
|
||||
checkpoint_request = self._build_checkpoint_request(body)
|
||||
|
||||
config = self._build_runnable_config(
|
||||
thread_id=scope.thread_id,
|
||||
request_config=self._as_json_mapping(body.get("config")),
|
||||
metadata=self._as_json_mapping(body.get("metadata")),
|
||||
assistant_id=body.get("assistant_id"),
|
||||
context=self._as_json_mapping(body.get("context")),
|
||||
)
|
||||
|
||||
return RunSpec(
|
||||
intent=request.intent,
|
||||
scope=scope,
|
||||
assistant_id=body.get("assistant_id") if isinstance(body.get("assistant_id"), str) else None,
|
||||
input=self._normalize_input(self._as_json_mapping(body.get("input"))),
|
||||
command=self._as_json_mapping(body.get("command")),
|
||||
runnable_config=config,
|
||||
context=self._as_json_mapping(body.get("context")),
|
||||
metadata=self._as_json_mapping(body.get("metadata")) or {},
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=bool(body.get("stream_subgraphs", False)),
|
||||
stream_resumable=bool(body.get("stream_resumable", False)),
|
||||
on_disconnect=body.get("on_disconnect", "cancel") if body.get("on_disconnect") in {"cancel", "continue"} else "cancel",
|
||||
on_completion=body.get("on_completion", "keep") if body.get("on_completion") in {"delete", "keep"} else "keep",
|
||||
multitask_strategy=body.get("multitask_strategy", "reject") if body.get("multitask_strategy") in {"reject", "interrupt"} else "reject",
|
||||
interrupt_before="*" if body.get("interrupt_before") == "*" else self._as_string_list(body.get("interrupt_before")),
|
||||
interrupt_after="*" if body.get("interrupt_after") == "*" else self._as_string_list(body.get("interrupt_after")),
|
||||
checkpoint_request=checkpoint_request,
|
||||
follow_up_to_run_id=body.get("follow_up_to_run_id") if isinstance(body.get("follow_up_to_run_id"), str) else None,
|
||||
webhook=body.get("webhook") if isinstance(body.get("webhook"), str) else None,
|
||||
feedback_keys=self._as_string_list(body.get("feedback_keys")),
|
||||
)
|
||||
|
||||
def _validate_constraints(self, body: JSONMapping) -> None:
|
||||
"""Validate phase1 constraints, raise UnsupportedRunFeatureError if violated."""
|
||||
# Check multitask_strategy
|
||||
strategy = body.get("multitask_strategy", "reject")
|
||||
if strategy in self.UNSUPPORTED_MULTITASK_STRATEGIES:
|
||||
raise UnsupportedRunFeatureError(
|
||||
f"multitask_strategy '{strategy}' is not supported in phase1. "
|
||||
f"Supported: reject, interrupt"
|
||||
)
|
||||
|
||||
# Check for rollback action
|
||||
command = self._as_json_mapping(body.get("command")) or {}
|
||||
if command.get("action") in self.UNSUPPORTED_ACTIONS:
|
||||
raise UnsupportedRunFeatureError(
|
||||
f"action '{command.get('action')}' is not supported in phase1"
|
||||
)
|
||||
|
||||
# Check for after_seconds
|
||||
if body.get("after_seconds") is not None:
|
||||
raise UnsupportedRunFeatureError("after_seconds is not supported in phase1")
|
||||
|
||||
def _build_scope(self, request: AdaptedRunRequest) -> RunScope:
|
||||
"""Build RunScope from request."""
|
||||
if request.is_stateless:
|
||||
# Stateless: generate temporary thread
|
||||
return RunScope(
|
||||
kind="stateless",
|
||||
thread_id=str(uuid.uuid4()),
|
||||
temporary=True,
|
||||
)
|
||||
else:
|
||||
assert request.thread_id is not None
|
||||
return RunScope(
|
||||
kind="stateful",
|
||||
thread_id=request.thread_id,
|
||||
temporary=False,
|
||||
)
|
||||
|
||||
def _normalize_stream_modes(self, stream_mode: JSONValue | None) -> list[str]:
|
||||
"""Normalize stream_mode to list, convert messages-tuple to messages."""
|
||||
if stream_mode is None:
|
||||
return self.DEFAULT_STREAM_MODES.copy()
|
||||
|
||||
if isinstance(stream_mode, str):
|
||||
modes = [stream_mode]
|
||||
elif isinstance(stream_mode, list):
|
||||
modes = [mode for mode in stream_mode if isinstance(mode, str)]
|
||||
else:
|
||||
return self.DEFAULT_STREAM_MODES.copy()
|
||||
|
||||
return ["messages" if m == "messages-tuple" else m for m in modes]
|
||||
|
||||
def _build_checkpoint_request(self, body: JSONMapping) -> CheckpointRequest | None:
|
||||
"""Build CheckpointRequest if checkpoint data is provided."""
|
||||
checkpoint_id = body.get("checkpoint_id")
|
||||
checkpoint = self._as_json_mapping(body.get("checkpoint"))
|
||||
|
||||
if not isinstance(checkpoint_id, str) and checkpoint is None:
|
||||
return None
|
||||
|
||||
return CheckpointRequest(
|
||||
checkpoint_id=checkpoint_id if isinstance(checkpoint_id, str) else None,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _normalize_input(self, raw_input: JSONMapping | None) -> GraphInput | None:
|
||||
"""Convert HTTP-friendly message dicts into LangChain message objects."""
|
||||
if raw_input is None:
|
||||
return None
|
||||
|
||||
messages = raw_input.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
return raw_input
|
||||
|
||||
converted: list[object] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "user"))
|
||||
content = msg.get("content", "")
|
||||
if role in ("user", "human"):
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
converted.append(msg)
|
||||
return {**raw_input, "messages": converted}
|
||||
|
||||
def _build_runnable_config(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
request_config: JSONMapping | None,
|
||||
metadata: JSONMapping | None,
|
||||
assistant_id: str | None,
|
||||
context: JSONMapping | None,
|
||||
) -> RunnableConfigDict:
|
||||
"""Build RunnableConfig from request payload and app-side rules."""
|
||||
config: RunnableConfigDict = {"recursion_limit": 100}
|
||||
|
||||
if request_config:
|
||||
if "context" in request_config:
|
||||
config["context"] = request_config["context"]
|
||||
else:
|
||||
configurable = {"thread_id": thread_id}
|
||||
raw_configurable = request_config.get("configurable")
|
||||
if isinstance(raw_configurable, dict):
|
||||
configurable.update(raw_configurable)
|
||||
config["configurable"] = configurable
|
||||
|
||||
for key, value in request_config.items():
|
||||
if key not in ("configurable", "context"):
|
||||
config[key] = value
|
||||
else:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
configurable = config.get("configurable")
|
||||
if (
|
||||
assistant_id
|
||||
and assistant_id != self.DEFAULT_ASSISTANT_ID
|
||||
and isinstance(configurable, dict)
|
||||
and "agent_name" not in configurable
|
||||
):
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(
|
||||
f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization."
|
||||
)
|
||||
configurable["agent_name"] = normalized
|
||||
|
||||
if metadata:
|
||||
existing_metadata = config.get("metadata")
|
||||
if isinstance(existing_metadata, dict):
|
||||
existing_metadata.update(metadata)
|
||||
else:
|
||||
config["metadata"] = dict(metadata)
|
||||
|
||||
if context and isinstance(configurable, dict):
|
||||
for key in self.CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
return config
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Compatibility wrapper for the app-owned storage observer."""
|
||||
|
||||
from app.infra.storage.runs import StorageRunObserver
|
||||
|
||||
__all__ = ["StorageRunObserver"]
|
||||
@@ -1,11 +0,0 @@
|
||||
"""App-owned runs store adapters."""
|
||||
|
||||
from .create_store import AppRunCreateStore
|
||||
from .delete_store import AppRunDeleteStore
|
||||
from .query_store import AppRunQueryStore
|
||||
|
||||
__all__ = [
|
||||
"AppRunCreateStore",
|
||||
"AppRunDeleteStore",
|
||||
"AppRunQueryStore",
|
||||
]
|
||||
@@ -1,38 +0,0 @@
|
||||
"""App-owned durable run creation adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.runtime.runs.store import RunCreateStore
|
||||
from deerflow.runtime.runs.types import RunRecord
|
||||
|
||||
from app.infra.storage import ThreadMetaStorage
|
||||
from app.infra.storage.runs import RunWriteRepository
|
||||
|
||||
|
||||
class AppRunCreateStore(RunCreateStore):
|
||||
"""Write the initial durable row for a newly created run."""
|
||||
|
||||
def __init__(self, repo: RunWriteRepository, thread_meta_storage: ThreadMetaStorage | None = None) -> None:
|
||||
self._repo = repo
|
||||
self._thread_meta_storage = thread_meta_storage
|
||||
|
||||
async def create_run(self, record: RunRecord) -> None:
|
||||
await self._repo.create(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=str(record.status),
|
||||
metadata=record.metadata,
|
||||
follow_up_to_run_id=record.follow_up_to_run_id,
|
||||
created_at=record.created_at,
|
||||
)
|
||||
if self._thread_meta_storage is not None and record.assistant_id:
|
||||
thread = await self._thread_meta_storage.ensure_thread(
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
)
|
||||
if thread.assistant_id != record.assistant_id:
|
||||
await self._thread_meta_storage.sync_thread_assistant_id(
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
"""App-owned durable run deletion adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.runtime.runs.store import RunDeleteStore
|
||||
|
||||
from app.infra.storage.runs import RunDeleteRepository
|
||||
|
||||
|
||||
class AppRunDeleteStore(RunDeleteStore):
|
||||
"""Delete durable run rows via the app storage adapter."""
|
||||
|
||||
def __init__(self, repo: RunDeleteRepository) -> None:
|
||||
self._repo = repo
|
||||
|
||||
async def delete_run(self, run_id: str) -> bool:
|
||||
return await self._repo.delete(run_id)
|
||||
@@ -1,47 +0,0 @@
|
||||
"""App-owned durable run query adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.runtime.runs.store import RunQueryStore
|
||||
from deerflow.runtime.runs.types import RunRecord, RunStatus
|
||||
|
||||
from app.infra.storage.runs import RunReadRepository, RunRow
|
||||
|
||||
|
||||
class AppRunQueryStore(RunQueryStore):
|
||||
"""Map app-side durable run rows into harness RunRecord DTOs."""
|
||||
|
||||
def __init__(self, repo: RunReadRepository) -> None:
|
||||
self._repo = repo
|
||||
|
||||
async def get_run(self, run_id: str) -> RunRecord | None:
|
||||
row = await self._repo.get(run_id)
|
||||
if row is None:
|
||||
return None
|
||||
return self._to_run_record(row)
|
||||
|
||||
async def list_runs(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
) -> list[RunRecord]:
|
||||
rows = await self._repo.list_by_thread(thread_id, limit=limit)
|
||||
return [self._to_run_record(row) for row in rows]
|
||||
|
||||
def _to_run_record(self, row: RunRow) -> RunRecord:
|
||||
return RunRecord(
|
||||
run_id=row["run_id"],
|
||||
thread_id=row["thread_id"],
|
||||
assistant_id=row.get("assistant_id"),
|
||||
status=RunStatus(row.get("status", "pending")),
|
||||
temporary=False,
|
||||
multitask_strategy=row.get("multitask_strategy", "reject"),
|
||||
metadata=row.get("metadata", {}),
|
||||
follow_up_to_run_id=row.get("follow_up_to_run_id"),
|
||||
created_at=row.get("created_at", ""),
|
||||
updated_at=row.get("updated_at", ""),
|
||||
started_at=row.get("started_at"),
|
||||
ended_at=row.get("ended_at"),
|
||||
error=row.get("error"),
|
||||
)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Shared utility helpers for the Gateway layer."""
|
||||
|
||||
|
||||
def sanitize_log_param(value: str) -> str:
|
||||
"""Strip control characters to prevent log injection."""
|
||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||
@@ -1 +0,0 @@
|
||||
"""Application-owned infrastructure adapters and wiring."""
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Run event store backends owned by app infrastructure."""
|
||||
|
||||
from .factory import build_run_event_store
|
||||
from .jsonl_store import JsonlRunEventStore
|
||||
|
||||
__all__ = ["JsonlRunEventStore", "build_run_event_store"]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Factory for app-owned run event store backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.infra.storage import AppRunEventStore
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
from .jsonl_store import JsonlRunEventStore
|
||||
|
||||
|
||||
def build_run_event_store(session_factory: async_sessionmaker[AsyncSession]) -> AppRunEventStore | JsonlRunEventStore:
|
||||
"""Build the run event store selected by app configuration."""
|
||||
|
||||
config = get_app_config().run_events
|
||||
if config.backend == "db":
|
||||
return AppRunEventStore(session_factory)
|
||||
if config.backend == "jsonl":
|
||||
return JsonlRunEventStore(
|
||||
base_dir=Path(config.jsonl_base_dir),
|
||||
)
|
||||
raise ValueError(f"Unsupported run event backend: {config.backend}")
|
||||
@@ -1,210 +0,0 @@
|
||||
"""JSONL run event store backend owned by app infrastructure."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from collections.abc import Iterable
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class JsonlRunEventStore:
|
||||
"""Append-only JSONL implementation of the runs RunEventStore protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: Path | str = ".deer-flow/run-events",
|
||||
) -> None:
|
||||
self._base_dir = Path(base_dir)
|
||||
self._locks: dict[str, asyncio.Lock] = {}
|
||||
self._locks_guard = asyncio.Lock()
|
||||
|
||||
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[dict[str, Any]]] = {}
|
||||
for event in events:
|
||||
grouped.setdefault(str(event["thread_id"]), []).append(event)
|
||||
|
||||
records_by_thread: dict[str, list[dict[str, Any]]] = {}
|
||||
for thread_id, thread_events in grouped.items():
|
||||
async with await self._thread_lock(thread_id):
|
||||
records_by_thread[thread_id] = self._append_thread_events(thread_id, thread_events)
|
||||
|
||||
indexes = {thread_id: 0 for thread_id in records_by_thread}
|
||||
ordered: list[dict[str, Any]] = []
|
||||
for event in events:
|
||||
thread_id = str(event["thread_id"])
|
||||
index = indexes[thread_id]
|
||||
ordered.append(records_by_thread[thread_id][index])
|
||||
indexes[thread_id] = index + 1
|
||||
return ordered
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
events = [event for event in await self._read_thread_events(thread_id) if event.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||
return events[-limit:]
|
||||
if after_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||
return events[:limit]
|
||||
return events[-limit:]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict[str, Any]]:
|
||||
event_type_set = set(event_types or [])
|
||||
events = [
|
||||
event
|
||||
for event in await self._read_thread_events(thread_id)
|
||||
if event.get("run_id") == run_id and (not event_type_set or event.get("event_type") in event_type_set)
|
||||
]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
events = [
|
||||
event
|
||||
for event in await self._read_thread_events(thread_id)
|
||||
if event.get("run_id") == run_id and event.get("category") == "message"
|
||||
]
|
||||
if before_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||
return events[-limit:]
|
||||
if after_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||
return events[:limit]
|
||||
return events[-limit:]
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
return len(await self.list_messages(thread_id, limit=10**9))
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
async with await self._thread_lock(thread_id):
|
||||
count = len(self._read_thread_events_sync(thread_id))
|
||||
shutil.rmtree(self._thread_dir(thread_id), ignore_errors=True)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
async with await self._thread_lock(thread_id):
|
||||
events = self._read_thread_events_sync(thread_id)
|
||||
kept = [event for event in events if event.get("run_id") != run_id]
|
||||
deleted = len(events) - len(kept)
|
||||
if deleted:
|
||||
self._write_thread_events(thread_id, kept)
|
||||
return deleted
|
||||
|
||||
async def _thread_lock(self, thread_id: str) -> asyncio.Lock:
|
||||
async with self._locks_guard:
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._locks[thread_id] = lock
|
||||
return lock
|
||||
|
||||
def _append_thread_events(self, thread_id: str, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
seq = self._read_seq(thread_id)
|
||||
records: list[dict[str, Any]] = []
|
||||
with self._events_path(thread_id).open("a", encoding="utf-8") as file:
|
||||
for event in events:
|
||||
seq += 1
|
||||
record = self._normalize_event(event, seq=seq)
|
||||
file.write(json.dumps(record, ensure_ascii=False, default=str))
|
||||
file.write("\n")
|
||||
records.append(record)
|
||||
self._write_seq(thread_id, seq)
|
||||
return records
|
||||
|
||||
def _normalize_event(self, event: dict[str, Any], *, seq: int) -> dict[str, Any]:
|
||||
created_at = event.get("created_at")
|
||||
if isinstance(created_at, datetime):
|
||||
created_at_value = created_at.isoformat()
|
||||
elif created_at:
|
||||
created_at_value = str(created_at)
|
||||
else:
|
||||
created_at_value = datetime.now(UTC).isoformat()
|
||||
|
||||
return {
|
||||
"thread_id": str(event["thread_id"]),
|
||||
"run_id": str(event["run_id"]),
|
||||
"seq": seq,
|
||||
"event_type": str(event["event_type"]),
|
||||
"category": str(event["category"]),
|
||||
"content": event.get("content", ""),
|
||||
"metadata": dict(event.get("metadata") or {}),
|
||||
"created_at": created_at_value,
|
||||
}
|
||||
|
||||
async def _read_thread_events(self, thread_id: str) -> list[dict[str, Any]]:
|
||||
async with await self._thread_lock(thread_id):
|
||||
return self._read_thread_events_sync(thread_id)
|
||||
|
||||
def _read_thread_events_sync(self, thread_id: str) -> list[dict[str, Any]]:
|
||||
path = self._events_path(thread_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
events.append(json.loads(stripped))
|
||||
return events
|
||||
|
||||
def _write_thread_events(self, thread_id: str, events: Iterable[dict[str, Any]]) -> None:
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = self._events_path(thread_id).with_suffix(".jsonl.tmp")
|
||||
with temp_path.open("w", encoding="utf-8") as file:
|
||||
for event in events:
|
||||
file.write(json.dumps(event, ensure_ascii=False, default=str))
|
||||
file.write("\n")
|
||||
temp_path.replace(self._events_path(thread_id))
|
||||
|
||||
def _read_seq(self, thread_id: str) -> int:
|
||||
path = self._seq_path(thread_id)
|
||||
if not path.exists():
|
||||
return 0
|
||||
try:
|
||||
return int(path.read_text(encoding="utf-8").strip() or "0")
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
def _write_seq(self, thread_id: str, seq: int) -> None:
|
||||
self._seq_path(thread_id).write_text(str(seq), encoding="utf-8")
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
return self._base_dir / "threads" / thread_id
|
||||
|
||||
def _events_path(self, thread_id: str) -> Path:
|
||||
return self._thread_dir(thread_id) / "events.jsonl"
|
||||
|
||||
def _seq_path(self, thread_id: str) -> Path:
|
||||
return self._thread_dir(thread_id) / "seq"
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Storage-facing adapters owned by the app layer."""
|
||||
|
||||
from .run_events import AppRunEventStore
|
||||
from .runs import FeedbackStoreAdapter, RunStoreAdapter, StorageRunObserver
|
||||
from .thread_meta import ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||
|
||||
__all__ = [
|
||||
"AppRunEventStore",
|
||||
"FeedbackStoreAdapter",
|
||||
"RunStoreAdapter",
|
||||
"StorageRunObserver",
|
||||
"ThreadMetaStorage",
|
||||
"ThreadMetaStoreAdapter",
|
||||
]
|
||||
@@ -1,166 +0,0 @@
|
||||
"""App-owned adapter from runs callbacks to storage run event repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from store.repositories import RunEvent, RunEventCreate, build_run_event_repository, build_thread_meta_repository
|
||||
|
||||
from deerflow.runtime.actor_context import get_actor_context
|
||||
|
||||
|
||||
class AppRunEventStore:
|
||||
"""Implements the harness RunEventStore protocol using storage repositories."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
denied = {str(event["thread_id"]) for event in events if not await self._thread_visible(str(event["thread_id"]))}
|
||||
if denied:
|
||||
raise PermissionError(f"actor is not allowed to append events for thread(s): {', '.join(sorted(denied))}")
|
||||
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.append_batch([_event_create_from_dict(event) for event in events])
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return []
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.list_messages(
|
||||
thread_id,
|
||||
limit=limit,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return []
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.list_events(thread_id, run_id, event_types=event_types, limit=limit)
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return []
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.list_messages_by_run(
|
||||
thread_id,
|
||||
run_id,
|
||||
limit=limit,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return 0
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
return await repo.count_messages(thread_id)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return 0
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
repo = build_run_event_repository(session)
|
||||
count = await repo.delete_by_thread(thread_id)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return 0
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
repo = build_run_event_repository(session)
|
||||
count = await repo.delete_by_run(thread_id, run_id)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
return count
|
||||
|
||||
async def _thread_visible(self, thread_id: str) -> bool:
|
||||
actor = get_actor_context()
|
||||
if actor is None or actor.user_id is None:
|
||||
return True
|
||||
|
||||
async with self._session_factory() as session:
|
||||
thread_repo = build_thread_meta_repository(session)
|
||||
thread = await thread_repo.get_thread_meta(thread_id)
|
||||
|
||||
if thread is None:
|
||||
return True
|
||||
return thread.user_id is None or thread.user_id == actor.user_id
|
||||
|
||||
|
||||
def _event_create_from_dict(event: dict[str, Any]) -> RunEventCreate:
|
||||
created_at = event.get("created_at")
|
||||
return RunEventCreate(
|
||||
thread_id=str(event["thread_id"]),
|
||||
run_id=str(event["run_id"]),
|
||||
event_type=str(event["event_type"]),
|
||||
category=str(event["category"]),
|
||||
content=event.get("content", ""),
|
||||
metadata=dict(event.get("metadata") or {}),
|
||||
created_at=datetime.fromisoformat(created_at) if isinstance(created_at, str) else created_at,
|
||||
)
|
||||
|
||||
|
||||
def _event_to_dict(event: RunEvent) -> dict[str, Any]:
|
||||
return {
|
||||
"thread_id": event.thread_id,
|
||||
"run_id": event.run_id,
|
||||
"event_type": event.event_type,
|
||||
"category": event.category,
|
||||
"content": event.content,
|
||||
"metadata": event.metadata,
|
||||
"seq": event.seq,
|
||||
"created_at": event.created_at.isoformat(),
|
||||
}
|
||||
@@ -1,515 +0,0 @@
|
||||
"""Run lifecycle persistence adapters owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, TypedDict, Unpack, cast
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from store.repositories import FeedbackCreate, Run, RunCreate, build_feedback_repository, build_run_repository
|
||||
|
||||
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
from deerflow.runtime.runs.observer import LifecycleEventType, RunLifecycleEvent, RunObserver
|
||||
from deerflow.runtime.stream_bridge import JSONValue
|
||||
|
||||
from .thread_meta import ThreadMetaStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunCreateFields(TypedDict, total=False):
|
||||
status: str
|
||||
created_at: str
|
||||
started_at: str
|
||||
ended_at: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, JSONValue]
|
||||
kwargs: dict[str, JSONValue]
|
||||
|
||||
|
||||
class RunStatusUpdateFields(TypedDict, total=False):
|
||||
started_at: str
|
||||
ended_at: str
|
||||
metadata: dict[str, JSONValue]
|
||||
|
||||
|
||||
class RunCompletionFields(TypedDict, total=False):
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_tokens: int
|
||||
llm_call_count: int
|
||||
lead_agent_tokens: int
|
||||
subagent_tokens: int
|
||||
middleware_tokens: int
|
||||
message_count: int
|
||||
last_ai_message: str | None
|
||||
first_human_message: str | None
|
||||
error: str | None
|
||||
|
||||
|
||||
class RunRow(TypedDict, total=False):
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
status: str
|
||||
multitask_strategy: str
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, JSONValue]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
started_at: str | None
|
||||
ended_at: str | None
|
||||
error: str | None
|
||||
|
||||
|
||||
class RunReadRepository(Protocol):
|
||||
"""Protocol for durable run queries."""
|
||||
|
||||
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: ...
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | object = AUTO,
|
||||
) -> list[RunRow]: ...
|
||||
|
||||
|
||||
class RunWriteRepository(Protocol):
|
||||
"""Protocol for durable run writes."""
|
||||
|
||||
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: ...
|
||||
async def update_status(
|
||||
self,
|
||||
run_id: str,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunStatusUpdateFields],
|
||||
) -> None: ...
|
||||
async def set_error(self, run_id: str, error: str) -> None: ...
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunCompletionFields],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class RunDeleteRepository(Protocol):
|
||||
"""Protocol for durable run deletion."""
|
||||
|
||||
async def delete(self, run_id: str) -> bool: ...
|
||||
|
||||
|
||||
class _RepositoryContext:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
build_repo: Callable[[AsyncSession], object],
|
||||
*,
|
||||
commit: bool,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._build_repo = build_repo
|
||||
self._commit = commit
|
||||
self._session: AsyncSession | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self._session_factory()
|
||||
return self._build_repo(self._session)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if self._commit:
|
||||
if exc_type is None:
|
||||
await self._session.commit()
|
||||
else:
|
||||
await self._session.rollback()
|
||||
finally:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
def _run_to_row(row: Run) -> RunRow:
|
||||
return {
|
||||
"run_id": row.run_id,
|
||||
"thread_id": row.thread_id,
|
||||
"assistant_id": row.assistant_id,
|
||||
"user_id": row.user_id,
|
||||
"status": row.status,
|
||||
"model_name": row.model_name,
|
||||
"multitask_strategy": row.multitask_strategy,
|
||||
"follow_up_to_run_id": row.follow_up_to_run_id,
|
||||
"metadata": cast(dict[str, JSONValue], row.metadata),
|
||||
"kwargs": cast(dict[str, JSONValue], row.kwargs),
|
||||
"created_at": row.created_time.isoformat(),
|
||||
"updated_at": row.updated_time.isoformat() if row.updated_time else "",
|
||||
"total_input_tokens": row.total_input_tokens,
|
||||
"total_output_tokens": row.total_output_tokens,
|
||||
"total_tokens": row.total_tokens,
|
||||
"llm_call_count": row.llm_call_count,
|
||||
"lead_agent_tokens": row.lead_agent_tokens,
|
||||
"subagent_tokens": row.subagent_tokens,
|
||||
"middleware_tokens": row.middleware_tokens,
|
||||
"message_count": row.message_count,
|
||||
"first_human_message": row.first_human_message,
|
||||
"last_ai_message": row.last_ai_message,
|
||||
"error": row.error,
|
||||
}
|
||||
|
||||
|
||||
class FeedbackStoreAdapter:
|
||||
"""Expose feedback route semantics on top of storage package repositories."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
effective_user_id = user_id if user_id is not None else owner_id
|
||||
async with self._transaction() as repo:
|
||||
row = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=rating,
|
||||
user_id=effective_user_id,
|
||||
message_id=message_id,
|
||||
comment=comment,
|
||||
)
|
||||
)
|
||||
return _feedback_to_dict(row)
|
||||
|
||||
async def get(self, feedback_id: str) -> dict[str, object] | None:
|
||||
async with self._read() as repo:
|
||||
row = await repo.get_feedback(feedback_id)
|
||||
return _feedback_to_dict(row) if row is not None else None
|
||||
|
||||
async def list_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
async with self._read() as repo:
|
||||
rows = await repo.list_feedback_by_run(run_id)
|
||||
filtered = [row for row in rows if row.thread_id == thread_id]
|
||||
if user_id is not None:
|
||||
filtered = [row for row in filtered if row.user_id == user_id]
|
||||
return [_feedback_to_dict(row) for row in filtered][:limit]
|
||||
|
||||
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict[str, object]]:
|
||||
async with self._read() as repo:
|
||||
rows = await repo.list_feedback_by_thread(thread_id)
|
||||
return [_feedback_to_dict(row) for row in rows][:limit]
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict[str, object]:
|
||||
rows = await self.list_by_run(thread_id, run_id)
|
||||
positive = sum(1 for row in rows if row["rating"] == 1)
|
||||
negative = sum(1 for row in rows if row["rating"] == -1)
|
||||
return {"run_id": run_id, "total": len(rows), "positive": positive, "negative": negative}
|
||||
|
||||
async def delete(self, feedback_id: str) -> bool:
|
||||
async with self._transaction() as repo:
|
||||
return await repo.delete_feedback(feedback_id)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str,
|
||||
comment: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
async with self._transaction() as repo:
|
||||
rows = await repo.list_feedback_by_run(run_id)
|
||||
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||
feedback_id = existing.feedback_id if existing is not None else str(uuid.uuid4())
|
||||
if existing is not None:
|
||||
await repo.delete_feedback(existing.feedback_id)
|
||||
row = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id=feedback_id,
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=rating,
|
||||
user_id=user_id,
|
||||
comment=comment,
|
||||
)
|
||||
)
|
||||
return _feedback_to_dict(row)
|
||||
|
||||
async def delete_by_run(self, *, thread_id: str, run_id: str, user_id: str) -> bool:
|
||||
async with self._transaction() as repo:
|
||||
rows = await repo.list_feedback_by_run(run_id)
|
||||
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||
if existing is None:
|
||||
return False
|
||||
return await repo.delete_feedback(existing.feedback_id)
|
||||
|
||||
async def list_by_thread_grouped(self, thread_id: str, *, user_id: str) -> dict[str, dict[str, object]]:
|
||||
rows = await self.list_by_thread(thread_id)
|
||||
return {
|
||||
row["run_id"]: row
|
||||
for row in rows
|
||||
if row["user_id"] == user_id
|
||||
}
|
||||
|
||||
def _read(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=False)
|
||||
|
||||
def _transaction(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=True)
|
||||
|
||||
|
||||
def _feedback_to_dict(row) -> dict[str, object]:
|
||||
return {
|
||||
"feedback_id": row.feedback_id,
|
||||
"run_id": row.run_id,
|
||||
"thread_id": row.thread_id,
|
||||
"user_id": row.user_id,
|
||||
"owner_id": row.user_id,
|
||||
"message_id": row.message_id,
|
||||
"rating": row.rating,
|
||||
"comment": row.comment,
|
||||
"created_at": row.created_time.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class RunStoreAdapter:
|
||||
"""Expose runs facade storage semantics on top of storage package repositories."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None:
|
||||
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.get")
|
||||
async with self._read() as repo:
|
||||
row = await repo.get_run(run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if effective_user_id is not None and row.user_id != effective_user_id:
|
||||
return None
|
||||
return _run_to_row(row)
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | object = AUTO,
|
||||
) -> list[RunRow]:
|
||||
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.list_by_thread")
|
||||
async with self._read() as repo:
|
||||
rows = await repo.list_runs_by_thread(thread_id, limit=limit, offset=0)
|
||||
if effective_user_id is not None:
|
||||
rows = [row for row in rows if row.user_id == effective_user_id]
|
||||
return [_run_to_row(row) for row in rows]
|
||||
|
||||
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None:
|
||||
metadata = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("metadata") or {}))
|
||||
run_kwargs = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("kwargs") or {}))
|
||||
effective_user_id = resolve_user_id(kwargs.get("user_id", AUTO), method_name="RunStoreAdapter.create")
|
||||
async with self._transaction() as repo:
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=kwargs.get("assistant_id"),
|
||||
user_id=effective_user_id,
|
||||
status=kwargs.get("status", "pending"),
|
||||
metadata=dict(metadata),
|
||||
kwargs=dict(run_kwargs),
|
||||
follow_up_to_run_id=kwargs.get("follow_up_to_run_id"),
|
||||
)
|
||||
)
|
||||
|
||||
async def delete(self, run_id: str, *, user_id: str | None | object = AUTO) -> bool:
|
||||
async with self._transaction() as repo:
|
||||
existing = await repo.get_run(run_id)
|
||||
if existing is None:
|
||||
return False
|
||||
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.delete")
|
||||
if effective_user_id is not None and existing.user_id != effective_user_id:
|
||||
return False
|
||||
await repo.delete_run(run_id)
|
||||
return True
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
run_id: str,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunStatusUpdateFields],
|
||||
) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_run_status(run_id, status)
|
||||
|
||||
async def set_error(self, run_id: str, error: str) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_run_status(run_id, "error", error=error)
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunCompletionFields],
|
||||
) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_run_completion(
|
||||
run_id,
|
||||
status=status,
|
||||
total_input_tokens=kwargs.get("total_input_tokens", 0),
|
||||
total_output_tokens=kwargs.get("total_output_tokens", 0),
|
||||
total_tokens=kwargs.get("total_tokens", 0),
|
||||
llm_call_count=kwargs.get("llm_call_count", 0),
|
||||
lead_agent_tokens=kwargs.get("lead_agent_tokens", 0),
|
||||
subagent_tokens=kwargs.get("subagent_tokens", 0),
|
||||
middleware_tokens=kwargs.get("middleware_tokens", 0),
|
||||
message_count=kwargs.get("message_count", 0),
|
||||
last_ai_message=kwargs.get("last_ai_message"),
|
||||
first_human_message=kwargs.get("first_human_message"),
|
||||
error=kwargs.get("error"),
|
||||
)
|
||||
|
||||
def _read(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_run_repository, commit=False)
|
||||
|
||||
def _transaction(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_run_repository, commit=True)
|
||||
|
||||
|
||||
class StorageRunObserver(RunObserver):
|
||||
"""Persist run lifecycle state into app-owned repositories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_write_repo: RunWriteRepository | None = None,
|
||||
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||
) -> None:
|
||||
self._run_write_repo = run_write_repo
|
||||
self._thread_meta_storage = thread_meta_storage
|
||||
|
||||
async def on_event(self, event: RunLifecycleEvent) -> None:
|
||||
try:
|
||||
await self._dispatch(event)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"StorageRunObserver failed to persist event %s for run %s",
|
||||
event.event_type,
|
||||
event.run_id,
|
||||
)
|
||||
|
||||
async def _dispatch(self, event: RunLifecycleEvent) -> None:
|
||||
handlers = {
|
||||
LifecycleEventType.RUN_STARTED: self._handle_run_started,
|
||||
LifecycleEventType.RUN_COMPLETED: self._handle_run_completed,
|
||||
LifecycleEventType.RUN_FAILED: self._handle_run_failed,
|
||||
LifecycleEventType.RUN_CANCELLED: self._handle_run_cancelled,
|
||||
LifecycleEventType.THREAD_STATUS_UPDATED: self._handle_thread_status,
|
||||
}
|
||||
|
||||
handler = handlers.get(event.event_type)
|
||||
if handler:
|
||||
await handler(event)
|
||||
|
||||
async def _handle_run_started(self, event: RunLifecycleEvent) -> None:
|
||||
if self._run_write_repo:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="running",
|
||||
started_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
|
||||
async def _handle_run_completed(self, event: RunLifecycleEvent) -> None:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
if self._run_write_repo:
|
||||
completion_data = payload.get("completion_data")
|
||||
if isinstance(completion_data, dict):
|
||||
await self._run_write_repo.update_run_completion(
|
||||
run_id=event.run_id,
|
||||
status="success",
|
||||
**cast(RunCompletionFields, completion_data),
|
||||
)
|
||||
else:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="success",
|
||||
ended_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
|
||||
if self._thread_meta_storage and "title" in payload:
|
||||
await self._thread_meta_storage.sync_thread_title(
|
||||
thread_id=event.thread_id,
|
||||
title=payload["title"],
|
||||
)
|
||||
|
||||
async def _handle_run_failed(self, event: RunLifecycleEvent) -> None:
|
||||
if self._run_write_repo:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
error = payload.get("error", "Unknown error")
|
||||
completion_data = payload.get("completion_data")
|
||||
if isinstance(completion_data, dict):
|
||||
await self._run_write_repo.update_run_completion(
|
||||
run_id=event.run_id,
|
||||
status="error",
|
||||
error=str(error),
|
||||
**cast(RunCompletionFields, completion_data),
|
||||
)
|
||||
else:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="error",
|
||||
ended_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
await self._run_write_repo.set_error(run_id=event.run_id, error=str(error))
|
||||
|
||||
async def _handle_run_cancelled(self, event: RunLifecycleEvent) -> None:
|
||||
if self._run_write_repo:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
completion_data = payload.get("completion_data")
|
||||
if isinstance(completion_data, dict):
|
||||
await self._run_write_repo.update_run_completion(
|
||||
run_id=event.run_id,
|
||||
status="interrupted",
|
||||
**cast(RunCompletionFields, completion_data),
|
||||
)
|
||||
else:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="interrupted",
|
||||
ended_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
|
||||
async def _handle_thread_status(self, event: RunLifecycleEvent) -> None:
|
||||
if self._thread_meta_storage:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
status = payload.get("status", "idle")
|
||||
await self._thread_meta_storage.sync_thread_status(
|
||||
thread_id=event.thread_id,
|
||||
status=status,
|
||||
)
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Thread metadata storage adapter owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from store.repositories import build_thread_meta_repository
|
||||
from store.repositories.contracts import (
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||
|
||||
|
||||
class ThreadMetaStoreAdapter:
|
||||
"""Use storage package thread repositories with per-call sessions."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
async with self._transaction() as repo:
|
||||
return await repo.create_thread_meta(data)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
async with self._read() as repo:
|
||||
return await repo.get_thread_meta(thread_id)
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_thread_meta(
|
||||
thread_id,
|
||||
assistant_id=assistant_id,
|
||||
display_name=display_name,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.delete_thread(thread_id)
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
async with self._read() as repo:
|
||||
return await repo.search_threads(
|
||||
metadata=metadata,
|
||||
status=status,
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
def _read(self):
|
||||
return _ThreadMetaRepositoryContext(self._session_factory, commit=False)
|
||||
|
||||
def _transaction(self):
|
||||
return _ThreadMetaRepositoryContext(self._session_factory, commit=True)
|
||||
|
||||
|
||||
class _ThreadMetaRepositoryContext:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, commit: bool) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._commit = commit
|
||||
self._session: AsyncSession | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self._session_factory()
|
||||
return build_thread_meta_repository(self._session)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if self._commit:
|
||||
if exc_type is None:
|
||||
await self._session.commit()
|
||||
else:
|
||||
await self._session.rollback()
|
||||
finally:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class ThreadMetaStorage:
|
||||
"""App-facing adapter around the storage thread metadata contract."""
|
||||
|
||||
def __init__(self, repo: ThreadMetaRepositoryProtocol) -> None:
|
||||
self._repo = repo
|
||||
|
||||
async def get_thread(self, thread_id: str, *, user_id: str | None | object = AUTO) -> ThreadMeta | None:
|
||||
thread = await self._repo.get_thread_meta(thread_id)
|
||||
if thread is None:
|
||||
return None
|
||||
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.get_thread")
|
||||
if effective_user_id is not None and thread.user_id != effective_user_id:
|
||||
return None
|
||||
return thread
|
||||
|
||||
async def ensure_thread(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
user_id: str | None | object = AUTO,
|
||||
) -> ThreadMeta:
|
||||
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.ensure_thread")
|
||||
existing = await self.get_thread(thread_id, user_id=effective_user_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
return await self._repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=effective_user_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
)
|
||||
|
||||
async def ensure_thread_running(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ThreadMeta | None:
|
||||
existing = await self._repo.get_thread_meta(thread_id)
|
||||
if existing is None:
|
||||
return await self._repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status="running",
|
||||
metadata=metadata or {},
|
||||
)
|
||||
)
|
||||
|
||||
await self._repo.update_thread_meta(thread_id, status="running")
|
||||
return await self._repo.get_thread_meta(thread_id)
|
||||
|
||||
async def sync_thread_title(self, *, thread_id: str, title: str) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, display_name=title)
|
||||
|
||||
async def sync_thread_assistant_id(self, *, thread_id: str, assistant_id: str) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, assistant_id=assistant_id)
|
||||
|
||||
async def sync_thread_status(self, *, thread_id: str, status: str) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, status=status)
|
||||
|
||||
async def sync_thread_metadata(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, metadata=metadata)
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._repo.delete_thread(thread_id)
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None | object = AUTO,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
normalized_status = status.strip() if status is not None else None
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.search_threads")
|
||||
normalized_user_id = resolved_user_id.strip() if resolved_user_id is not None else None
|
||||
normalized_assistant_id = (
|
||||
assistant_id.strip() if assistant_id is not None else None
|
||||
)
|
||||
|
||||
return await self._repo.search_threads(
|
||||
metadata=metadata,
|
||||
status=normalized_status or None,
|
||||
user_id=normalized_user_id or None,
|
||||
assistant_id=normalized_assistant_id or None,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ThreadMetaStorage", "ThreadMetaStoreAdapter"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""App-owned stream bridge adapters and factory."""
|
||||
|
||||
from .factory import build_stream_bridge
|
||||
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||
|
||||
__all__ = ["MemoryStreamBridge", "RedisStreamBridge", "build_stream_bridge"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Concrete stream bridge adapters owned by the app layer."""
|
||||
|
||||
from .memory import MemoryStreamBridge
|
||||
from .redis import RedisStreamBridge
|
||||
|
||||
__all__ = ["MemoryStreamBridge", "RedisStreamBridge"]
|
||||
@@ -1,450 +0,0 @@
|
||||
"""In-memory stream bridge implementation owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from deerflow.runtime.stream_bridge import (
|
||||
CANCELLED_SENTINEL,
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
TERMINAL_STATES,
|
||||
ResumeResult,
|
||||
StreamBridge,
|
||||
StreamEvent,
|
||||
StreamStatus,
|
||||
)
|
||||
from deerflow.runtime.stream_bridge.exceptions import (
|
||||
BridgeClosedError,
|
||||
StreamCapacityExceededError,
|
||||
StreamTerminatedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RunStream:
|
||||
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||
events: list[StreamEvent] = field(default_factory=list)
|
||||
id_to_offset: dict[str, int] = field(default_factory=dict)
|
||||
start_offset: int = 0
|
||||
current_bytes: int = 0
|
||||
seq: int = 0
|
||||
status: StreamStatus = StreamStatus.ACTIVE
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
last_publish_at: float | None = None
|
||||
ended_at: float | None = None
|
||||
subscriber_count: int = 0
|
||||
last_subscribe_at: float | None = None
|
||||
awaiting_input: bool = False
|
||||
awaiting_since: float | None = None
|
||||
|
||||
|
||||
class MemoryStreamBridge(StreamBridge):
|
||||
"""Per-run in-memory event log implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_events_per_stream: int = 256,
|
||||
max_bytes_per_stream: int = 10 * 1024 * 1024,
|
||||
max_active_streams: int = 1000,
|
||||
stream_eviction_policy: Literal["reject", "lru"] = "lru",
|
||||
terminal_retention_ttl: float = 300.0,
|
||||
active_no_publish_timeout: float = 600.0,
|
||||
orphan_timeout: float = 60.0,
|
||||
max_stream_age: float = 86400.0,
|
||||
hitl_extended_timeout: float = 7200.0,
|
||||
cleanup_interval: float = 30.0,
|
||||
queue_maxsize: int | None = None,
|
||||
) -> None:
|
||||
if queue_maxsize is not None:
|
||||
max_events_per_stream = queue_maxsize
|
||||
|
||||
self._max_events = max_events_per_stream
|
||||
self._max_bytes = max_bytes_per_stream
|
||||
self._max_streams = max_active_streams
|
||||
self._eviction_policy = stream_eviction_policy
|
||||
self._terminal_ttl = terminal_retention_ttl
|
||||
self._active_timeout = active_no_publish_timeout
|
||||
self._orphan_timeout = orphan_timeout
|
||||
self._max_age = max_stream_age
|
||||
self._hitl_timeout = hitl_extended_timeout
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._streams: dict[str, _RunStream] = {}
|
||||
self._registry_lock = asyncio.Lock()
|
||||
self._closed = False
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(
|
||||
"MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)",
|
||||
self._max_events,
|
||||
self._max_bytes,
|
||||
self._max_streams,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._registry_lock:
|
||||
self._closed = True
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
for stream in self._streams.values():
|
||||
async with stream.condition:
|
||||
stream.status = StreamStatus.CLOSED
|
||||
stream.condition.notify_all()
|
||||
|
||||
self._streams.clear()
|
||||
logger.info("MemoryStreamBridge closed")
|
||||
|
||||
async def _get_or_create_stream(self, run_id: str) -> _RunStream:
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is not None:
|
||||
return stream
|
||||
|
||||
async with self._registry_lock:
|
||||
if self._closed:
|
||||
raise BridgeClosedError("Stream bridge is closed")
|
||||
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is not None:
|
||||
return stream
|
||||
|
||||
if len(self._streams) >= self._max_streams:
|
||||
if self._eviction_policy == "reject":
|
||||
raise StreamCapacityExceededError(
|
||||
f"Max {self._max_streams} active streams reached"
|
||||
)
|
||||
evicted = self._evict_oldest_terminal()
|
||||
if evicted is None:
|
||||
raise StreamCapacityExceededError("All streams active, cannot evict")
|
||||
logger.info("Evicted stream %s to make room", evicted)
|
||||
|
||||
stream = _RunStream()
|
||||
self._streams[run_id] = stream
|
||||
logger.debug("Created stream for run %s", run_id)
|
||||
return stream
|
||||
|
||||
def _evict_oldest_terminal(self) -> str | None:
|
||||
oldest_run_id: str | None = None
|
||||
oldest_ended_at: float = float("inf")
|
||||
for run_id, stream in self._streams.items():
|
||||
if stream.status in TERMINAL_STATES and stream.ended_at is not None:
|
||||
if stream.ended_at < oldest_ended_at:
|
||||
oldest_ended_at = stream.ended_at
|
||||
oldest_run_id = run_id
|
||||
if oldest_run_id is not None:
|
||||
del self._streams[oldest_run_id]
|
||||
return oldest_run_id
|
||||
return None
|
||||
|
||||
def _next_id(self, stream: _RunStream) -> str:
|
||||
stream.seq += 1
|
||||
return f"{int(time.time() * 1000)}-{stream.seq}"
|
||||
|
||||
def _estimate_size(self, event: StreamEvent) -> int:
|
||||
base = len(event.id) + len(event.event) + 100
|
||||
if event.data is None:
|
||||
return base
|
||||
if isinstance(event.data, str):
|
||||
return base + len(event.data)
|
||||
if isinstance(event.data, (dict, list)):
|
||||
try:
|
||||
return base + len(json.dumps(event.data, default=str))
|
||||
except (TypeError, ValueError):
|
||||
return base + 200
|
||||
return base + 50
|
||||
|
||||
def _evict_overflow(self, stream: _RunStream) -> None:
|
||||
while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes:
|
||||
if not stream.events:
|
||||
break
|
||||
evicted = stream.events.pop(0)
|
||||
stream.id_to_offset.pop(evicted.id, None)
|
||||
stream.current_bytes -= self._estimate_size(evicted)
|
||||
stream.start_offset += 1
|
||||
|
||||
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||
stream = await self._get_or_create_stream(run_id)
|
||||
async with stream.condition:
|
||||
if stream.status != StreamStatus.ACTIVE:
|
||||
raise StreamTerminatedError(
|
||||
f"Cannot publish to {stream.status.value} stream"
|
||||
)
|
||||
|
||||
entry = StreamEvent(id=self._next_id(stream), event=event, data=data)
|
||||
absolute_offset = stream.start_offset + len(stream.events)
|
||||
stream.events.append(entry)
|
||||
stream.id_to_offset[entry.id] = absolute_offset
|
||||
stream.current_bytes += self._estimate_size(entry)
|
||||
stream.last_publish_at = time.monotonic()
|
||||
self._evict_overflow(stream)
|
||||
stream.condition.notify_all()
|
||||
return entry.id
|
||||
|
||||
async def publish_end(self, run_id: str) -> str:
|
||||
return await self.publish_terminal(run_id, StreamStatus.ENDED)
|
||||
|
||||
async def publish_terminal(
|
||||
self,
|
||||
run_id: str,
|
||||
kind: StreamStatus,
|
||||
data: Any = None,
|
||||
) -> str:
|
||||
if kind not in TERMINAL_STATES:
|
||||
raise ValueError(f"Invalid terminal kind: {kind}")
|
||||
|
||||
stream = await self._get_or_create_stream(run_id)
|
||||
async with stream.condition:
|
||||
if stream.status != StreamStatus.ACTIVE:
|
||||
for evt in reversed(stream.events):
|
||||
if evt.event in ("end", "cancel", "error", "dead_letter"):
|
||||
return evt.id
|
||||
return ""
|
||||
|
||||
event_name = {
|
||||
StreamStatus.ENDED: "end",
|
||||
StreamStatus.CANCELLED: "cancel",
|
||||
StreamStatus.ERRORED: "error",
|
||||
}[kind]
|
||||
entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data)
|
||||
absolute_offset = stream.start_offset + len(stream.events)
|
||||
stream.events.append(entry)
|
||||
stream.id_to_offset[entry.id] = absolute_offset
|
||||
stream.current_bytes += self._estimate_size(entry)
|
||||
stream.status = kind
|
||||
stream.ended_at = time.monotonic()
|
||||
stream.awaiting_input = False
|
||||
stream.condition.notify_all()
|
||||
logger.debug("Stream %s terminal: %s", run_id, kind.value)
|
||||
return entry.id
|
||||
|
||||
async def cancel(self, run_id: str) -> None:
|
||||
await self.publish_terminal(run_id, StreamStatus.CANCELLED)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
stream = await self._get_or_create_stream(run_id)
|
||||
resume = self._resolve_resume_point(stream, last_event_id)
|
||||
next_offset = resume.next_offset
|
||||
|
||||
async with stream.condition:
|
||||
stream.subscriber_count += 1
|
||||
stream.last_subscribe_at = time.monotonic()
|
||||
|
||||
try:
|
||||
while True:
|
||||
entry_to_yield: StreamEvent | None = None
|
||||
sentinel_to_yield: StreamEvent | None = None
|
||||
should_return = False
|
||||
should_wait = False
|
||||
|
||||
async with stream.condition:
|
||||
if self._closed or stream.status == StreamStatus.CLOSED:
|
||||
sentinel_to_yield = CANCELLED_SENTINEL
|
||||
should_return = True
|
||||
elif next_offset < stream.start_offset:
|
||||
next_offset = stream.start_offset
|
||||
else:
|
||||
local_index = next_offset - stream.start_offset
|
||||
if 0 <= local_index < len(stream.events):
|
||||
entry_to_yield = stream.events[local_index]
|
||||
next_offset += 1
|
||||
if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"):
|
||||
should_return = True
|
||||
elif stream.status in TERMINAL_STATES:
|
||||
sentinel_to_yield = END_SENTINEL
|
||||
should_return = True
|
||||
else:
|
||||
should_wait = True
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
stream.condition.wait(),
|
||||
timeout=heartbeat_interval,
|
||||
)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
if sentinel_to_yield is not None:
|
||||
yield sentinel_to_yield
|
||||
if should_return:
|
||||
return
|
||||
continue
|
||||
|
||||
if entry_to_yield is not None:
|
||||
yield entry_to_yield
|
||||
if should_return:
|
||||
return
|
||||
continue
|
||||
|
||||
if should_wait:
|
||||
async with stream.condition:
|
||||
local_index = next_offset - stream.start_offset
|
||||
has_events = 0 <= local_index < len(stream.events)
|
||||
is_terminal = stream.status in TERMINAL_STATES
|
||||
if not has_events and not is_terminal:
|
||||
yield HEARTBEAT_SENTINEL
|
||||
|
||||
finally:
|
||||
async with stream.condition:
|
||||
stream.subscriber_count = max(0, stream.subscriber_count - 1)
|
||||
|
||||
async def mark_awaiting_input(self, run_id: str) -> None:
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is None:
|
||||
return
|
||||
async with stream.condition:
|
||||
if stream.status == StreamStatus.ACTIVE:
|
||||
stream.awaiting_input = True
|
||||
stream.awaiting_since = time.monotonic()
|
||||
logger.debug("Stream %s marked as awaiting input", run_id)
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
await self._do_cleanup(run_id, "manual")
|
||||
|
||||
async def _do_cleanup(self, run_id: str, reason: str) -> None:
|
||||
async with self._registry_lock:
|
||||
stream = self._streams.pop(run_id, None)
|
||||
if stream is not None:
|
||||
async with stream.condition:
|
||||
stream.status = StreamStatus.CLOSED
|
||||
stream.condition.notify_all()
|
||||
logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason)
|
||||
|
||||
async def _mark_dead_letter(self, run_id: str, reason: str) -> None:
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is None:
|
||||
return
|
||||
async with stream.condition:
|
||||
if stream.status != StreamStatus.ACTIVE:
|
||||
return
|
||||
entry = StreamEvent(
|
||||
id=self._next_id(stream),
|
||||
event="dead_letter",
|
||||
data={"reason": reason, "timestamp": time.time()},
|
||||
)
|
||||
absolute_offset = stream.start_offset + len(stream.events)
|
||||
stream.events.append(entry)
|
||||
stream.id_to_offset[entry.id] = absolute_offset
|
||||
stream.current_bytes += self._estimate_size(entry)
|
||||
stream.status = StreamStatus.ERRORED
|
||||
stream.ended_at = time.monotonic()
|
||||
stream.condition.notify_all()
|
||||
logger.warning("Stream %s marked as dead letter: %s", run_id, reason)
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
now = time.monotonic()
|
||||
to_cleanup: list[tuple[str, str]] = []
|
||||
to_mark_dead: list[tuple[str, str]] = []
|
||||
|
||||
async with self._registry_lock:
|
||||
for run_id, stream in list(self._streams.items()):
|
||||
if now - stream.created_at > self._max_age:
|
||||
to_cleanup.append((run_id, "max_age_exceeded"))
|
||||
continue
|
||||
|
||||
if stream.status == StreamStatus.ACTIVE:
|
||||
timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout
|
||||
last_activity = stream.last_publish_at or stream.created_at
|
||||
if now - last_activity > timeout:
|
||||
to_mark_dead.append((run_id, "no_publish_timeout"))
|
||||
continue
|
||||
|
||||
if stream.status in TERMINAL_STATES and stream.ended_at:
|
||||
if stream.subscriber_count > 0:
|
||||
continue
|
||||
last_sub = stream.last_subscribe_at or stream.ended_at
|
||||
if now - last_sub > self._orphan_timeout:
|
||||
to_cleanup.append((run_id, "orphan"))
|
||||
continue
|
||||
if now - stream.ended_at > self._terminal_ttl:
|
||||
to_cleanup.append((run_id, "ttl_expired"))
|
||||
|
||||
for run_id, reason in to_mark_dead:
|
||||
await self._mark_dead_letter(run_id, reason)
|
||||
for run_id, reason in to_cleanup:
|
||||
await self._do_cleanup(run_id, reason)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE)
|
||||
terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES)
|
||||
total_events = sum(len(s.events) for s in self._streams.values())
|
||||
total_bytes = sum(s.current_bytes for s in self._streams.values())
|
||||
total_subs = sum(s.subscriber_count for s in self._streams.values())
|
||||
return {
|
||||
"total_streams": len(self._streams),
|
||||
"active_streams": active,
|
||||
"terminal_streams": terminal,
|
||||
"total_events": total_events,
|
||||
"total_bytes": total_bytes,
|
||||
"total_subscribers": total_subs,
|
||||
"closed": self._closed,
|
||||
}
|
||||
|
||||
def _resolve_resume_point(
|
||||
self,
|
||||
stream: _RunStream,
|
||||
last_event_id: str | None,
|
||||
) -> ResumeResult:
|
||||
if last_event_id is None:
|
||||
return ResumeResult(next_offset=stream.start_offset, status="fresh")
|
||||
if last_event_id in stream.id_to_offset:
|
||||
return ResumeResult(
|
||||
next_offset=stream.id_to_offset[last_event_id] + 1,
|
||||
status="resumed",
|
||||
)
|
||||
|
||||
parts = last_event_id.split("-")
|
||||
if len(parts) != 2:
|
||||
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||
try:
|
||||
event_ts = int(parts[0])
|
||||
_event_seq = int(parts[1])
|
||||
except ValueError:
|
||||
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||
|
||||
if stream.events:
|
||||
try:
|
||||
oldest_parts = stream.events[0].id.split("-")
|
||||
oldest_ts = int(oldest_parts[0])
|
||||
if event_ts < oldest_ts:
|
||||
return ResumeResult(
|
||||
next_offset=stream.start_offset,
|
||||
status="evicted",
|
||||
gap_count=stream.start_offset,
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return ResumeResult(next_offset=stream.start_offset, status="unknown")
|
||||
|
||||
|
||||
__all__ = ["MemoryStreamBridge"]
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Redis-backed stream bridge placeholder owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.stream_bridge import StreamBridge, StreamEvent
|
||||
|
||||
|
||||
class RedisStreamBridge(StreamBridge):
|
||||
"""Reserved app-owned Redis implementation.
|
||||
|
||||
Phase 1 intentionally keeps Redis out of the harness package. The concrete
|
||||
implementation will live here once cross-process streaming is introduced.
|
||||
"""
|
||||
|
||||
def __init__(self, *, redis_url: str) -> None:
|
||||
self._redis_url = redis_url
|
||||
|
||||
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
|
||||
async def publish_end(self, run_id: str) -> str:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
@@ -1,50 +0,0 @@
|
||||
"""App-owned stream bridge factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
|
||||
from deerflow.config.stream_bridge_config import get_stream_bridge_config
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_stream_bridge(config=None) -> AbstractAsyncContextManager[StreamBridge]:
|
||||
"""Build the configured app-owned stream bridge."""
|
||||
return _build_stream_bridge_impl(config)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _build_stream_bridge_impl(config=None) -> AsyncIterator[StreamBridge]:
|
||||
if config is None:
|
||||
config = get_stream_bridge_config()
|
||||
|
||||
if config is None or config.type == "memory":
|
||||
maxsize = config.queue_maxsize if config is not None else 256
|
||||
bridge = MemoryStreamBridge(queue_maxsize=maxsize)
|
||||
await bridge.start()
|
||||
logger.info("Stream bridge initialised: memory (queue_maxsize=%d)", maxsize)
|
||||
try:
|
||||
yield bridge
|
||||
finally:
|
||||
await bridge.close()
|
||||
return
|
||||
|
||||
if config.type == "redis":
|
||||
if not config.redis_url:
|
||||
raise ValueError("Redis stream bridge requires redis_url")
|
||||
bridge = RedisStreamBridge(redis_url=config.redis_url)
|
||||
await bridge.start()
|
||||
logger.info("Stream bridge initialised: redis (%s)", config.redis_url)
|
||||
try:
|
||||
yield bridge
|
||||
finally:
|
||||
await bridge.close()
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown stream bridge type: {config.type!r}")
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Entry point for running the Gateway API via `python app/main.py`.
|
||||
|
||||
Useful for IDE debugging (e.g., PyCharm / VS Code debug configurations).
|
||||
Equivalent to: PYTHONPATH=. uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||
"""
|
||||
|
||||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"app.gateway.app:app",
|
||||
host="0.0.0.0",
|
||||
port=8001,
|
||||
reload=True,
|
||||
)
|
||||
@@ -1,314 +0,0 @@
|
||||
# app.plugins Design Overview
|
||||
|
||||
This document describes the current role of `backend/app/plugins`, its plugin design contract, dependency boundaries, and how the current `auth` plugin provides services with minimal intrusion into the host application.
|
||||
|
||||
## 1. Overall Role
|
||||
|
||||
`app.plugins` is the application-side plugin boundary.
|
||||
|
||||
Its purpose is not to implement a generic plugin marketplace. Instead, it provides a clear boundary inside `app` for separable business capabilities, so that a capability can:
|
||||
|
||||
1. carry its own domain model, runtime state, and adapters inside the plugin
|
||||
2. interact with the host application only through a limited set of seams
|
||||
3. remain replaceable, removable, and extensible over time
|
||||
|
||||
The only real plugin currently implemented under this directory is [`auth`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth).
|
||||
|
||||
The current direction is not “put all logic into app”. It is:
|
||||
|
||||
1. the host application owns unified bootstrap, shared infrastructure, and top-level router assembly
|
||||
2. each plugin owns its own business contract, persistence definitions, runtime state, and outward-facing adapters
|
||||
|
||||
## 2. Plugin Design Contract
|
||||
|
||||
### 2.1 A plugin should carry its own implementation
|
||||
|
||||
The primary contract visible in the current codebase is:
|
||||
|
||||
A plugin’s own ORM, runtime, domain, and adapters should be implemented inside the plugin itself. Core business behavior should not be scattered into unrelated external modules.
|
||||
|
||||
The `auth` plugin already follows that pattern with a fairly complete internal structure:
|
||||
|
||||
1. `domain`
|
||||
- config, errors, JWT, password logic, domain models, service
|
||||
2. `storage`
|
||||
- plugin-owned ORM models, repository contracts, and repository implementations
|
||||
3. `runtime`
|
||||
- plugin-owned runtime config state
|
||||
4. `api`
|
||||
- plugin-owned HTTP router and schemas
|
||||
5. `security`
|
||||
- plugin-owned middleware, dependencies, CSRF logic, and LangGraph adapter
|
||||
6. `authorization`
|
||||
- plugin-owned permission model, policy resolution, and hooks
|
||||
7. `injection`
|
||||
- plugin-owned route-policy loading, injection, and validation
|
||||
|
||||
In other words, a plugin should be a self-contained capability module, not a bag of helpers.
|
||||
|
||||
### 2.2 The host app should provide shared infrastructure, not plugin internals
|
||||
|
||||
The current contract is not that every plugin must be fully infrastructure-independent.
|
||||
|
||||
It is:
|
||||
|
||||
1. a plugin may reuse the application’s shared `engine`, `session_factory`, FastAPI app, and router tree
|
||||
2. but the plugin must still own its table definitions, repositories, runtime config, and business/auth behavior
|
||||
|
||||
This is stated explicitly in [`auth/plugin.toml`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/plugin.toml):
|
||||
|
||||
1. `storage.mode = "shared_infrastructure"`
|
||||
2. the plugin owns its storage definitions and repositories
|
||||
3. but it reuses the application’s shared persistence infrastructure
|
||||
|
||||
So the real rule is not “never reuse infrastructure”. The real rule is “do not outsource plugin business semantics to the rest of the app”.
|
||||
|
||||
### 2.3 Dependencies should remain one-way
|
||||
|
||||
The intended dependency direction in the current design is:
|
||||
|
||||
```text
|
||||
gateway / app bootstrap
|
||||
-> plugin public adapters
|
||||
-> plugin domain / storage / runtime
|
||||
```
|
||||
|
||||
Not:
|
||||
|
||||
```text
|
||||
plugin domain
|
||||
-> depends on app business modules
|
||||
```
|
||||
|
||||
A plugin may depend on:
|
||||
|
||||
1. shared persistence infrastructure
|
||||
2. `app.state` provided by the host application
|
||||
3. generic framework capabilities such as FastAPI / Starlette
|
||||
|
||||
But its core business rules should not depend on unrelated app business modules, otherwise hot-swappability becomes unrealistic.
|
||||
|
||||
## 3. The Current auth Plugin Structure
|
||||
|
||||
The current `auth` plugin is effectively a self-contained authentication and authorization package with its own models, services, and adapters.
|
||||
|
||||
### 3.1 domain
|
||||
|
||||
[`auth/domain`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/domain) owns:
|
||||
|
||||
1. `config.py`
|
||||
- auth-related configuration definition and loading
|
||||
2. `errors.py`
|
||||
- error codes and response contracts
|
||||
3. `jwt.py`
|
||||
- token encoding and decoding
|
||||
4. `password.py`
|
||||
- password hashing and verification
|
||||
5. `models.py`
|
||||
- auth domain models
|
||||
6. `service.py`
|
||||
- `AuthService` as the core business service
|
||||
|
||||
`AuthService` depends only on the plugin’s own `DbUserRepository` plus the shared session factory. The auth business logic is not reimplemented in `gateway`.
|
||||
|
||||
### 3.2 storage
|
||||
|
||||
[`auth/storage`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage) clearly shows the “ORM is owned by the plugin” contract:
|
||||
|
||||
1. [`models.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/models.py)
|
||||
- defines the plugin-owned `users` table model
|
||||
2. [`contracts.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/contracts.py)
|
||||
- defines `User`, `UserCreate`, and `UserRepositoryProtocol`
|
||||
3. [`repositories.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/repositories.py)
|
||||
- implements `DbUserRepository`
|
||||
|
||||
The key point is:
|
||||
|
||||
1. the plugin defines its own ORM model
|
||||
2. the plugin defines its own repository protocol
|
||||
3. the plugin implements its own repository
|
||||
4. external code only needs to provide a session or session factory
|
||||
|
||||
That is the minimal shared seam the boundary should preserve.
|
||||
|
||||
### 3.3 runtime
|
||||
|
||||
[`auth/runtime/config_state.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/runtime/config_state.py) keeps plugin-owned runtime config state:
|
||||
|
||||
1. `get_auth_config()`
|
||||
2. `set_auth_config()`
|
||||
3. `reset_auth_config()`
|
||||
|
||||
This matters because runtime state is also part of the plugin boundary. If future plugins need their own caches, state holders, or feature flags, they should follow the same pattern and keep them inside the plugin.
|
||||
|
||||
### 3.4 adapters
|
||||
|
||||
The `auth` plugin exposes capability through four main adapter groups:
|
||||
|
||||
1. `api/router.py`
|
||||
- HTTP endpoints
|
||||
2. `security/*`
|
||||
- middleware, dependencies, request-user resolution, actor-context bridge
|
||||
3. `authorization/*`
|
||||
- capabilities, policy evaluators, auth hooks
|
||||
4. `injection/*`
|
||||
- route-policy registry, guard injection, startup validation
|
||||
|
||||
These adapters all follow the same rule:
|
||||
|
||||
1. entry-point behavior is defined inside the plugin
|
||||
2. the host app only assembles and wires it
|
||||
|
||||
## 4. How a Plugin Interacts with the Host App
|
||||
|
||||
### 4.1 The top-level router only includes plugin routers
|
||||
|
||||
[`app/gateway/router.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/router.py) simply:
|
||||
|
||||
1. imports `app.plugins.auth.api.router`
|
||||
2. calls `include_router(auth_router)`
|
||||
|
||||
That means the host app integrates auth HTTP behavior by assembly, not by duplicating login/register logic in `gateway`.
|
||||
|
||||
### 4.2 registrar performs wiring, not takeover
|
||||
|
||||
In [`app/gateway/registrar.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/registrar.py), the host app mainly does this:
|
||||
|
||||
1. `app.state.authz_hooks = build_authz_hooks()`
|
||||
2. loads and validates the route-policy registry
|
||||
3. calls `install_route_guards(app)`
|
||||
4. calls `app.add_middleware(CSRFMiddleware)`
|
||||
5. calls `app.add_middleware(AuthMiddleware)`
|
||||
|
||||
So the host app only wires the plugin in:
|
||||
|
||||
1. register middleware
|
||||
2. install route guards
|
||||
3. expose hooks and registries through `app.state`
|
||||
|
||||
The actual auth logic, authz logic, and route-policy semantics still live inside the plugin.
|
||||
|
||||
### 4.3 The plugin reuses shared sessions, but still owns business repositories
|
||||
|
||||
In [`auth/security/dependencies.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/security/dependencies.py):
|
||||
|
||||
1. the plugin reads the shared session factory from `request.app.state.persistence.session_factory`
|
||||
2. constructs `DbUserRepository` itself
|
||||
3. constructs `AuthService` itself
|
||||
|
||||
This is a good low-intrusion seam:
|
||||
|
||||
1. the outside world provides only shared infrastructure handles
|
||||
2. the plugin decides how to instantiate its internal dependencies
|
||||
|
||||
## 5. Hot-Swappability and Low-Intrusion Principles
|
||||
|
||||
### 5.1 If a plugin serves other modules, it should minimize intrusion
|
||||
|
||||
When a plugin provides services to the rest of the app, the preferred patterns are:
|
||||
|
||||
1. expose a router
|
||||
2. expose middleware or dependencies
|
||||
3. expose hooks or protocols
|
||||
4. inject a small number of shared objects through `app.state`
|
||||
5. use config-driven route policies or capabilities instead of hardcoding checks inside business routes
|
||||
|
||||
Patterns to avoid:
|
||||
|
||||
1. large plugin-specific branches spread across `gateway`
|
||||
2. unrelated business modules importing plugin ORM internals and rebuilding plugin logic themselves
|
||||
3. plugin state being maintained across many global modules
|
||||
|
||||
### 5.2 Low-intrusion seams already visible in auth
|
||||
|
||||
The current `auth` plugin already uses four important low-intrusion seams:
|
||||
|
||||
1. router integration
|
||||
- `gateway.router` only calls `include_router`
|
||||
2. middleware integration
|
||||
- `registrar` only registers `AuthMiddleware` and `CSRFMiddleware`
|
||||
3. policy injection
|
||||
- `install_route_guards(app)` appends `Depends(enforce_route_policy)` uniformly to routes
|
||||
4. hook seam
|
||||
- `authz_hooks` is exposed via `app.state`, so permission providers and policy builders can be replaced
|
||||
|
||||
This structure has three practical benefits:
|
||||
|
||||
1. host-app changes stay concentrated in the assembly layer
|
||||
2. plugin core logic stays concentrated inside the plugin directory
|
||||
3. swapping implementations does not require editing business routes one by one
|
||||
|
||||
### 5.3 Route policy is a key low-intrusion mechanism
|
||||
|
||||
[`auth/injection/registry_loader.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/registry_loader.py), [`validation.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/validation.py), and [`route_injector.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/route_injector.py) together form an important contract:
|
||||
|
||||
1. route policies live in the plugin-owned `route_policies.yaml`
|
||||
2. startup validates that policy entries and real routes stay aligned
|
||||
3. guards are attached by uniform injection instead of manual per-endpoint code
|
||||
|
||||
That allows the plugin to:
|
||||
|
||||
1. describe which routes are public, which capabilities are required, and which owner policies apply
|
||||
2. avoid large invasive changes to the host routing layer
|
||||
3. remain easier to replace or trim down later
|
||||
|
||||
## 6. What “ORM and runtime are implemented inside the plugin” Should Mean
|
||||
|
||||
That contract should be read as three concrete rules:
|
||||
|
||||
1. data models belong to the plugin
|
||||
- the plugin’s own tables, Pydantic contracts, repository protocols, and repository implementations stay inside the plugin directory
|
||||
2. runtime state belongs to the plugin
|
||||
- plugin-owned config caches, context bridges, and plugin-level hooks stay inside the plugin
|
||||
3. the outside world exposes infrastructure, not plugin semantics
|
||||
- for example shared `session_factory`, FastAPI app, and `app.state`
|
||||
|
||||
Using `auth` as the example:
|
||||
|
||||
1. the `users` table is defined inside the plugin, not in `app.infra`
|
||||
2. `AuthService` is implemented inside the plugin, not in `gateway`
|
||||
3. `get_auth_config()` is maintained inside the plugin, not cached elsewhere
|
||||
4. `AuthMiddleware`, `route_guard`, and `AuthzHooks` are all provided by the plugin itself
|
||||
|
||||
This is the structural prerequisite for meaningful pluginization later.
|
||||
|
||||
## 7. Current Scope and Non-Goals
|
||||
|
||||
At the current stage, the role of `app.plugins` is mainly:
|
||||
|
||||
1. to create module boundaries for separable application-side capabilities
|
||||
2. to let each plugin own its own domain/storage/runtime/adapters
|
||||
3. to connect plugins to the host app through assembly-oriented seams
|
||||
|
||||
The current non-goals are also clear:
|
||||
|
||||
1. this is not yet a full generic plugin discovery/installation system
|
||||
2. plugins are not dynamically enabled or disabled at runtime
|
||||
3. shared infrastructure is not being duplicated into every plugin
|
||||
|
||||
So at this stage, “hot-swappable” should be interpreted more precisely as:
|
||||
|
||||
1. plugin boundaries stay as independent as possible
|
||||
2. integration points stay concentrated in the assembly layer
|
||||
3. replacing or removing a plugin should mostly affect a small number of places such as `registrar`, router includes, and `app.state` hooks
|
||||
|
||||
## 8. Suggested Evolution Rules
|
||||
|
||||
If `app.plugins` is going to become a more stable plugin boundary, the codebase should keep following these rules:
|
||||
|
||||
1. each plugin directory should keep a `domain` / `storage` / `runtime` / `adapter` split
|
||||
2. plugin-owned ORM and repositories should not drift into shared business directories
|
||||
3. when a plugin serves the rest of the app, it should prefer exposing protocols, hooks, routers, and middleware over forcing external code to import internal implementation details
|
||||
4. seams between a plugin and the host app should stay mostly limited to:
|
||||
- `router.include_router(...)`
|
||||
- `app.add_middleware(...)`
|
||||
- `app.state.*`
|
||||
- lifespan/bootstrap wiring
|
||||
5. config-driven integration should be preferred over scattered hardcoded integration
|
||||
6. startup validation should be preferred over implicit runtime failure
|
||||
|
||||
## 9. Summary
|
||||
|
||||
The current `app.plugins` contract can be summarized in one sentence:
|
||||
|
||||
Each plugin owns its own business implementation, ORM, and runtime; the host application provides shared infrastructure and assembly seams; and services should be integrated through low-intrusion, replaceable boundaries so the system can evolve toward real hot-swappability.
|
||||
@@ -1,310 +0,0 @@
|
||||
# app.plugins 设计说明
|
||||
|
||||
本文基于当前代码实现,说明 `backend/app/plugins` 的定位、插件设计契约、依赖边界,以及当前 `auth` 插件是如何在尽量少侵入宿主应用的前提下提供服务的。
|
||||
|
||||
## 1. 总体定位
|
||||
|
||||
`app.plugins` 是应用侧插件边界。它的目标不是做一个通用插件市场,而是在 `app` 这一层给可拆分的业务能力预留清晰边界,使某一类能力可以:
|
||||
|
||||
1. 在插件内部自带领域模型、运行时状态和适配器
|
||||
2. 只通过有限的接缝与宿主应用交互
|
||||
3. 在未来保持“可替换、可裁剪、可扩展”
|
||||
|
||||
当前目录下实际落地的插件是 [`auth`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth)。
|
||||
|
||||
从当前实现看,`app.plugins` 的方向不是“所有逻辑都塞进 app”,而是:
|
||||
|
||||
1. 宿主应用负责统一启动、共享基础设施和总路由装配
|
||||
2. 插件负责自己的业务契约、持久化定义、运行时状态和外部适配器
|
||||
|
||||
## 2. 插件设计契约
|
||||
|
||||
### 2.1 插件内部要自带完整能力
|
||||
|
||||
当前代码体现出的首要契约是:
|
||||
|
||||
插件自己的 ORM、runtime、domain、adapter,原则上都应由插件内部实现,不要把核心业务依赖散落到外部模块。
|
||||
|
||||
以 `auth` 插件为例,它内部已经自带了完整分层:
|
||||
|
||||
1. `domain`
|
||||
- 配置、错误、JWT、密码、领域模型、服务
|
||||
2. `storage`
|
||||
- 插件自己的 ORM 模型、仓储契约和仓储实现
|
||||
3. `runtime`
|
||||
- 插件自己的运行时配置状态
|
||||
4. `api`
|
||||
- 插件自己的 HTTP router 和 schema
|
||||
5. `security`
|
||||
- 插件自己的 middleware、dependency、csrf、LangGraph 适配
|
||||
6. `authorization`
|
||||
- 插件自己的权限模型、policy 解析和 hook
|
||||
7. `injection`
|
||||
- 插件自己的路由策略注册、注入和校验逻辑
|
||||
|
||||
换句话说,插件不是一组零散 helper,而应该是一个自闭合的功能模块。
|
||||
|
||||
### 2.2 宿主应用只提供共享基础设施,不承接插件内部逻辑
|
||||
|
||||
当前约束不是“插件完全独立进程”,而是:
|
||||
|
||||
1. 插件可以复用应用共享的 `engine`、`session_factory`、FastAPI app、路由树
|
||||
2. 但插件自己的表结构、仓储、运行时配置、鉴权逻辑,仍然应由插件自己拥有
|
||||
|
||||
这一点在 [`auth/plugin.toml`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/plugin.toml) 里写得很明确:
|
||||
|
||||
1. `storage.mode = "shared_infrastructure"`
|
||||
2. 说明插件拥有自己的 storage definitions 和 repositories
|
||||
3. 但复用应用共享的 persistence infrastructure
|
||||
|
||||
所以这里的契约不是“禁止复用基础设施”,而是“不要把插件内部业务实现外包给 app 其他模块”。
|
||||
|
||||
### 2.3 依赖方向要单向
|
||||
|
||||
按当前实现,比较理想的依赖方向是:
|
||||
|
||||
```text
|
||||
gateway / app bootstrap
|
||||
-> plugin public adapters
|
||||
-> plugin domain / storage / runtime
|
||||
```
|
||||
|
||||
而不是:
|
||||
|
||||
```text
|
||||
plugin domain
|
||||
-> 依赖 app 里的业务模块
|
||||
```
|
||||
|
||||
插件可以使用:
|
||||
|
||||
1. 共享持久化基础设施
|
||||
2. 宿主应用提供的 `app.state`
|
||||
3. FastAPI / Starlette 等通用框架能力
|
||||
|
||||
但不应该把自己的核心业务规则建立在别的业务模块之上,否则后续无法热插拔。
|
||||
|
||||
## 3. 当前 auth 插件的实际结构
|
||||
|
||||
当前 `auth` 插件可以概括为一套“自带模型、自带服务、自带适配器”的认证授权包。
|
||||
|
||||
### 3.1 domain
|
||||
|
||||
[`auth/domain`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/domain) 负责:
|
||||
|
||||
1. `config.py`
|
||||
- 认证相关配置定义与加载
|
||||
2. `errors.py`
|
||||
- 错误码和错误响应契约
|
||||
3. `jwt.py`
|
||||
- token 编解码
|
||||
4. `password.py`
|
||||
- 密码哈希和校验
|
||||
5. `models.py`
|
||||
- auth 域模型
|
||||
6. `service.py`
|
||||
- `AuthService`,作为核心业务服务
|
||||
|
||||
`AuthService` 本身只依赖插件内部的 `DbUserRepository` 和共享 session factory,没有把认证逻辑散到 `gateway`。
|
||||
|
||||
### 3.2 storage
|
||||
|
||||
[`auth/storage`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage) 明确体现了“ORM 由插件自己内部实现”的契约:
|
||||
|
||||
1. [`models.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/models.py)
|
||||
- 定义插件自己的 `users` 表模型
|
||||
2. [`contracts.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/contracts.py)
|
||||
- 定义 `User`、`UserCreate` 和 `UserRepositoryProtocol`
|
||||
3. [`repositories.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/repositories.py)
|
||||
- 实现 `DbUserRepository`
|
||||
|
||||
这里的关键点是:
|
||||
|
||||
1. 插件自己定义 ORM model
|
||||
2. 插件自己定义 repository protocol
|
||||
3. 插件自己实现 repository
|
||||
4. 外部只需要给它 session / session_factory
|
||||
|
||||
这就是插件边界应该保持的最小共享面。
|
||||
|
||||
### 3.3 runtime
|
||||
|
||||
[`auth/runtime/config_state.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/runtime/config_state.py) 维护插件自己的 runtime config state:
|
||||
|
||||
1. `get_auth_config()`
|
||||
2. `set_auth_config()`
|
||||
3. `reset_auth_config()`
|
||||
|
||||
这说明运行时配置状态也属于插件内部,而不是由外部模块代持。后续如果别的插件需要自己的缓存、状态机、feature flag,也应沿这个模式内聚在插件内部。
|
||||
|
||||
### 3.4 adapters
|
||||
|
||||
`auth` 插件对外暴露能力主要通过四类 adapter:
|
||||
|
||||
1. `api/router.py`
|
||||
- HTTP 接口
|
||||
2. `security/*`
|
||||
- middleware、dependency、request user 解析、actor context bridge
|
||||
3. `authorization/*`
|
||||
- capability、policy evaluator、auth hooks
|
||||
4. `injection/*`
|
||||
- route policy registry、guard 注入、启动校验
|
||||
|
||||
这类 adapter 的共同特征是:
|
||||
|
||||
1. 入口能力在插件内定义
|
||||
2. 宿主应用只负责调用和装配
|
||||
|
||||
## 4. 插件如何与宿主应用交互
|
||||
|
||||
### 4.1 总路由只 include,不重写插件逻辑
|
||||
|
||||
[`app/gateway/router.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/router.py) 只是:
|
||||
|
||||
1. 引入 `app.plugins.auth.api.router`
|
||||
2. `include_router(auth_router)`
|
||||
|
||||
这说明宿主应用对 auth HTTP 能力的接入是装配式的,而不是在 `gateway` 里重写一套登录/注册逻辑。
|
||||
|
||||
### 4.2 registrar 负责启动装配,不负责接管插件实现
|
||||
|
||||
[`app/gateway/registrar.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/registrar.py) 里,宿主应用做的事情主要是:
|
||||
|
||||
1. `app.state.authz_hooks = build_authz_hooks()`
|
||||
2. 加载并校验 route policy registry
|
||||
3. `install_route_guards(app)`
|
||||
4. `app.add_middleware(CSRFMiddleware)`
|
||||
5. `app.add_middleware(AuthMiddleware)`
|
||||
|
||||
也就是说,宿主应用只负责把插件接进来:
|
||||
|
||||
1. 注册 middleware
|
||||
2. 安装 route guard
|
||||
3. 把 hooks 和 registry 放到 `app.state`
|
||||
|
||||
真正的鉴权逻辑、认证逻辑、路由策略语义仍然在插件内部。
|
||||
|
||||
### 4.3 共享会话工厂,但业务仓储仍归插件
|
||||
|
||||
在 [`auth/security/dependencies.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/security/dependencies.py) 中:
|
||||
|
||||
1. 插件从 `request.app.state.persistence.session_factory` 取得共享 session factory
|
||||
2. 然后自己构造 `DbUserRepository`
|
||||
3. 再自己构造 `AuthService`
|
||||
|
||||
这就是一个很典型的低侵入接缝:
|
||||
|
||||
1. 外部只提供共享基础设施句柄
|
||||
2. 插件自己决定如何实例化内部依赖
|
||||
|
||||
## 5. 热插拔与低侵入原则
|
||||
|
||||
### 5.1 如果要向其他模块提供服务,应尽量减少入侵
|
||||
|
||||
插件给其他模块提供服务时,优先选下面这些方式:
|
||||
|
||||
1. 暴露 router
|
||||
2. 暴露 middleware / dependency
|
||||
3. 暴露 hook 或 protocol
|
||||
4. 通过 `app.state` 注入少量共享对象
|
||||
5. 使用配置驱动的 route policy / capability,而不是把判断逻辑硬编码进业务路由
|
||||
|
||||
不推荐的方式是:
|
||||
|
||||
1. 在 `gateway` 大量写插件特定分支
|
||||
2. 让别的业务模块直接 import 插件内部 ORM 细节后自行拼逻辑
|
||||
3. 把插件状态散落到全局多个模块中共同维护
|
||||
|
||||
### 5.2 当前 auth 插件已经体现出的低侵入点
|
||||
|
||||
当前 `auth` 插件的低侵入接入点主要有四个:
|
||||
|
||||
1. 路由接入
|
||||
- `gateway.router` 只 `include_router`
|
||||
2. 中间件接入
|
||||
- `registrar` 只注册 `AuthMiddleware` / `CSRFMiddleware`
|
||||
3. 策略注入
|
||||
- `install_route_guards(app)` 给路由统一追加 `Depends(enforce_route_policy)`
|
||||
4. hook 接缝
|
||||
- `authz_hooks` 通过 `app.state` 暴露,策略构建和权限提供器可以替换
|
||||
|
||||
这套结构的好处是:
|
||||
|
||||
1. 宿主应用改动面集中在装配层
|
||||
2. 插件核心实现集中在插件目录内部
|
||||
3. 替换实现时,不需要在业务路由里逐个修改
|
||||
|
||||
### 5.3 route policy 是低侵入的关键机制
|
||||
|
||||
[`auth/injection/registry_loader.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/registry_loader.py)、[`validation.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/validation.py) 和 [`route_injector.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/route_injector.py) 共同形成了一套很关键的契约:
|
||||
|
||||
1. 路由策略写在插件自己的 `route_policies.yaml`
|
||||
2. 启动时会校验策略表和真实路由是否一致
|
||||
3. guard 通过统一注入附着到路由,而不是每个 endpoint 手写一遍
|
||||
|
||||
这使得插件能够:
|
||||
|
||||
1. 用配置描述“哪些路由公开、需要哪些 capability、需要哪些 owner policy”
|
||||
2. 避免对宿主路由层做大规模侵入
|
||||
3. 在未来更容易替换或裁剪某个插件
|
||||
|
||||
## 6. 关于“ORM、runtime 都由自己内部实现”的具体说明
|
||||
|
||||
这条契约建议明确理解为以下三点:
|
||||
|
||||
1. 数据模型归插件
|
||||
- 插件自己的表、Pydantic contract、repository protocol、repository implementation 都放在插件目录内
|
||||
2. 运行时状态归插件
|
||||
- 插件自己的配置缓存、上下文桥、插件级 hooks 都在插件内部维护
|
||||
3. 外部只暴露基础设施,不接管插件语义
|
||||
- 例如共享 `session_factory`、FastAPI app、`app.state`
|
||||
|
||||
拿 `auth` 举例:
|
||||
|
||||
1. `users` 表在插件里定义,不在 `app.infra` 定义
|
||||
2. `AuthService` 在插件里实现,不在 `gateway` 实现
|
||||
3. `get_auth_config()` 在插件里维护,不由别的模块缓存
|
||||
4. `AuthMiddleware`、`route_guard`、`AuthzHooks` 都由插件自己提供
|
||||
|
||||
这是后续做插件化时最重要的结构前提。
|
||||
|
||||
## 7. 当前作用范围与非目标
|
||||
|
||||
就当前实现而言,`app.plugins` 的作用范围主要是:
|
||||
|
||||
1. 为应用侧可拆分能力建立模块边界
|
||||
2. 让插件拥有自己的 domain/storage/runtime/adapter
|
||||
3. 通过装配式接缝接入宿主应用
|
||||
|
||||
当前非目标也很明确:
|
||||
|
||||
1. 还不是一个完整的通用插件发现/安装系统
|
||||
2. 还没有做到运行时动态启停插件
|
||||
3. 也不是把共享基础设施完全复制进每个插件
|
||||
|
||||
所以“热插拔”在当前阶段更准确的含义是:
|
||||
|
||||
1. 插件边界尽量独立
|
||||
2. 接入点尽量集中在装配层
|
||||
3. 替换或移除时,改动尽量局限在 `registrar`、`router include`、`app.state` hooks 这些少数位置
|
||||
|
||||
## 8. 后续演进建议
|
||||
|
||||
如果后续要继续把 `app.plugins` 做成更稳定的插件边界,建议保持这些规则:
|
||||
|
||||
1. 每个插件目录内部都保持 `domain` / `storage` / `runtime` / `adapter` 分层
|
||||
2. 插件自己的 ORM 与 repository 不要下沉到共享业务目录
|
||||
3. 插件向外提供服务时优先暴露 protocol、hook、router、middleware,而不是要求外部 import 内部实现细节
|
||||
4. 插件与宿主应用的接缝尽量限制在:
|
||||
- `router.include_router(...)`
|
||||
- `app.add_middleware(...)`
|
||||
- `app.state.*`
|
||||
- 生命周期装配
|
||||
5. 配置驱动优先于散落的硬编码接入
|
||||
6. 启动期校验优先于运行时隐式失败
|
||||
|
||||
## 9. 设计总结
|
||||
|
||||
可以把当前 `app.plugins` 的契约总结为一句话:
|
||||
|
||||
插件内部拥有自己的业务实现、ORM 和 runtime;宿主应用只提供共享基础设施和装配接缝;对外服务时尽量通过低侵入、可替换的方式接入,以便后续做到真正的热插拔和边界演进。
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user