mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f82f8a3a2 | |||
| a0ab3a3dd4 | |||
| 4fa2c15613 | |||
| 892a06fe98 | |||
| b5e18f5b47 | |||
| e3e00af51d | |||
| 5f2f1941e9 | |||
| 9d0a42c1fb | |||
| 39a575617b | |||
| 274255b1a5 | |||
| 14892e1463 | |||
| 37fd8b0d7a | |||
| 2fe0856e33 | |||
| 38a6ec496f | |||
| 3a99c4e81c | |||
| 7b9d224b3a | |||
| 0572ef44b9 | |||
| 839563f308 | |||
| 62bdfe3abc | |||
| b61ce3527b | |||
| 2d5f6f1b3d | |||
| 69bf3dafd8 | |||
| 6cbec13495 | |||
| 31e5b586a1 | |||
| e75a2ff29a | |||
| 185f5649dd |
+2
-23
@@ -1,6 +1,3 @@
|
|||||||
# Serper API Key (Google Search) - https://serper.dev
|
|
||||||
SERPER_API_KEY=your-serper-api-key
|
|
||||||
|
|
||||||
# TAVILY API Key
|
# TAVILY API Key
|
||||||
TAVILY_API_KEY=your-tavily-api-key
|
TAVILY_API_KEY=your-tavily-api-key
|
||||||
|
|
||||||
@@ -9,9 +6,8 @@ JINA_API_KEY=your-jina-api-key
|
|||||||
|
|
||||||
# InfoQuest API Key
|
# InfoQuest API Key
|
||||||
INFOQUEST_API_KEY=your-infoquest-api-key
|
INFOQUEST_API_KEY=your-infoquest-api-key
|
||||||
# Browser CORS allowlist for split-origin or port-forwarded deployments (comma-separated exact origins).
|
# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001
|
||||||
# Leave unset when using the unified nginx endpoint, e.g. http://localhost:2026.
|
# CORS_ORIGINS=http://localhost:3000
|
||||||
# GATEWAY_CORS_ORIGINS=http://localhost:3000,http://127.0.0.1:3000
|
|
||||||
|
|
||||||
# Optional:
|
# Optional:
|
||||||
# FIRECRAWL_API_KEY=your-firecrawl-api-key
|
# FIRECRAWL_API_KEY=your-firecrawl-api-key
|
||||||
@@ -28,7 +24,6 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
# SLACK_BOT_TOKEN=your-slack-bot-token
|
# SLACK_BOT_TOKEN=your-slack-bot-token
|
||||||
# SLACK_APP_TOKEN=your-slack-app-token
|
# SLACK_APP_TOKEN=your-slack-app-token
|
||||||
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
|
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
|
||||||
# DISCORD_BOT_TOKEN=your-discord-bot-token
|
|
||||||
|
|
||||||
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
|
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
|
||||||
# LANGSMITH_TRACING=true
|
# LANGSMITH_TRACING=true
|
||||||
@@ -44,19 +39,3 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
#
|
#
|
||||||
# WECOM_BOT_ID=your-wecom-bot-id
|
# WECOM_BOT_ID=your-wecom-bot-id
|
||||||
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
||||||
# DINGTALK_CLIENT_ID=your-dingtalk-client-id
|
|
||||||
# DINGTALK_CLIENT_SECRET=your-dingtalk-client-secret
|
|
||||||
|
|
||||||
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
|
|
||||||
# GATEWAY_ENABLE_DOCS=false
|
|
||||||
|
|
||||||
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
|
|
||||||
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
|
|
||||||
# /api/* rewrites). They default to localhost values that match `make dev` and
|
|
||||||
# `make start`, so most local users do not need to set them.
|
|
||||||
#
|
|
||||||
# Override only when the Gateway is not on localhost:8001 (e.g. when the
|
|
||||||
# frontend and gateway run on different hosts, in containers with a service
|
|
||||||
# alias, or behind a different port). docker-compose already sets these.
|
|
||||||
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
|
||||||
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
name: Publish Containers
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- "v*"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
|
|
||||||
backend-container:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
packages: write
|
|
||||||
attestations: write
|
|
||||||
id-token: write
|
|
||||||
env:
|
|
||||||
REGISTRY: ghcr.io
|
|
||||||
IMAGE_NAME: ${{ github.repository }}-backend
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
- name: Log in to the Container registry
|
|
||||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
|
||||||
with:
|
|
||||||
registry: ${{ env.REGISTRY }}
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
|
||||||
with:
|
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
|
||||||
tags: |
|
|
||||||
type=ref,event=tag
|
|
||||||
type=ref,event=branch
|
|
||||||
type=sha
|
|
||||||
type=raw,value=latest,enable={{is_default_branch}}
|
|
||||||
- name: Build and push Docker image
|
|
||||||
id: push
|
|
||||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: backend/Dockerfile
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
|
|
||||||
- name: Generate artifact attestation
|
|
||||||
uses: actions/attest-build-provenance@v2
|
|
||||||
with:
|
|
||||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
|
||||||
subject-digest: ${{ steps.push.outputs.digest }}
|
|
||||||
push-to-registry: true
|
|
||||||
|
|
||||||
frontend-container:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
packages: write
|
|
||||||
attestations: write
|
|
||||||
id-token: write
|
|
||||||
env:
|
|
||||||
REGISTRY: ghcr.io
|
|
||||||
IMAGE_NAME: ${{ github.repository }}-frontend
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
- name: Log in to the Container registry
|
|
||||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
|
||||||
with:
|
|
||||||
registry: ${{ env.REGISTRY }}
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
|
||||||
with:
|
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
|
||||||
tags: |
|
|
||||||
type=ref,event=tag
|
|
||||||
type=ref,event=branch
|
|
||||||
type=sha
|
|
||||||
type=raw,value=latest,enable={{is_default_branch}}
|
|
||||||
- name: Build and push Docker image
|
|
||||||
id: push
|
|
||||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: frontend/Dockerfile
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
|
|
||||||
- name: Generate artifact attestation
|
|
||||||
uses: actions/attest-build-provenance@v2
|
|
||||||
with:
|
|
||||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
|
||||||
subject-digest: ${{ steps.push.outputs.digest }}
|
|
||||||
push-to-registry: true
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
name: E2E Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ 'main' ]
|
|
||||||
paths:
|
|
||||||
- 'frontend/**'
|
|
||||||
- '.github/workflows/e2e-tests.yml'
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
|
||||||
- 'frontend/**'
|
|
||||||
- '.github/workflows/e2e-tests.yml'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
e2e-tests:
|
|
||||||
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 15
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '22'
|
|
||||||
|
|
||||||
- name: Enable Corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Use pinned pnpm version
|
|
||||||
run: corepack prepare pnpm@10.26.2 --activate
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Install Playwright Chromium
|
|
||||||
working-directory: frontend
|
|
||||||
run: npx playwright install chromium --with-deps
|
|
||||||
|
|
||||||
- name: Run E2E tests
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm exec playwright test
|
|
||||||
env:
|
|
||||||
SKIP_ENV_VALIDATION: '1'
|
|
||||||
|
|
||||||
- name: Upload Playwright report
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
if: ${{ !cancelled() }}
|
|
||||||
with:
|
|
||||||
name: playwright-report
|
|
||||||
path: frontend/playwright-report/
|
|
||||||
retention-days: 7
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
name: Frontend Unit Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ 'main' ]
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
frontend-unit-tests:
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 15
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '22'
|
|
||||||
|
|
||||||
- name: Enable Corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Use pinned pnpm version
|
|
||||||
run: corepack prepare pnpm@10.26.2 --activate
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Run unit tests of frontend
|
|
||||||
working-directory: frontend
|
|
||||||
run: make test
|
|
||||||
@@ -40,7 +40,6 @@ coverage/
|
|||||||
skills/custom/*
|
skills/custom/*
|
||||||
logs/
|
logs/
|
||||||
log/
|
log/
|
||||||
debug.log
|
|
||||||
|
|
||||||
# Local git hooks (keep only on this machine, do not push)
|
# Local git hooks (keep only on this machine, do not push)
|
||||||
.githooks/
|
.githooks/
|
||||||
@@ -56,7 +55,5 @@ web/
|
|||||||
backend/Dockerfile.langgraph
|
backend/Dockerfile.langgraph
|
||||||
config.yaml.bak
|
config.yaml.bak
|
||||||
.playwright-mcp
|
.playwright-mcp
|
||||||
/frontend/test-results/
|
|
||||||
/frontend/playwright-report/
|
|
||||||
.gstack/
|
.gstack/
|
||||||
.worktrees
|
.worktrees
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
repos:
|
|
||||||
# Backend: ruff lint + format via uv (uses the same ruff version as backend deps)
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: ruff
|
|
||||||
name: ruff lint
|
|
||||||
entry: bash -c 'cd backend && uv run ruff check --fix "${@/#backend\//}"' --
|
|
||||||
language: system
|
|
||||||
types_or: [python]
|
|
||||||
files: ^backend/
|
|
||||||
- id: ruff-format
|
|
||||||
name: ruff format
|
|
||||||
entry: bash -c 'cd backend && uv run ruff format "${@/#backend\//}"' --
|
|
||||||
language: system
|
|
||||||
types_or: [python]
|
|
||||||
files: ^backend/
|
|
||||||
|
|
||||||
# Frontend: eslint + prettier (must run from frontend/ for node_modules resolution)
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: frontend-eslint
|
|
||||||
name: eslint (frontend)
|
|
||||||
entry: bash -c 'cd frontend && npx eslint --fix "${@/#frontend\//}"' --
|
|
||||||
language: system
|
|
||||||
types_or: [javascript, tsx, ts]
|
|
||||||
files: ^frontend/
|
|
||||||
|
|
||||||
- id: frontend-prettier
|
|
||||||
name: prettier (frontend)
|
|
||||||
entry: bash -c 'cd frontend && npx prettier --write "${@/#frontend\//}"' --
|
|
||||||
language: system
|
|
||||||
files: ^frontend/
|
|
||||||
types_or: [javascript, tsx, ts, json, css]
|
|
||||||
+26
-25
@@ -46,12 +46,12 @@ Docker provides a consistent, isolated environment with all dependencies pre-con
|
|||||||
All services will start with hot-reload enabled:
|
All services will start with hot-reload enabled:
|
||||||
- Frontend changes are automatically reloaded
|
- Frontend changes are automatically reloaded
|
||||||
- Backend changes trigger automatic restart
|
- Backend changes trigger automatic restart
|
||||||
- Gateway-hosted LangGraph-compatible runtime supports hot-reload
|
- LangGraph server supports hot-reload
|
||||||
|
|
||||||
4. **Access the application**:
|
4. **Access the application**:
|
||||||
- Web Interface: http://localhost:2026
|
- Web Interface: http://localhost:2026
|
||||||
- API Gateway: http://localhost:2026/api/*
|
- API Gateway: http://localhost:2026/api/*
|
||||||
- LangGraph-compatible API: http://localhost:2026/api/langgraph/*
|
- LangGraph: http://localhost:2026/api/langgraph/*
|
||||||
|
|
||||||
#### Docker Commands
|
#### Docker Commands
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ Use these as practical starting points for development and review environments:
|
|||||||
If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket:
|
If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
unable to get image 'deer-flow-gateway': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock
|
unable to get image 'deer-flow-dev-langgraph': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock
|
||||||
```
|
```
|
||||||
|
|
||||||
Recommended fix: add your current user to the `docker` group so Docker commands work without `sudo`.
|
Recommended fix: add your current user to the `docker` group so Docker commands work without `sudo`.
|
||||||
@@ -131,8 +131,9 @@ Host Machine
|
|||||||
Docker Compose (deer-flow-dev)
|
Docker Compose (deer-flow-dev)
|
||||||
├→ nginx (port 2026) ← Reverse proxy
|
├→ nginx (port 2026) ← Reverse proxy
|
||||||
├→ web (port 3000) ← Frontend with hot-reload
|
├→ web (port 3000) ← Frontend with hot-reload
|
||||||
├→ gateway (port 8001) ← Gateway API + LangGraph-compatible runtime with hot-reload
|
├→ api (port 8001) ← Gateway API with hot-reload
|
||||||
└→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode
|
├→ langgraph (port 2024) ← LangGraph server with hot-reload
|
||||||
|
└→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode
|
||||||
```
|
```
|
||||||
|
|
||||||
**Benefits of Docker Development**:
|
**Benefits of Docker Development**:
|
||||||
@@ -165,7 +166,7 @@ Required tools:
|
|||||||
|
|
||||||
1. **Configure the application** (same as Docker setup above)
|
1. **Configure the application** (same as Docker setup above)
|
||||||
|
|
||||||
2. **Install dependencies** (this also sets up pre-commit hooks):
|
2. **Install dependencies**:
|
||||||
```bash
|
```bash
|
||||||
make install
|
make install
|
||||||
```
|
```
|
||||||
@@ -183,13 +184,17 @@ Required tools:
|
|||||||
|
|
||||||
If you need to start services individually:
|
If you need to start services individually:
|
||||||
|
|
||||||
1. **Start backend service**:
|
1. **Start backend services**:
|
||||||
```bash
|
```bash
|
||||||
# Terminal 1: Start Gateway API + embedded agent runtime (port 8001)
|
# Terminal 1: Start LangGraph Server (port 2024)
|
||||||
cd backend
|
cd backend
|
||||||
make dev
|
make dev
|
||||||
|
|
||||||
# Terminal 2: Start Frontend (port 3000)
|
# Terminal 2: Start Gateway API (port 8001)
|
||||||
|
cd backend
|
||||||
|
make gateway
|
||||||
|
|
||||||
|
# Terminal 3: Start Frontend (port 3000)
|
||||||
cd frontend
|
cd frontend
|
||||||
pnpm dev
|
pnpm dev
|
||||||
```
|
```
|
||||||
@@ -207,10 +212,10 @@ If you need to start services individually:
|
|||||||
|
|
||||||
The nginx configuration provides:
|
The nginx configuration provides:
|
||||||
- Unified entry point on port 2026
|
- Unified entry point on port 2026
|
||||||
- Rewrites `/api/langgraph/*` to Gateway's LangGraph-compatible API (8001)
|
- Routes `/api/langgraph/*` to LangGraph Server (2024)
|
||||||
- Routes other `/api/*` endpoints to Gateway API (8001)
|
- Routes other `/api/*` endpoints to Gateway API (8001)
|
||||||
- Routes non-API requests to Frontend (3000)
|
- Routes non-API requests to Frontend (3000)
|
||||||
- Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist
|
- Centralized CORS handling
|
||||||
- SSE/streaming support for real-time agent responses
|
- SSE/streaming support for real-time agent responses
|
||||||
- Optimized timeouts for long-running operations
|
- Optimized timeouts for long-running operations
|
||||||
|
|
||||||
@@ -230,8 +235,8 @@ deer-flow/
|
|||||||
│ └── nginx.local.conf # Nginx config for local dev
|
│ └── nginx.local.conf # Nginx config for local dev
|
||||||
├── backend/ # Backend application
|
├── backend/ # Backend application
|
||||||
│ ├── src/
|
│ ├── src/
|
||||||
│ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001)
|
│ │ ├── gateway/ # Gateway API (port 8001)
|
||||||
│ │ ├── agents/ # LangGraph agent runtime used by Gateway
|
│ │ ├── agents/ # LangGraph agents (port 2024)
|
||||||
│ │ ├── mcp/ # Model Context Protocol integration
|
│ │ ├── mcp/ # Model Context Protocol integration
|
||||||
│ │ ├── skills/ # Skills system
|
│ │ ├── skills/ # Skills system
|
||||||
│ │ └── sandbox/ # Sandbox execution
|
│ │ └── sandbox/ # Sandbox execution
|
||||||
@@ -251,7 +256,8 @@ Browser
|
|||||||
↓
|
↓
|
||||||
Nginx (port 2026) ← Unified entry point
|
Nginx (port 2026) ← Unified entry point
|
||||||
├→ Frontend (port 3000) ← / (non-API requests)
|
├→ Frontend (port 3000) ← / (non-API requests)
|
||||||
└→ Gateway API (port 8001) ← /api/* and /api/langgraph/* (LangGraph-compatible agent interactions)
|
├→ Gateway API (port 8001) ← /api/models, /api/mcp, /api/skills, /api/threads/*/artifacts
|
||||||
|
└→ LangGraph Server (port 2024) ← /api/langgraph/* (agent interactions)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Development Workflow
|
## Development Workflow
|
||||||
@@ -292,24 +298,19 @@ Nginx (port 2026) ← Unified entry point
|
|||||||
```bash
|
```bash
|
||||||
# Backend tests
|
# Backend tests
|
||||||
cd backend
|
cd backend
|
||||||
make test
|
uv run pytest
|
||||||
|
|
||||||
# Frontend unit tests
|
# Frontend checks
|
||||||
cd frontend
|
cd frontend
|
||||||
make test
|
pnpm check
|
||||||
|
|
||||||
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
|
|
||||||
cd frontend
|
|
||||||
make test-e2e
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### PR Regression Checks
|
### PR Regression Checks
|
||||||
|
|
||||||
Every pull request triggers the following CI workflows:
|
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including:
|
||||||
|
|
||||||
- **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
|
- `tests/test_provisioner_kubeconfig.py`
|
||||||
- **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
|
- `tests/test_docker_sandbox_mode_detection.py`
|
||||||
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
|
|
||||||
|
|
||||||
## Code Style
|
## Code Style
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# DeerFlow - Unified Development Environment
|
# DeerFlow - Unified Development Environment
|
||||||
|
|
||||||
.PHONY: help config config-upgrade check install setup doctor detect-thread-boundaries dev dev-daemon start start-daemon stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
.PHONY: help config config-upgrade check install setup doctor dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||||
|
|
||||||
BASH ?= bash
|
BASH ?= bash
|
||||||
BACKEND_UV_RUN = cd backend && uv run
|
BACKEND_UV_RUN = cd backend && uv run
|
||||||
@@ -23,23 +23,28 @@ help:
|
|||||||
@echo " make config - Generate local config files (aborts if config already exists)"
|
@echo " make config - Generate local config files (aborts if config already exists)"
|
||||||
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
|
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
|
||||||
@echo " make check - Check if all required tools are installed"
|
@echo " make check - Check if all required tools are installed"
|
||||||
@echo " make detect-thread-boundaries - Inventory async/thread boundary points"
|
@echo " make install - Install all dependencies (frontend + backend)"
|
||||||
@echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)"
|
|
||||||
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
|
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
|
||||||
@echo " make dev - Start all services in development mode (with hot-reloading)"
|
@echo " make dev - Start all services in development mode (with hot-reloading)"
|
||||||
|
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
|
||||||
@echo " make dev-daemon - Start dev services in background (daemon mode)"
|
@echo " make dev-daemon - Start dev services in background (daemon mode)"
|
||||||
|
@echo " make dev-daemon-pro - Start dev daemon + Gateway mode (experimental)"
|
||||||
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
|
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
|
||||||
|
@echo " make start-pro - Start in prod + Gateway mode (experimental)"
|
||||||
@echo " make start-daemon - Start prod services in background (daemon mode)"
|
@echo " make start-daemon - Start prod services in background (daemon mode)"
|
||||||
|
@echo " make start-daemon-pro - Start prod daemon + Gateway mode (experimental)"
|
||||||
@echo " make stop - Stop all running services"
|
@echo " make stop - Stop all running services"
|
||||||
@echo " make clean - Clean up processes and temporary files"
|
@echo " make clean - Clean up processes and temporary files"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Docker Production Commands:"
|
@echo "Docker Production Commands:"
|
||||||
@echo " make up - Build and start production Docker services (localhost:2026)"
|
@echo " make up - Build and start production Docker services (localhost:2026)"
|
||||||
|
@echo " make up-pro - Build and start production Docker in Gateway mode (experimental)"
|
||||||
@echo " make down - Stop and remove production Docker containers"
|
@echo " make down - Stop and remove production Docker containers"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Docker Development Commands:"
|
@echo "Docker Development Commands:"
|
||||||
@echo " make docker-init - Pull the sandbox image"
|
@echo " make docker-init - Pull the sandbox image"
|
||||||
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
|
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
|
||||||
|
@echo " make docker-start-pro - Start Docker in Gateway mode (experimental, no LangGraph container)"
|
||||||
@echo " make docker-stop - Stop Docker development services"
|
@echo " make docker-stop - Stop Docker development services"
|
||||||
@echo " make docker-logs - View Docker development logs"
|
@echo " make docker-logs - View Docker development logs"
|
||||||
@echo " make docker-logs-frontend - View Docker frontend logs"
|
@echo " make docker-logs-frontend - View Docker frontend logs"
|
||||||
@@ -52,9 +57,6 @@ setup:
|
|||||||
doctor:
|
doctor:
|
||||||
@$(BACKEND_UV_RUN) python ../scripts/doctor.py
|
@$(BACKEND_UV_RUN) python ../scripts/doctor.py
|
||||||
|
|
||||||
detect-thread-boundaries:
|
|
||||||
@$(PYTHON) ./scripts/detect_thread_boundaries.py
|
|
||||||
|
|
||||||
config:
|
config:
|
||||||
@$(PYTHON) ./scripts/configure.py
|
@$(PYTHON) ./scripts/configure.py
|
||||||
|
|
||||||
@@ -71,8 +73,6 @@ install:
|
|||||||
@cd backend && uv sync
|
@cd backend && uv sync
|
||||||
@echo "Installing frontend dependencies..."
|
@echo "Installing frontend dependencies..."
|
||||||
@cd frontend && pnpm install
|
@cd frontend && pnpm install
|
||||||
@echo "Installing pre-commit hooks..."
|
|
||||||
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
|
|
||||||
@echo "✓ All dependencies installed"
|
@echo "✓ All dependencies installed"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "=========================================="
|
@echo "=========================================="
|
||||||
@@ -99,7 +99,7 @@ setup-sandbox:
|
|||||||
echo ""; \
|
echo ""; \
|
||||||
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
|
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
|
||||||
echo "Detected Apple Container on macOS, pulling image..."; \
|
echo "Detected Apple Container on macOS, pulling image..."; \
|
||||||
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
|
container pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
|
||||||
fi; \
|
fi; \
|
||||||
if command -v docker >/dev/null 2>&1; then \
|
if command -v docker >/dev/null 2>&1; then \
|
||||||
echo "Pulling image using Docker..."; \
|
echo "Pulling image using Docker..."; \
|
||||||
@@ -121,21 +121,41 @@ dev:
|
|||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev
|
||||||
|
|
||||||
|
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
|
||||||
|
dev-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway
|
||||||
|
|
||||||
# Start all services in production mode (with optimizations)
|
# Start all services in production mode (with optimizations)
|
||||||
start:
|
start:
|
||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod
|
||||||
|
|
||||||
|
# Start all services in prod + Gateway mode (experimental)
|
||||||
|
start-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway
|
||||||
|
|
||||||
# Start all services in daemon mode (background)
|
# Start all services in daemon mode (background)
|
||||||
dev-daemon:
|
dev-daemon:
|
||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --daemon
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --daemon
|
||||||
|
|
||||||
|
# Start daemon + Gateway mode (experimental)
|
||||||
|
dev-daemon-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway --daemon
|
||||||
|
|
||||||
# Start prod services in daemon mode (background)
|
# Start prod services in daemon mode (background)
|
||||||
start-daemon:
|
start-daemon:
|
||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --daemon
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --daemon
|
||||||
|
|
||||||
|
# Start prod daemon + Gateway mode (experimental)
|
||||||
|
start-daemon-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway --daemon
|
||||||
|
|
||||||
# Stop all services
|
# Stop all services
|
||||||
stop:
|
stop:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --stop
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --stop
|
||||||
@@ -160,6 +180,10 @@ docker-init:
|
|||||||
docker-start:
|
docker-start:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start
|
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start
|
||||||
|
|
||||||
|
# Start Docker in Gateway mode (experimental)
|
||||||
|
docker-start-pro:
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start --gateway
|
||||||
|
|
||||||
# Stop Docker development environment
|
# Stop Docker development environment
|
||||||
docker-stop:
|
docker-stop:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh stop
|
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh stop
|
||||||
@@ -182,6 +206,10 @@ docker-logs-gateway:
|
|||||||
up:
|
up:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh
|
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh
|
||||||
|
|
||||||
|
# Build and start production services in Gateway mode
|
||||||
|
up-pro:
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh --gateway
|
||||||
|
|
||||||
# Stop and remove production containers
|
# Stop and remove production containers
|
||||||
down:
|
down:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh down
|
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh down
|
||||||
|
|||||||
@@ -243,9 +243,10 @@ make up # Build images and start all production services
|
|||||||
make down # Stop and remove containers
|
make down # Stop and remove containers
|
||||||
```
|
```
|
||||||
|
|
||||||
Access: http://localhost:2026
|
> [!NOTE]
|
||||||
|
> The LangGraph agent server currently runs via `langgraph dev` (the open-source CLI server).
|
||||||
|
|
||||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
Access: http://localhost:2026
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||||
|
|
||||||
@@ -253,7 +254,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
|||||||
|
|
||||||
If you prefer running services locally:
|
If you prefer running services locally:
|
||||||
|
|
||||||
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root. Set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or `DEER_FLOW_CONFIG_PATH` to point at a specific config file. Runtime state defaults to `.deer-flow` under the project root and can be moved with `DEER_FLOW_HOME`; skills default to `skills/` under the project root and can be moved with `DEER_FLOW_SKILLS_PATH`. Run `make doctor` to verify your setup before starting.
|
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root (can be overridden via `DEER_FLOW_CONFIG_PATH`). Run `make doctor` to verify your setup before starting.
|
||||||
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
||||||
|
|
||||||
1. **Check prerequisites**:
|
1. **Check prerequisites**:
|
||||||
@@ -263,7 +264,7 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
|
|||||||
|
|
||||||
2. **Install dependencies**:
|
2. **Install dependencies**:
|
||||||
```bash
|
```bash
|
||||||
make install # Install backend + frontend dependencies + pre-commit hooks
|
make install # Install backend + frontend dependencies
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **(Optional) Pre-pull sandbox image**:
|
3. **(Optional) Pre-pull sandbox image**:
|
||||||
@@ -288,31 +289,53 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
|
|||||||
|
|
||||||
#### Startup Modes
|
#### Startup Modes
|
||||||
|
|
||||||
DeerFlow runs the agent runtime inside the Gateway API. Development mode enables hot-reload; production mode uses a pre-built frontend.
|
DeerFlow supports multiple startup modes across two dimensions:
|
||||||
|
|
||||||
|
- **Dev / Prod** — dev enables hot-reload; prod uses pre-built frontend
|
||||||
|
- **Standard / Gateway** — standard uses a separate LangGraph server (4 processes); Gateway mode (experimental) embeds the agent runtime in the Gateway API (3 processes)
|
||||||
|
|
||||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||||
|
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||||
|
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||||
|
|
||||||
| Action | Local | Docker Dev | Docker Prod |
|
| Action | Local | Docker Dev | Docker Prod |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||||
|
|
||||||
Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx.
|
> **Gateway mode** eliminates the LangGraph server process — the Gateway API handles agent execution directly via async tasks, managing its own concurrency.
|
||||||
|
|
||||||
|
#### Why Gateway Mode?
|
||||||
|
|
||||||
|
In standard mode, DeerFlow runs a dedicated [LangGraph Platform](https://langchain-ai.github.io/langgraph/) server alongside the Gateway API. This architecture works well but has trade-offs:
|
||||||
|
|
||||||
|
| | Standard Mode | Gateway Mode |
|
||||||
|
|---|---|---|
|
||||||
|
| **Architecture** | Gateway (REST API) + LangGraph (agent runtime) | Gateway embeds agent runtime |
|
||||||
|
| **Concurrency** | `--n-jobs-per-worker` per worker (requires license) | `--workers` × async tasks (no per-worker cap) |
|
||||||
|
| **Containers / Processes** | 4 (frontend, gateway, langgraph, nginx) | 3 (frontend, gateway, nginx) |
|
||||||
|
| **Resource usage** | Higher (two Python runtimes) | Lower (single Python runtime) |
|
||||||
|
| **LangGraph Platform license** | Required for production images | Not required |
|
||||||
|
| **Cold start** | Slower (two services to initialize) | Faster |
|
||||||
|
|
||||||
|
Both modes are functionally equivalent — the same agents, tools, and skills work in either mode.
|
||||||
|
|
||||||
#### Docker Production Deployment
|
#### Docker Production Deployment
|
||||||
|
|
||||||
`deploy.sh` supports building and starting separately:
|
`deploy.sh` supports building and starting separately. Images are mode-agnostic — runtime mode is selected at start time:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# One-step (build + start)
|
# One-step (build + start)
|
||||||
deploy.sh
|
deploy.sh # standard mode (default)
|
||||||
|
deploy.sh --gateway # gateway mode
|
||||||
|
|
||||||
# Two-step (build once, start later)
|
# Two-step (build once, start with any mode)
|
||||||
deploy.sh build # build all images
|
deploy.sh build # build all images
|
||||||
deploy.sh start # start pre-built images
|
deploy.sh start # start in standard mode
|
||||||
|
deploy.sh start --gateway # start in gateway mode
|
||||||
|
|
||||||
# Stop
|
# Stop
|
||||||
deploy.sh down
|
deploy.sh down
|
||||||
@@ -347,14 +370,13 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
|
|||||||
| Feishu / Lark | WebSocket | Moderate |
|
| Feishu / Lark | WebSocket | Moderate |
|
||||||
| WeChat | Tencent iLink (long-polling) | Moderate |
|
| WeChat | Tencent iLink (long-polling) | Moderate |
|
||||||
| WeCom | WebSocket | Moderate |
|
| WeCom | WebSocket | Moderate |
|
||||||
| DingTalk | Stream Push (WebSocket) | Moderate |
|
|
||||||
|
|
||||||
**Configuration in `config.yaml`:**
|
**Configuration in `config.yaml`:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api)
|
# LangGraph Server URL (default: http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL (default: http://localhost:8001)
|
# Gateway API URL (default: http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -417,19 +439,11 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # Client ID of your DingTalk application
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # Client Secret of your DingTalk application
|
|
||||||
allowed_users: [] # empty = allow all
|
|
||||||
card_template_id: "" # Optional: AI Card template ID for streaming typewriter effect
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
|
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
|
||||||
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
|
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
|
||||||
- IM channel workers call Gateway's LangGraph-compatible API internally and automatically attach process-local internal auth plus the CSRF cookie/header pair required for thread and run creation.
|
|
||||||
|
|
||||||
Set the corresponding API keys in your `.env` file:
|
Set the corresponding API keys in your `.env` file:
|
||||||
|
|
||||||
@@ -452,10 +466,6 @@ WECHAT_ILINK_BOT_ID=your_ilink_bot_id
|
|||||||
# WeCom
|
# WeCom
|
||||||
WECOM_BOT_ID=your_bot_id
|
WECOM_BOT_ID=your_bot_id
|
||||||
WECOM_BOT_SECRET=your_bot_secret
|
WECOM_BOT_SECRET=your_bot_secret
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Telegram Setup**
|
**Telegram Setup**
|
||||||
@@ -494,15 +504,7 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
|
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
|
||||||
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
|
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
|
||||||
|
|
||||||
**DingTalk Setup**
|
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://langgraph:2024` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
|
|
||||||
1. Create a DingTalk application in the [DingTalk Developer Console](https://open.dingtalk.com/) and enable **Robot** capability.
|
|
||||||
2. Set the message receiving mode to **Stream Mode** in the robot configuration page.
|
|
||||||
3. Copy the `Client ID` and `Client Secret`, set `DINGTALK_CLIENT_ID` and `DINGTALK_CLIENT_SECRET` in `.env`, and enable the channel in `config.yaml`.
|
|
||||||
4. *(Optional)* To enable streaming AI Card replies (typewriter effect), create an **AI Card** template on the [DingTalk Card Platform](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), then set `card_template_id` in `config.yaml` to the template ID. You also need to apply for the `Card.Streaming.Write` and `Card.Instance.Write` permissions.
|
|
||||||
|
|
||||||
|
|
||||||
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://gateway:8001/api` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
|
||||||
|
|
||||||
**Commands**
|
**Commands**
|
||||||
|
|
||||||
@@ -546,15 +548,6 @@ LANGFUSE_BASE_URL=https://cloud.langfuse.com
|
|||||||
|
|
||||||
If you are using a self-hosted Langfuse instance, set `LANGFUSE_BASE_URL` to your deployment URL.
|
If you are using a self-hosted Langfuse instance, set `LANGFUSE_BASE_URL` to your deployment URL.
|
||||||
|
|
||||||
**Trace correlation fields.** Every agent run is annotated with Langfuse's reserved trace attributes so the Sessions and Users pages light up automatically:
|
|
||||||
|
|
||||||
- `session_id` = LangGraph `thread_id` — groups every trace of the same conversation
|
|
||||||
- `user_id` = effective user from `get_effective_user_id()` (falls back to `default` in no-auth mode)
|
|
||||||
- `trace_name` = assistant id (defaults to `lead-agent`)
|
|
||||||
- `tags` = `[env:<DEER_FLOW_ENV>, model:<model_name>]` (omitted when not set)
|
|
||||||
|
|
||||||
These are injected into `RunnableConfig.metadata` at the graph invocation root for both the gateway path (`runtime/runs/worker.py::run_agent`) and the embedded path (`client.py::DeerFlowClient.stream`), so any LangChain-compatible callback can read them. Set `DEER_FLOW_ENV` (or `ENVIRONMENT`) to tag traces by deployment environment.
|
|
||||||
|
|
||||||
#### Using Both Providers
|
#### Using Both Providers
|
||||||
|
|
||||||
If both LangSmith and Langfuse are enabled, DeerFlow attaches both tracing callbacks and reports the same model activity to both systems.
|
If both LangSmith and Langfuse are enabled, DeerFlow attaches both tracing callbacks and reports the same model activity to both systems.
|
||||||
@@ -637,7 +630,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
|
|||||||
|
|
||||||
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
||||||
|
|
||||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
|
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
|
||||||
|
|
||||||
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
||||||
|
|
||||||
@@ -665,8 +658,6 @@ This is the difference between a chatbot with tool access and an agent with an a
|
|||||||
|
|
||||||
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
|
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
|
||||||
|
|
||||||
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
|
|
||||||
|
|
||||||
### Long-Term Memory
|
### Long-Term Memory
|
||||||
|
|
||||||
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
|
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
|
||||||
|
|||||||
+3
-22
@@ -228,7 +228,7 @@ make down # Stop and remove containers
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Le runtime d'agent s'exécute actuellement dans la Gateway. nginx réécrit `/api/langgraph/*` vers l'API compatible LangGraph servie par la Gateway.
|
> Le serveur d'agents LangGraph fonctionne actuellement via `langgraph dev` (le serveur CLI open source).
|
||||||
|
|
||||||
Accès : http://localhost:2026
|
Accès : http://localhost:2026
|
||||||
|
|
||||||
@@ -290,14 +290,13 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca
|
|||||||
| Telegram | Bot API (long-polling) | Facile |
|
| Telegram | Bot API (long-polling) | Facile |
|
||||||
| Slack | Socket Mode | Modérée |
|
| Slack | Socket Mode | Modérée |
|
||||||
| Feishu / Lark | WebSocket | Modérée |
|
| Feishu / Lark | WebSocket | Modérée |
|
||||||
| DingTalk | Stream Push (WebSocket) | Modérée |
|
|
||||||
|
|
||||||
**Configuration dans `config.yaml` :**
|
**Configuration dans `config.yaml` :**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api)
|
# LangGraph Server URL (default: http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL (default: http://localhost:8001)
|
# Gateway API URL (default: http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -342,13 +341,6 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # ClientId depuis DingTalk Open Platform
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret depuis DingTalk Open Platform
|
|
||||||
allowed_users: [] # vide = tout le monde autorisé
|
|
||||||
card_template_id: "" # Optionnel : ID de modèle AI Card pour l'effet machine à écrire en streaming
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Définissez les clés API correspondantes dans votre fichier `.env` :
|
Définissez les clés API correspondantes dans votre fichier `.env` :
|
||||||
@@ -364,10 +356,6 @@ SLACK_APP_TOKEN=xapp-...
|
|||||||
# Feishu / Lark
|
# Feishu / Lark
|
||||||
FEISHU_APP_ID=cli_xxxx
|
FEISHU_APP_ID=cli_xxxx
|
||||||
FEISHU_APP_SECRET=your_app_secret
|
FEISHU_APP_SECRET=your_app_secret
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Configuration Telegram**
|
**Configuration Telegram**
|
||||||
@@ -390,13 +378,6 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
3. Dans **Events**, abonnez-vous à `im.message.receive_v1` et sélectionnez le mode **Long Connection**.
|
3. Dans **Events**, abonnez-vous à `im.message.receive_v1` et sélectionnez le mode **Long Connection**.
|
||||||
4. Copiez l'App ID et l'App Secret. Définissez `FEISHU_APP_ID` et `FEISHU_APP_SECRET` dans `.env` et activez le canal dans `config.yaml`.
|
4. Copiez l'App ID et l'App Secret. Définissez `FEISHU_APP_ID` et `FEISHU_APP_SECRET` dans `.env` et activez le canal dans `config.yaml`.
|
||||||
|
|
||||||
**Configuration DingTalk**
|
|
||||||
|
|
||||||
1. Créez une application sur [DingTalk Open Platform](https://open.dingtalk.com/) et activez la capacité **Robot**.
|
|
||||||
2. Dans la page de configuration du robot, définissez le mode de réception des messages sur **Stream**.
|
|
||||||
3. Copiez le `Client ID` et le `Client Secret`. Définissez `DINGTALK_CLIENT_ID` et `DINGTALK_CLIENT_SECRET` dans `.env` et activez le canal dans `config.yaml`.
|
|
||||||
4. *(Optionnel)* Pour activer les réponses en streaming AI Card (effet machine à écrire), créez un modèle **AI Card** sur la [plateforme de cartes DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), puis définissez `card_template_id` dans `config.yaml` avec l'ID du modèle. Vous devez également demander les permissions `Card.Streaming.Write` et `Card.Instance.Write`.
|
|
||||||
|
|
||||||
**Commandes**
|
**Commandes**
|
||||||
|
|
||||||
Une fois un canal connecté, vous pouvez interagir avec DeerFlow directement depuis le chat :
|
Une fois un canal connecté, vous pouvez interagir avec DeerFlow directement depuis le chat :
|
||||||
|
|||||||
+3
-22
@@ -181,7 +181,7 @@ make down # コンテナを停止して削除
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Agentランタイムは現在Gateway内で実行されます。`/api/langgraph/*`はnginxによってGatewayのLangGraph-compatible APIへ書き換えられます。
|
> LangGraphエージェントサーバーは現在`langgraph dev`(オープンソースCLIサーバー)経由で実行されます。
|
||||||
|
|
||||||
アクセス: http://localhost:2026
|
アクセス: http://localhost:2026
|
||||||
|
|
||||||
@@ -243,14 +243,13 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート
|
|||||||
| Telegram | Bot API(ロングポーリング) | 簡単 |
|
| Telegram | Bot API(ロングポーリング) | 簡単 |
|
||||||
| Slack | Socket Mode | 中程度 |
|
| Slack | Socket Mode | 中程度 |
|
||||||
| Feishu / Lark | WebSocket | 中程度 |
|
| Feishu / Lark | WebSocket | 中程度 |
|
||||||
| DingTalk | Stream Push(WebSocket) | 中程度 |
|
|
||||||
|
|
||||||
**`config.yaml`での設定:**
|
**`config.yaml`での設定:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL(デフォルト: http://localhost:8001/api)
|
# LangGraphサーバーURL(デフォルト: http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL(デフォルト: http://localhost:8001)
|
# Gateway API URL(デフォルト: http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -295,13 +294,6 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # DingTalk Open PlatformのClientId
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # DingTalk Open PlatformのClientSecret
|
|
||||||
allowed_users: [] # 空 = 全員許可
|
|
||||||
card_template_id: "" # オプション:ストリーミングタイプライター効果用のAIカードテンプレートID
|
|
||||||
```
|
```
|
||||||
|
|
||||||
対応するAPIキーを`.env`ファイルに設定します:
|
対応するAPIキーを`.env`ファイルに設定します:
|
||||||
@@ -317,10 +309,6 @@ SLACK_APP_TOKEN=xapp-...
|
|||||||
# Feishu / Lark
|
# Feishu / Lark
|
||||||
FEISHU_APP_ID=cli_xxxx
|
FEISHU_APP_ID=cli_xxxx
|
||||||
FEISHU_APP_SECRET=your_app_secret
|
FEISHU_APP_SECRET=your_app_secret
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Telegramのセットアップ**
|
**Telegramのセットアップ**
|
||||||
@@ -343,13 +331,6 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
3. **イベント**で`im.message.receive_v1`を購読し、**ロングコネクション**モードを選択。
|
3. **イベント**で`im.message.receive_v1`を購読し、**ロングコネクション**モードを選択。
|
||||||
4. App IDとApp Secretをコピー。`.env`に`FEISHU_APP_ID`と`FEISHU_APP_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
|
4. App IDとApp Secretをコピー。`.env`に`FEISHU_APP_ID`と`FEISHU_APP_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
|
||||||
|
|
||||||
**DingTalkのセットアップ**
|
|
||||||
|
|
||||||
1. [DingTalk Open Platform](https://open.dingtalk.com/)でアプリを作成し、**ロボット**機能を有効化します。
|
|
||||||
2. ロボット設定ページでメッセージ受信モードを**Streamモード**に設定します。
|
|
||||||
3. `Client ID`と`Client Secret`をコピー。`.env`に`DINGTALK_CLIENT_ID`と`DINGTALK_CLIENT_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
|
|
||||||
4. *(オプション)* ストリーミングAIカード返信(タイプライター効果)を有効にするには、[DingTalkカードプラットフォーム](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)で**AIカード**テンプレートを作成し、`config.yaml`の`card_template_id`にテンプレートIDを設定します。`Card.Streaming.Write` および `Card.Instance.Write` 権限の申請も必要です。
|
|
||||||
|
|
||||||
**コマンド**
|
**コマンド**
|
||||||
|
|
||||||
チャネル接続後、チャットから直接DeerFlowと対話できます:
|
チャネル接続後、チャットから直接DeerFlowと対話できます:
|
||||||
|
|||||||
@@ -256,7 +256,6 @@ DeerFlow принимает задачи прямо из мессенджеро
|
|||||||
| Telegram | Bot API (long-polling) | Просто |
|
| Telegram | Bot API (long-polling) | Просто |
|
||||||
| Slack | Socket Mode | Средне |
|
| Slack | Socket Mode | Средне |
|
||||||
| Feishu / Lark | WebSocket | Средне |
|
| Feishu / Lark | WebSocket | Средне |
|
||||||
| DingTalk | Stream Push (WebSocket) | Средне |
|
|
||||||
|
|
||||||
**Конфигурация в `config.yaml`:**
|
**Конфигурация в `config.yaml`:**
|
||||||
|
|
||||||
@@ -279,13 +278,6 @@ channels:
|
|||||||
enabled: true
|
enabled: true
|
||||||
bot_token: $TELEGRAM_BOT_TOKEN
|
bot_token: $TELEGRAM_BOT_TOKEN
|
||||||
allowed_users: []
|
allowed_users: []
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # ClientId с DingTalk Open Platform
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret с DingTalk Open Platform
|
|
||||||
allowed_users: [] # пусто = разрешить всем
|
|
||||||
card_template_id: "" # Опционально: ID шаблона AI Card для потокового эффекта печатной машинки
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Настройка Telegram**
|
**Настройка Telegram**
|
||||||
@@ -293,13 +285,6 @@ channels:
|
|||||||
1. Напишите [@BotFather](https://t.me/BotFather), отправьте `/newbot` и скопируйте HTTP API-токен.
|
1. Напишите [@BotFather](https://t.me/BotFather), отправьте `/newbot` и скопируйте HTTP API-токен.
|
||||||
2. Укажите `TELEGRAM_BOT_TOKEN` в `.env` и включите канал в `config.yaml`.
|
2. Укажите `TELEGRAM_BOT_TOKEN` в `.env` и включите канал в `config.yaml`.
|
||||||
|
|
||||||
**Настройка DingTalk**
|
|
||||||
|
|
||||||
1. Создайте приложение на [DingTalk Open Platform](https://open.dingtalk.com/) и включите возможность **Робот**.
|
|
||||||
2. На странице настроек робота установите режим приёма сообщений на **Stream**.
|
|
||||||
3. Скопируйте `Client ID` и `Client Secret`. Укажите `DINGTALK_CLIENT_ID` и `DINGTALK_CLIENT_SECRET` в `.env` и включите канал в `config.yaml`.
|
|
||||||
4. *(Опционально)* Для включения потоковых ответов AI Card (эффект печатной машинки) создайте шаблон **AI Card** на [платформе карточек DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), затем укажите `card_template_id` в `config.yaml` с ID шаблона. Также необходимо запросить разрешения `Card.Streaming.Write` и `Card.Instance.Write`.
|
|
||||||
|
|
||||||
**Доступные команды**
|
**Доступные команды**
|
||||||
|
|
||||||
| Команда | Описание |
|
| Команда | Описание |
|
||||||
|
|||||||
+4
-23
@@ -184,7 +184,7 @@ make down # 停止并移除容器
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 当前 Agent 运行时嵌入在 Gateway 中运行,`/api/langgraph/*` 会由 nginx 重写到 Gateway 的 LangGraph-compatible API。
|
> 当前 LangGraph agent server 通过开源 CLI 服务 `langgraph dev` 运行。
|
||||||
|
|
||||||
访问地址:http://localhost:2026
|
访问地址:http://localhost:2026
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ make down # 停止并移除容器
|
|||||||
|
|
||||||
如果你更希望直接在本地启动各个服务:
|
如果你更希望直接在本地启动各个服务:
|
||||||
|
|
||||||
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`。可以用 `DEER_FLOW_PROJECT_ROOT` 显式指定项目根目录,也可以用 `DEER_FLOW_CONFIG_PATH` 指向某个具体配置文件。运行期状态默认写到项目根目录下的 `.deer-flow`,可用 `DEER_FLOW_HOME` 覆盖;skills 默认读取项目根目录下的 `skills/`,可用 `DEER_FLOW_SKILLS_PATH` 覆盖。
|
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`,也可以通过 `DEER_FLOW_CONFIG_PATH` 覆盖。
|
||||||
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
|
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
|
||||||
|
|
||||||
1. **检查依赖环境**:
|
1. **检查依赖环境**:
|
||||||
@@ -248,14 +248,13 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
|
|||||||
| Slack | Socket Mode | 中等 |
|
| Slack | Socket Mode | 中等 |
|
||||||
| Feishu / Lark | WebSocket | 中等 |
|
| Feishu / Lark | WebSocket | 中等 |
|
||||||
| 企业微信智能机器人 | WebSocket | 中等 |
|
| 企业微信智能机器人 | WebSocket | 中等 |
|
||||||
| 钉钉 | Stream Push(WebSocket) | 中等 |
|
|
||||||
|
|
||||||
**`config.yaml` 中的配置示例:**
|
**`config.yaml` 中的配置示例:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL(默认:http://localhost:8001/api)
|
# LangGraph Server URL(默认:http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL(默认:http://localhost:8001)
|
# Gateway API URL(默认:http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -305,13 +304,6 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # 钉钉开放平台 ClientId
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # 钉钉开放平台 ClientSecret
|
|
||||||
allowed_users: [] # 留空表示允许所有人
|
|
||||||
card_template_id: "" # 可选:AI 卡片模板 ID,用于流式打字机效果
|
|
||||||
```
|
```
|
||||||
|
|
||||||
说明:
|
说明:
|
||||||
@@ -335,10 +327,6 @@ FEISHU_APP_SECRET=your_app_secret
|
|||||||
# 企业微信智能机器人
|
# 企业微信智能机器人
|
||||||
WECOM_BOT_ID=your_bot_id
|
WECOM_BOT_ID=your_bot_id
|
||||||
WECOM_BOT_SECRET=your_bot_secret
|
WECOM_BOT_SECRET=your_bot_secret
|
||||||
|
|
||||||
# 钉钉
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Telegram 配置**
|
**Telegram 配置**
|
||||||
@@ -369,13 +357,6 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
|
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
|
||||||
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
|
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
|
||||||
|
|
||||||
**钉钉配置**
|
|
||||||
|
|
||||||
1. 在 [钉钉开放平台](https://open.dingtalk.com/) 创建应用,并启用 **机器人** 能力。
|
|
||||||
2. 在机器人配置页面设置消息接收模式为 **Stream模式**。
|
|
||||||
3. 复制 `Client ID` 和 `Client Secret`,在 `.env` 中设置 `DINGTALK_CLIENT_ID` 和 `DINGTALK_CLIENT_SECRET`,并在 `config.yaml` 中启用该渠道。
|
|
||||||
4. *(可选)* 如需开启流式 AI 卡片回复(打字机效果),请在[钉钉卡片平台](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)创建 **AI 卡片**模板,然后在 `config.yaml` 中将 `card_template_id` 设为该模板 ID。同时需要申请 `Card.Streaming.Write` 和 `Card.Instance.Write` 权限。
|
|
||||||
|
|
||||||
**命令**
|
**命令**
|
||||||
|
|
||||||
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
|
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
|
||||||
|
|||||||
+52
-91
@@ -7,13 +7,15 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
DeerFlow is a LangGraph-based AI super agent system with a full-stack architecture. The backend provides a "super agent" with sandbox execution, persistent memory, subagent delegation, and extensible tool integration - all operating in per-thread isolated environments.
|
DeerFlow is a LangGraph-based AI super agent system with a full-stack architecture. The backend provides a "super agent" with sandbox execution, persistent memory, subagent delegation, and extensible tool integration - all operating in per-thread isolated environments.
|
||||||
|
|
||||||
**Architecture**:
|
**Architecture**:
|
||||||
- **Gateway API** (port 8001): REST API plus embedded LangGraph-compatible agent runtime
|
- **LangGraph Server** (port 2024): Agent runtime and workflow execution
|
||||||
|
- **Gateway API** (port 8001): REST API for models, MCP, skills, memory, artifacts, uploads, and local thread cleanup
|
||||||
- **Frontend** (port 3000): Next.js web interface
|
- **Frontend** (port 3000): Next.js web interface
|
||||||
- **Nginx** (port 2026): Unified reverse proxy entry point
|
- **Nginx** (port 2026): Unified reverse proxy entry point
|
||||||
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
|
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
|
||||||
|
|
||||||
**Runtime**:
|
**Runtime Modes**:
|
||||||
- `make dev`, Docker dev, and production all run the agent runtime in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Nginx exposes that runtime at `/api/langgraph/*` and rewrites it to Gateway's native `/api/*` routers.
|
- **Standard mode** (`make dev`): LangGraph Server handles agent execution as a separate process. 4 processes total.
|
||||||
|
- **Gateway mode** (`make dev-pro`, experimental): Agent runtime embedded in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Service manages its own concurrency via async tasks. 3 processes total, no LangGraph Server.
|
||||||
|
|
||||||
**Project Structure**:
|
**Project Structure**:
|
||||||
```
|
```
|
||||||
@@ -23,7 +25,7 @@ deer-flow/
|
|||||||
├── extensions_config.json # MCP servers and skills configuration
|
├── extensions_config.json # MCP servers and skills configuration
|
||||||
├── backend/ # Backend application (this directory)
|
├── backend/ # Backend application (this directory)
|
||||||
│ ├── Makefile # Backend-only commands (dev, gateway, lint)
|
│ ├── Makefile # Backend-only commands (dev, gateway, lint)
|
||||||
│ ├── langgraph.json # LangGraph Studio graph configuration
|
│ ├── langgraph.json # LangGraph server configuration
|
||||||
│ ├── packages/
|
│ ├── packages/
|
||||||
│ │ └── harness/ # deerflow-harness package (import: deerflow.*)
|
│ │ └── harness/ # deerflow-harness package (import: deerflow.*)
|
||||||
│ │ ├── pyproject.toml
|
│ │ ├── pyproject.toml
|
||||||
@@ -81,15 +83,16 @@ When making code changes, you MUST update the relevant documentation:
|
|||||||
```bash
|
```bash
|
||||||
make check # Check system requirements
|
make check # Check system requirements
|
||||||
make install # Install all dependencies (frontend + backend)
|
make install # Install all dependencies (frontend + backend)
|
||||||
make dev # Start all services (Gateway + Frontend + Nginx), with config.yaml preflight
|
make dev # Start all services (LangGraph + Gateway + Frontend + Nginx), with config.yaml preflight
|
||||||
make start # Start production services locally
|
make dev-pro # Gateway mode (experimental): skip LangGraph, agent runtime embedded in Gateway
|
||||||
|
make start-pro # Production + Gateway mode (experimental)
|
||||||
make stop # Stop all services
|
make stop # Stop all services
|
||||||
```
|
```
|
||||||
|
|
||||||
**Backend directory** (for backend development only):
|
**Backend directory** (for backend development only):
|
||||||
```bash
|
```bash
|
||||||
make install # Install backend dependencies
|
make install # Install backend dependencies
|
||||||
make dev # Run Gateway API with reload (port 8001)
|
make dev # Run LangGraph server only (port 2024)
|
||||||
make gateway # Run Gateway API only (port 8001)
|
make gateway # Run Gateway API only (port 8001)
|
||||||
make test # Run all backend tests
|
make test # Run all backend tests
|
||||||
make lint # Lint with ruff
|
make lint # Lint with ruff
|
||||||
@@ -112,7 +115,7 @@ CI runs these regression tests for every pull request via [.github/workflows/bac
|
|||||||
The backend is split into two layers with a strict dependency direction:
|
The backend is split into two layers with a strict dependency direction:
|
||||||
|
|
||||||
- **Harness** (`packages/harness/deerflow/`): Publishable agent framework package (`deerflow-harness`). Import prefix: `deerflow.*`. Contains agent orchestration, tools, sandbox, models, MCP, skills, config — everything needed to build and run agents.
|
- **Harness** (`packages/harness/deerflow/`): Publishable agent framework package (`deerflow-harness`). Import prefix: `deerflow.*`. Contains agent orchestration, tools, sandbox, models, MCP, skills, config — everything needed to build and run agents.
|
||||||
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram, DingTalk).
|
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram).
|
||||||
|
|
||||||
**Dependency rule**: App imports deerflow, but deerflow never imports app. This boundary is enforced by `tests/test_harness_boundary.py` which runs in CI.
|
**Dependency rule**: App imports deerflow, but deerflow never imports app. This boundary is enforced by `tests/test_harness_boundary.py` which runs in CI.
|
||||||
|
|
||||||
@@ -153,26 +156,20 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
### Middleware Chain
|
### Middleware Chain
|
||||||
|
|
||||||
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
|
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
|
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||||
5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
|
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
||||||
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||||
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||||
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||||
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
|
||||||
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||||
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
|
||||||
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
|
||||||
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
|
|
||||||
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
|
||||||
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
|
||||||
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
|
||||||
|
|
||||||
### Configuration System
|
### Configuration System
|
||||||
|
|
||||||
@@ -184,18 +181,6 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
|||||||
|
|
||||||
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
||||||
|
|
||||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
|
|
||||||
|
|
||||||
| Field | Why a restart is required |
|
|
||||||
|---|---|
|
|
||||||
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
|
|
||||||
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
|
|
||||||
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
|
|
||||||
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
|
|
||||||
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
|
|
||||||
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
|
|
||||||
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
|
|
||||||
|
|
||||||
Configuration priority:
|
Configuration priority:
|
||||||
1. Explicit `config_path` argument
|
1. Explicit `config_path` argument
|
||||||
2. `DEER_FLOW_CONFIG_PATH` environment variable
|
2. `DEER_FLOW_CONFIG_PATH` environment variable
|
||||||
@@ -217,9 +202,7 @@ Configuration priority:
|
|||||||
|
|
||||||
### Gateway API (`app/gateway/`)
|
### Gateway API (`app/gateway/`)
|
||||||
|
|
||||||
FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWAY_ENABLE_DOCS=false` to disable `/docs`, `/redoc`, and `/openapi.json` in production (default: enabled).
|
FastAPI application on port 8001 with health check at `GET /health`.
|
||||||
|
|
||||||
CORS is same-origin by default when requests enter through nginx on port 2026. Split-origin or port-forwarded browser clients must opt in with `GATEWAY_CORS_ORIGINS` (comma-separated exact origins); Gateway `CORSMiddleware` and `CSRFMiddleware` both read that variable so browser CORS and auth-origin checks stay aligned.
|
|
||||||
|
|
||||||
**Routers**:
|
**Routers**:
|
||||||
|
|
||||||
@@ -237,33 +220,27 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
|||||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
|
|
||||||
**RunManager / RunStore contract**:
|
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||||
- `RunManager.get()` is async; direct callers must `await` it.
|
|
||||||
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
|
|
||||||
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
|
|
||||||
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
|
|
||||||
|
|
||||||
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
|
|
||||||
|
|
||||||
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
||||||
|
|
||||||
**Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir`
|
**Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle
|
||||||
**Implementations**:
|
**Implementations**:
|
||||||
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
- `LocalSandboxProvider` - Singleton local filesystem execution with path mappings
|
||||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||||
- Translation: `LocalSandboxProvider` builds per-thread `PathMapping`s for the user-data prefixes at acquire time; `tools.py` keeps `replace_virtual_path()` / `replace_virtual_paths_in_command()` as a defense-in-depth layer (and for path validation). AIO has the directories volume-mounted at the same virtual paths inside its container, so both implementations accept `/mnt/user-data/...` natively.
|
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||||
- Detection: `is_local_sandbox()` accepts both `sandbox_id == "local"` (legacy / no-thread) and `sandbox_id.startswith("local:")` (per-thread)
|
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||||
|
|
||||||
**Sandbox Tools** (in `packages/harness/deerflow/sandbox/tools.py`):
|
**Sandbox Tools** (in `packages/harness/deerflow/sandbox/tools.py`):
|
||||||
- `bash` - Execute commands with path translation and error handling
|
- `bash` - Execute commands with path translation and error handling
|
||||||
- `ls` - Directory listing (tree format, max 2 levels)
|
- `ls` - Directory listing (tree format, max 2 levels)
|
||||||
- `read_file` - Read file contents with optional line range
|
- `read_file` - Read file contents with optional line range
|
||||||
- `write_file` - Write/append to files, creates directories; overwrites by default and exposes the `append` argument in the model-facing schema for end-of-file writes
|
- `write_file` - Write/append to files, creates directories
|
||||||
- `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process
|
- `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process
|
||||||
|
|
||||||
### Subagent System (`packages/harness/deerflow/subagents/`)
|
### Subagent System (`packages/harness/deerflow/subagents/`)
|
||||||
@@ -283,10 +260,8 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
|
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
|
||||||
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
|
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
|
||||||
- `view_image` - Read image as base64 (added only if model supports vision)
|
- `view_image` - Read image as base64 (added only if model supports vision)
|
||||||
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
|
|
||||||
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
|
|
||||||
4. **Subagent tool** (if enabled):
|
4. **Subagent tool** (if enabled):
|
||||||
- `task` - Delegate to subagent (description, prompt, subagent_type)
|
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
|
||||||
|
|
||||||
**Community tools** (`packages/harness/deerflow/community/`):
|
**Community tools** (`packages/harness/deerflow/community/`):
|
||||||
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
|
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
|
||||||
@@ -334,10 +309,9 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
|
|
||||||
### IM Channels System (`app/channels/`)
|
### IM Channels System (`app/channels/`)
|
||||||
|
|
||||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via the LangGraph Server.
|
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via the LangGraph Server.
|
||||||
|
|
||||||
|
**Architecture**: Channels communicate with the LangGraph Server through `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side.
|
||||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
|
||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
||||||
@@ -345,25 +319,23 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
||||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||||
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
- `slack.py` / `feishu.py` / `telegram.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place)
|
||||||
|
|
||||||
**Message Flow**:
|
**Message Flow**:
|
||||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||||
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
3. For chat: look up/create thread on LangGraph Server
|
||||||
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||||
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||||
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||||
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
7. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||||
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
8. Outbound → channel callbacks → platform reply
|
||||||
9. Outbound → channel callbacks → platform reply
|
|
||||||
|
|
||||||
**Configuration** (`config.yaml` -> `channels`):
|
**Configuration** (`config.yaml` -> `channels`):
|
||||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
- `langgraph_url` - LangGraph Server URL (default: `http://localhost:2024`)
|
||||||
- `gateway_url` - Gateway API URL for auxiliary commands (default: `http://localhost:8001`)
|
- `gateway_url` - Gateway API URL for auxiliary commands (default: `http://localhost:8001`)
|
||||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://langgraph:2024` / `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token)
|
||||||
|
|
||||||
|
|
||||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||||
|
|
||||||
@@ -376,11 +348,10 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
**Per-User Isolation**:
|
**Per-User Isolation**:
|
||||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||||
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
|
|
||||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||||
- Absolute `storage_path` in config opts out of per-user isolation
|
- Absolute `storage_path` in config opts out of per-user isolation
|
||||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
|
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||||
|
|
||||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||||
@@ -409,24 +380,6 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
- `resolve_variable(path)` - Import module and return variable (e.g., `module.path:variable_name`)
|
- `resolve_variable(path)` - Import module and return variable (e.g., `module.path:variable_name`)
|
||||||
- `resolve_class(path, base_class)` - Import and validate class against base class
|
- `resolve_class(path, base_class)` - Import and validate class against base class
|
||||||
|
|
||||||
### Tracing System (`packages/harness/deerflow/tracing/`)
|
|
||||||
|
|
||||||
LangSmith and Langfuse are both supported. The wiring lives in two layers:
|
|
||||||
|
|
||||||
- `factory.py::build_tracing_callbacks()` — returns the LangChain `CallbackHandler` list for the providers currently enabled via env vars (`LANGSMITH_TRACING`, `LANGFUSE_TRACING`, etc.). The handlers are attached at the **graph invocation root** for in-graph runs (`make_lead_agent` and `DeerFlowClient.stream` both append them to `config["callbacks"]` before invoking the graph) so a single run produces one trace with all node / LLM / tool calls as child spans. Standalone callers — anything that invokes a model outside such a graph (e.g. `MemoryUpdater`) — keep `create_chat_model`'s default `attach_tracing=True`, which falls back to model-level callback attachment.
|
|
||||||
- `metadata.py::build_langfuse_trace_metadata()` — builds the Langfuse-reserved trace attributes for `RunnableConfig.metadata`. The Langfuse v4 `langchain.CallbackHandler` lifts these onto the root trace (see its `_parse_langfuse_trace_attributes`), but only when it sees `on_chain_start(parent_run_id=None)` — which is why the callbacks have to live at the graph root, not the model.
|
|
||||||
|
|
||||||
**Trace-attribute injection points**: both `runtime/runs/worker.py::run_agent` (gateway path) and `client.py::DeerFlowClient.stream` (embedded path) merge the metadata into `config["metadata"]` right before constructing the graph. Caller-supplied keys win via `setdefault`, so an external `session_id` override is preserved. Field mapping:
|
|
||||||
|
|
||||||
| Langfuse field | Source |
|
|
||||||
|-----------------------|----------------------------------------------|
|
|
||||||
| `langfuse_session_id` | LangGraph `thread_id` |
|
|
||||||
| `langfuse_user_id` | `get_effective_user_id()` (`default` in no-auth) |
|
|
||||||
| `langfuse_trace_name` | `RunRecord.assistant_id` / client `agent_name` (defaults to `lead-agent`) |
|
|
||||||
| `langfuse_tags` | `env:<DEER_FLOW_ENV>` + `model:<model_name>` |
|
|
||||||
|
|
||||||
Returns `{}` when Langfuse is not in the enabled providers — LangSmith-only deployments are unaffected. Set `DEER_FLOW_ENV` (or `ENVIRONMENT`) to tag traces by deployment environment. Tests live in `tests/test_tracing_factory.py`, `tests/test_tracing_metadata.py`, `tests/test_worker_langfuse_metadata.py`, and `tests/test_client_langfuse_metadata.py`.
|
|
||||||
|
|
||||||
### Config Schema
|
### Config Schema
|
||||||
|
|
||||||
**`config.yaml`** key sections:
|
**`config.yaml`** key sections:
|
||||||
@@ -451,9 +404,9 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
|
|||||||
|
|
||||||
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes.
|
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes.
|
||||||
|
|
||||||
**Architecture**: Imports the same `deerflow` modules that Gateway API uses. Shares the same config files and data directories. No FastAPI dependency.
|
**Architecture**: Imports the same `deerflow` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
|
||||||
|
|
||||||
**Agent Conversation**:
|
**Agent Conversation** (replaces LangGraph Server):
|
||||||
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
|
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
|
||||||
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
|
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
|
||||||
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
|
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
|
||||||
@@ -516,15 +469,20 @@ This starts all services and makes the application available at `http://localhos
|
|||||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||||
|
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||||
|
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||||
|
|
||||||
| Action | Local | Docker Dev | Docker Prod |
|
| Action | Local | Docker Dev | Docker Prod |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||||
|
|
||||||
|
Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
|
||||||
|
|
||||||
**Nginx routing**:
|
**Nginx routing**:
|
||||||
- `/api/langgraph/*` → Gateway embedded runtime (8001), rewritten to `/api/*`
|
- Standard mode: `/api/langgraph/*` → LangGraph Server (2024)
|
||||||
|
- Gateway mode: `/api/langgraph/*` → Gateway embedded runtime (8001) (via envsubst)
|
||||||
- `/api/*` (other) → Gateway API (8001)
|
- `/api/*` (other) → Gateway API (8001)
|
||||||
- `/` (non-API) → Frontend (3000)
|
- `/` (non-API) → Frontend (3000)
|
||||||
|
|
||||||
@@ -533,11 +491,15 @@ This starts all services and makes the application available at `http://localhos
|
|||||||
From the **backend** directory:
|
From the **backend** directory:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Gateway API
|
# Terminal 1: LangGraph server
|
||||||
|
make dev
|
||||||
|
|
||||||
|
# Terminal 2: Gateway API
|
||||||
make gateway
|
make gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
Direct access (without nginx):
|
Direct access (without nginx):
|
||||||
|
- LangGraph: `http://localhost:2024`
|
||||||
- Gateway: `http://localhost:8001`
|
- Gateway: `http://localhost:8001`
|
||||||
|
|
||||||
### Frontend Configuration
|
### Frontend Configuration
|
||||||
@@ -558,7 +520,6 @@ Multi-file upload with automatic document conversion:
|
|||||||
- Rejects directory inputs before copying so uploads stay all-or-nothing
|
- Rejects directory inputs before copying so uploads stay all-or-nothing
|
||||||
- Reuses one conversion worker per request when called from an active event loop
|
- Reuses one conversion worker per request when called from an active event loop
|
||||||
- Files stored in thread-isolated directories
|
- Files stored in thread-isolated directories
|
||||||
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
|
|
||||||
- Agent receives uploaded file list via `UploadsMiddleware`
|
- Agent receives uploaded file list via `UploadsMiddleware`
|
||||||
|
|
||||||
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
|
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
|
||||||
|
|||||||
@@ -56,8 +56,11 @@ export OPENAI_API_KEY="your-api-key"
|
|||||||
### Run the Development Server
|
### Run the Development Server
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Gateway API + embedded agent runtime
|
# Terminal 1: LangGraph server
|
||||||
make dev
|
make dev
|
||||||
|
|
||||||
|
# Terminal 2: Gateway API
|
||||||
|
make gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|||||||
@@ -50,12 +50,6 @@ COPY backend ./backend
|
|||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
||||||
|
|
||||||
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
|
|
||||||
# containers where locale configuration may be missing and the default encoding is not UTF-8.
|
|
||||||
ENV LANG=C.UTF-8
|
|
||||||
ENV LC_ALL=C.UTF-8
|
|
||||||
ENV PYTHONIOENCODING=utf-8
|
|
||||||
|
|
||||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||||
# source distributions in development containers.
|
# source distributions in development containers.
|
||||||
@@ -72,10 +66,6 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
|
|||||||
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
||||||
FROM python:3.12-slim-bookworm
|
FROM python:3.12-slim-bookworm
|
||||||
|
|
||||||
ENV LANG=C.UTF-8
|
|
||||||
ENV LC_ALL=C.UTF-8
|
|
||||||
ENV PYTHONIOENCODING=utf-8
|
|
||||||
|
|
||||||
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
|
|||||||
+3
-3
@@ -2,13 +2,13 @@ install:
|
|||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
dev:
|
dev:
|
||||||
PYTHONPATH=. PYTHONIOENCODING=utf-8 PYTHONUTF8=1 uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload
|
uv run langgraph dev --no-browser --no-reload --n-jobs-per-worker 10
|
||||||
|
|
||||||
gateway:
|
gateway:
|
||||||
PYTHONPATH=. PYTHONIOENCODING=utf-8 PYTHONUTF8=1 uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||||
|
|
||||||
test:
|
test:
|
||||||
PYTHONPATH=. PYTHONIOENCODING=utf-8 PYTHONUTF8=1 uv run pytest tests/ -v
|
PYTHONPATH=. uv run pytest tests/unittest -v
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
uvx ruff check .
|
uvx ruff check .
|
||||||
|
|||||||
+34
-30
@@ -11,26 +11,31 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent
|
|||||||
│ Nginx (Port 2026) │
|
│ Nginx (Port 2026) │
|
||||||
│ Unified reverse proxy │
|
│ Unified reverse proxy │
|
||||||
└───────┬──────────────────┬───────────┘
|
└───────┬──────────────────┬───────────┘
|
||||||
│
|
│ │
|
||||||
/api/langgraph/* │ /api/* (other)
|
/api/langgraph/* │ │ /api/* (other)
|
||||||
rewritten to /api/* │
|
▼ ▼
|
||||||
▼
|
┌────────────────────┐ ┌────────────────────────┐
|
||||||
┌────────────────────────────────────────┐
|
│ LangGraph Server │ │ Gateway API (8001) │
|
||||||
│ Gateway API (8001) │
|
│ (Port 2024) │ │ FastAPI REST │
|
||||||
│ FastAPI REST + agent runtime │
|
│ │ │ │
|
||||||
│ │
|
│ ┌────────────────┐ │ │ Models, MCP, Skills, │
|
||||||
│ Models, MCP, Skills, Memory, Uploads, │
|
│ │ Lead Agent │ │ │ Memory, Uploads, │
|
||||||
│ Artifacts, Threads, Runs, Streaming │
|
│ │ ┌──────────┐ │ │ │ Artifacts │
|
||||||
│ │
|
│ │ │Middleware│ │ │ └────────────────────────┘
|
||||||
│ ┌────────────────────────────────────┐ │
|
│ │ │ Chain │ │ │
|
||||||
│ │ Lead Agent │ │
|
│ │ └──────────┘ │ │
|
||||||
│ │ Middleware Chain, Tools, Subagents │ │
|
│ │ ┌──────────┐ │ │
|
||||||
│ └────────────────────────────────────┘ │
|
│ │ │ Tools │ │ │
|
||||||
└────────────────────────────────────────┘
|
│ │ └──────────┘ │ │
|
||||||
|
│ │ ┌──────────┐ │ │
|
||||||
|
│ │ │Subagents │ │ │
|
||||||
|
│ │ └──────────┘ │ │
|
||||||
|
│ └────────────────┘ │
|
||||||
|
└────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
**Request Routing** (via Nginx):
|
**Request Routing** (via Nginx):
|
||||||
- `/api/langgraph/*` → Gateway LangGraph-compatible API - agent interactions, threads, streaming
|
- `/api/langgraph/*` → LangGraph Server - agent interactions, threads, streaming
|
||||||
- `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup
|
- `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup
|
||||||
- `/` (non-API) → Frontend - Next.js web interface
|
- `/` (non-API) → Frontend - Next.js web interface
|
||||||
|
|
||||||
@@ -69,12 +74,12 @@ Middlewares execute in strict order, each handling a specific concern:
|
|||||||
Per-thread isolated execution with virtual path translation:
|
Per-thread isolated execution with virtual path translation:
|
||||||
|
|
||||||
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
|
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/)
|
||||||
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
||||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||||
- **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match
|
- **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match
|
||||||
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`write_file` overwrites by default and exposes `append` for end-of-file writes; `bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
|
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
|
||||||
|
|
||||||
### Subagent System
|
### Subagent System
|
||||||
|
|
||||||
@@ -119,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
|
|||||||
| `POST /api/memory/reload` | Force memory reload |
|
| `POST /api/memory/reload` | Force memory reload |
|
||||||
| `GET /api/memory/config` | Memory configuration |
|
| `GET /api/memory/config` | Memory configuration |
|
||||||
| `GET /api/memory/status` | Combined config + data |
|
| `GET /api/memory/status` | Combined config + data |
|
||||||
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
|
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
|
||||||
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
|
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
|
||||||
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
|
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
|
||||||
@@ -188,7 +193,7 @@ export OPENAI_API_KEY="your-api-key-here"
|
|||||||
**Full Application** (from project root):
|
**Full Application** (from project root):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make dev # Starts Gateway + Frontend + Nginx
|
make dev # Starts LangGraph + Gateway + Frontend + Nginx
|
||||||
```
|
```
|
||||||
|
|
||||||
Access at: http://localhost:2026
|
Access at: http://localhost:2026
|
||||||
@@ -196,11 +201,14 @@ Access at: http://localhost:2026
|
|||||||
**Backend Only** (from backend directory):
|
**Backend Only** (from backend directory):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Gateway API + embedded agent runtime
|
# Terminal 1: LangGraph server
|
||||||
make dev
|
make dev
|
||||||
|
|
||||||
|
# Terminal 2: Gateway API
|
||||||
|
make gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
Direct access: Gateway at http://localhost:8001
|
Direct access: LangGraph at http://localhost:2024, Gateway at http://localhost:8001
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -236,16 +244,12 @@ backend/
|
|||||||
│ └── utils/ # Utilities
|
│ └── utils/ # Utilities
|
||||||
├── docs/ # Documentation
|
├── docs/ # Documentation
|
||||||
├── tests/ # Test suite
|
├── tests/ # Test suite
|
||||||
├── langgraph.json # LangGraph graph registry for tooling/Studio compatibility
|
├── langgraph.json # LangGraph server configuration
|
||||||
├── pyproject.toml # Python dependencies
|
├── pyproject.toml # Python dependencies
|
||||||
├── Makefile # Development commands
|
├── Makefile # Development commands
|
||||||
└── Dockerfile # Container build
|
└── Dockerfile # Container build
|
||||||
```
|
```
|
||||||
|
|
||||||
`langgraph.json` is not the default service entrypoint. The scripts and Docker
|
|
||||||
deployments run the Gateway embedded runtime; the file is kept for LangGraph
|
|
||||||
tooling, Studio, or direct LangGraph Server compatibility.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
@@ -358,8 +362,8 @@ If a provider is explicitly enabled but required credentials are missing, or the
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
make install # Install dependencies
|
make install # Install dependencies
|
||||||
make dev # Run Gateway API + embedded agent runtime (port 8001)
|
make dev # Run LangGraph server (port 2024)
|
||||||
make gateway # Run Gateway API without reload (port 8001)
|
make gateway # Run Gateway API (port 8001)
|
||||||
make lint # Run linter (ruff)
|
make lint # Run linter (ruff)
|
||||||
make format # Format code (ruff)
|
make format # Format code (ruff)
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Provides a pluggable channel system that connects external messaging platforms
|
Provides a pluggable channel system that connects external messaging platforms
|
||||||
(Feishu/Lark, Slack, Telegram) to the DeerFlow agent via the ChannelManager,
|
(Feishu/Lark, Slack, Telegram) to the DeerFlow agent via the ChannelManager,
|
||||||
which uses ``langgraph-sdk`` to communicate with Gateway's LangGraph-compatible API.
|
which uses ``langgraph-sdk`` to communicate with the underlying LangGraph Server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
|||||||
@@ -31,10 +31,6 @@ class Channel(ABC):
|
|||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return self._running
|
return self._running
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
# -- lifecycle ---------------------------------------------------------
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -1,740 +0,0 @@
|
|||||||
"""DingTalk channel implementation."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from app.channels.base import Channel
|
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DINGTALK_API_BASE = "https://api.dingtalk.com"
|
|
||||||
|
|
||||||
_TOKEN_REFRESH_MARGIN_SECONDS = 300
|
|
||||||
|
|
||||||
_CONVERSATION_TYPE_P2P = "1"
|
|
||||||
_CONVERSATION_TYPE_GROUP = "2"
|
|
||||||
|
|
||||||
_MAX_UPLOAD_SIZE_BYTES = 20 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_conversation_type(raw: Any) -> str:
|
|
||||||
"""Normalize ``conversationType`` to ``"1"`` (P2P) or ``"2"`` (group).
|
|
||||||
|
|
||||||
Stream payloads may send int or string values.
|
|
||||||
"""
|
|
||||||
if raw is None:
|
|
||||||
return _CONVERSATION_TYPE_P2P
|
|
||||||
s = str(raw).strip()
|
|
||||||
if s == _CONVERSATION_TYPE_GROUP:
|
|
||||||
return _CONVERSATION_TYPE_GROUP
|
|
||||||
return _CONVERSATION_TYPE_P2P
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|
||||||
if allowed_users is None:
|
|
||||||
return set()
|
|
||||||
if isinstance(allowed_users, str):
|
|
||||||
values = [allowed_users]
|
|
||||||
elif isinstance(allowed_users, (list, tuple, set)):
|
|
||||||
values = allowed_users
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"DingTalk allowed_users should be a list of user IDs; treating %s as one string value",
|
|
||||||
type(allowed_users).__name__,
|
|
||||||
)
|
|
||||||
values = [allowed_users]
|
|
||||||
return {str(uid) for uid in values if str(uid)}
|
|
||||||
|
|
||||||
|
|
||||||
def _is_dingtalk_command(text: str) -> bool:
|
|
||||||
if not text.startswith("/"):
|
|
||||||
return False
|
|
||||||
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_rich_text(rich_text_list: list) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in rich_text_list:
|
|
||||||
if isinstance(item, dict) and "text" in item:
|
|
||||||
parts.append(item["text"])
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
_FENCED_CODE_BLOCK_RE = re.compile(r"```(\w*)\n(.*?)```", re.DOTALL)
|
|
||||||
_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
|
|
||||||
_HORIZONTAL_RULE_RE = re.compile(r"^-{3,}$", re.MULTILINE)
|
|
||||||
_TABLE_SEPARATOR_RE = re.compile(r"^\|[-:| ]+\|$", re.MULTILINE)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_markdown_table(text: str) -> str:
|
|
||||||
# DingTalk sampleMarkdown does not render pipe-delimited tables.
|
|
||||||
lines = text.split("\n")
|
|
||||||
result: list[str] = []
|
|
||||||
i = 0
|
|
||||||
while i < len(lines):
|
|
||||||
line = lines[i]
|
|
||||||
# Detect table: header row followed by separator row
|
|
||||||
if i + 1 < len(lines) and line.strip().startswith("|") and _TABLE_SEPARATOR_RE.match(lines[i + 1].strip()):
|
|
||||||
headers = [h.strip() for h in line.strip().strip("|").split("|")]
|
|
||||||
i += 2 # skip header + separator
|
|
||||||
while i < len(lines) and lines[i].strip().startswith("|"):
|
|
||||||
cells = [c.strip() for c in lines[i].strip().strip("|").split("|")]
|
|
||||||
for h, c in zip(headers, cells):
|
|
||||||
result.append(f"> **{h}**: {c}")
|
|
||||||
result.append("")
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
result.append(line)
|
|
||||||
i += 1
|
|
||||||
return "\n".join(result)
|
|
||||||
|
|
||||||
|
|
||||||
def _adapt_markdown_for_dingtalk(text: str) -> str:
|
|
||||||
"""Adapt markdown for DingTalk's limited sampleMarkdown renderer."""
|
|
||||||
|
|
||||||
def _code_block_to_quote(match: re.Match) -> str:
|
|
||||||
lang = match.group(1)
|
|
||||||
code = match.group(2).rstrip("\n")
|
|
||||||
prefix = f"> **{lang}**\n" if lang else ""
|
|
||||||
quoted_lines = "\n".join(f"> {line}" for line in code.split("\n"))
|
|
||||||
return f"{prefix}{quoted_lines}\n"
|
|
||||||
|
|
||||||
text = _FENCED_CODE_BLOCK_RE.sub(_code_block_to_quote, text)
|
|
||||||
text = _INLINE_CODE_RE.sub(r"**\1**", text)
|
|
||||||
text = _convert_markdown_table(text)
|
|
||||||
text = _HORIZONTAL_RULE_RE.sub("───────────", text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class DingTalkChannel(Channel):
|
|
||||||
"""DingTalk IM channel using Stream Push (WebSocket, no public IP needed)."""
|
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
|
||||||
super().__init__(name="dingtalk", bus=bus, config=config)
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._client_id: str = ""
|
|
||||||
self._client_secret: str = ""
|
|
||||||
self._allowed_users: set[str] = _normalize_allowed_users(config.get("allowed_users"))
|
|
||||||
self._cached_token: str = ""
|
|
||||||
self._token_expires_at: float = 0.0
|
|
||||||
self._token_lock = asyncio.Lock()
|
|
||||||
self._card_template_id: str = config.get("card_template_id", "")
|
|
||||||
self._card_track_ids: dict[str, str] = {}
|
|
||||||
self._dingtalk_client: Any = None
|
|
||||||
self._stream_client: Any = None
|
|
||||||
self._incoming_messages: dict[str, Any] = {}
|
|
||||||
self._incoming_messages_lock = threading.Lock()
|
|
||||||
self._card_repliers: dict[str, Any] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return bool(self._card_template_id)
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
import dingtalk_stream # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
logger.error("dingtalk-stream is not installed. Install it with: uv add dingtalk-stream")
|
|
||||||
return
|
|
||||||
|
|
||||||
client_id = self.config.get("client_id", "")
|
|
||||||
client_secret = self.config.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id or not client_secret:
|
|
||||||
logger.error("DingTalk channel requires client_id and client_secret")
|
|
||||||
return
|
|
||||||
|
|
||||||
self._client_id = client_id
|
|
||||||
self._client_secret = client_secret
|
|
||||||
self._main_loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
if self._card_template_id:
|
|
||||||
logger.info("[DingTalk] AI Card mode enabled (template=%s)", self._card_template_id)
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
self._thread = threading.Thread(
|
|
||||||
target=self._run_stream,
|
|
||||||
args=(client_id, client_secret),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
self._thread.start()
|
|
||||||
logger.info("DingTalk channel started")
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
self._running = False
|
|
||||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
stream_client = self._stream_client
|
|
||||||
if stream_client is not None:
|
|
||||||
try:
|
|
||||||
if hasattr(stream_client, "disconnect"):
|
|
||||||
stream_client.disconnect()
|
|
||||||
except Exception:
|
|
||||||
logger.debug("[DingTalk] error disconnecting stream client", exc_info=True)
|
|
||||||
|
|
||||||
self._dingtalk_client = None
|
|
||||||
self._stream_client = None
|
|
||||||
with self._incoming_messages_lock:
|
|
||||||
self._incoming_messages.clear()
|
|
||||||
self._card_repliers.clear()
|
|
||||||
self._card_track_ids.clear()
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=5)
|
|
||||||
self._thread = None
|
|
||||||
logger.info("DingTalk channel stopped")
|
|
||||||
|
|
||||||
def _resolve_routing(self, msg: OutboundMessage) -> tuple[str, str, str]:
|
|
||||||
"""Return (conversation_type, sender_staff_id, conversation_id).
|
|
||||||
|
|
||||||
Uses msg.chat_id as the primary routing key; metadata as fallback.
|
|
||||||
"""
|
|
||||||
conversation_type = _normalize_conversation_type(msg.metadata.get("conversation_type"))
|
|
||||||
sender_staff_id = msg.metadata.get("sender_staff_id", "")
|
|
||||||
conversation_id = msg.metadata.get("conversation_id", "")
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
conversation_id = msg.chat_id or conversation_id
|
|
||||||
else:
|
|
||||||
sender_staff_id = msg.chat_id or sender_staff_id
|
|
||||||
return conversation_type, sender_staff_id, conversation_id
|
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
|
||||||
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
|
|
||||||
robot_code = self._client_id
|
|
||||||
|
|
||||||
# Card mode: stream update to existing AI card
|
|
||||||
source_key = self._make_card_source_key_from_outbound(msg)
|
|
||||||
out_track_id = self._card_track_ids.get(source_key)
|
|
||||||
|
|
||||||
# ``card_template_id`` enables ``runs.stream`` (non-final + final outbounds).
|
|
||||||
# If card creation failed, skip non-final chunks to avoid duplicate messages.
|
|
||||||
if self._card_template_id and not out_track_id and not msg.is_final:
|
|
||||||
return
|
|
||||||
|
|
||||||
if out_track_id:
|
|
||||||
try:
|
|
||||||
await self._stream_update_card(
|
|
||||||
out_track_id,
|
|
||||||
msg.text,
|
|
||||||
is_finalize=msg.is_final,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("[DingTalk] card stream failed, falling back to sampleMarkdown")
|
|
||||||
if msg.is_final:
|
|
||||||
self._card_track_ids.pop(source_key, None)
|
|
||||||
self._card_repliers.pop(out_track_id, None)
|
|
||||||
await self._send_markdown_fallback(robot_code, conversation_type, sender_staff_id, conversation_id, msg.text)
|
|
||||||
return
|
|
||||||
if msg.is_final:
|
|
||||||
self._card_track_ids.pop(source_key, None)
|
|
||||||
self._card_repliers.pop(out_track_id, None)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Non-card mode: send sampleMarkdown with retry
|
|
||||||
last_exc: Exception | None = None
|
|
||||||
for attempt in range(_max_retries):
|
|
||||||
try:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
|
|
||||||
else:
|
|
||||||
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
|
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
delay = 2**attempt
|
|
||||||
logger.warning(
|
|
||||||
"[DingTalk] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
_max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[DingTalk] send failed after %d attempts: %s", _max_retries, last_exc)
|
|
||||||
if last_exc is None:
|
|
||||||
raise RuntimeError("DingTalk send failed without an exception from any attempt")
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
async def _send_markdown_fallback(
|
|
||||||
self,
|
|
||||||
robot_code: str,
|
|
||||||
conversation_type: str,
|
|
||||||
sender_staff_id: str,
|
|
||||||
conversation_id: str,
|
|
||||||
text: str,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
await self._send_group_message(robot_code, conversation_id, text)
|
|
||||||
else:
|
|
||||||
await self._send_p2p_message(robot_code, sender_staff_id, text)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] markdown fallback also failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
|
||||||
if attachment.size > _MAX_UPLOAD_SIZE_BYTES:
|
|
||||||
logger.warning("[DingTalk] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
|
|
||||||
robot_code = self._client_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
media_id = await self._upload_media(attachment.actual_path, "image" if attachment.is_image else "file")
|
|
||||||
if not media_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if attachment.is_image:
|
|
||||||
msg_key = "sampleImageMsg"
|
|
||||||
msg_param = json.dumps({"photoURL": media_id})
|
|
||||||
else:
|
|
||||||
msg_key = "sampleFile"
|
|
||||||
msg_param = json.dumps(
|
|
||||||
{
|
|
||||||
"fileUrl": media_id,
|
|
||||||
"fileName": attachment.filename,
|
|
||||||
"fileSize": str(attachment.size),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": msg_key,
|
|
||||||
"msgParam": msg_param,
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"openConversationId": conversation_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": msg_key,
|
|
||||||
"msgParam": msg_param,
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"userIds": [sender_staff_id],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
logger.info("[DingTalk] file sent: %s", attachment.filename)
|
|
||||||
return True
|
|
||||||
except (httpx.HTTPError, OSError, ValueError, TypeError, AttributeError):
|
|
||||||
logger.exception("[DingTalk] failed to send file: %s", attachment.filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# -- stream client (runs in dedicated thread) --------------------------
|
|
||||||
|
|
||||||
def _run_stream(self, client_id: str, client_secret: str) -> None:
|
|
||||||
try:
|
|
||||||
import dingtalk_stream
|
|
||||||
|
|
||||||
credential = dingtalk_stream.Credential(client_id, client_secret)
|
|
||||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
|
||||||
self._stream_client = client
|
|
||||||
client.register_callback_handler(
|
|
||||||
dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
|
|
||||||
_DingTalkMessageHandler(self),
|
|
||||||
)
|
|
||||||
client.start_forever()
|
|
||||||
except Exception:
|
|
||||||
if self._running:
|
|
||||||
logger.exception("DingTalk Stream Push error")
|
|
||||||
finally:
|
|
||||||
self._stream_client = None
|
|
||||||
|
|
||||||
def _on_chatbot_message(self, message: Any) -> None:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
sender_staff_id = message.sender_staff_id or ""
|
|
||||||
conversation_type = _normalize_conversation_type(message.conversation_type)
|
|
||||||
conversation_id = message.conversation_id or ""
|
|
||||||
msg_id = message.message_id or ""
|
|
||||||
sender_nick = message.sender_nick or ""
|
|
||||||
|
|
||||||
if self._allowed_users and sender_staff_id not in self._allowed_users:
|
|
||||||
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
text = self._extract_text(message)
|
|
||||||
if not text:
|
|
||||||
logger.info("[DingTalk] empty text, ignoring message")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text=%r",
|
|
||||||
conversation_type,
|
|
||||||
msg_id,
|
|
||||||
sender_staff_id,
|
|
||||||
sender_nick,
|
|
||||||
text[:100],
|
|
||||||
)
|
|
||||||
|
|
||||||
if _is_dingtalk_command(text):
|
|
||||||
msg_type = InboundMessageType.COMMAND
|
|
||||||
else:
|
|
||||||
msg_type = InboundMessageType.CHAT
|
|
||||||
|
|
||||||
# P2P: topic_id=None (single thread per user, like Telegram private chat)
|
|
||||||
# Group: topic_id=msg_id (each new message starts a new topic, like Feishu)
|
|
||||||
topic_id: str | None = msg_id if conversation_type == _CONVERSATION_TYPE_GROUP else None
|
|
||||||
|
|
||||||
# chat_id uses conversation_id for groups, sender_staff_id for P2P
|
|
||||||
chat_id = conversation_id if conversation_type == _CONVERSATION_TYPE_GROUP else sender_staff_id
|
|
||||||
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=sender_staff_id,
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=msg_id,
|
|
||||||
metadata={
|
|
||||||
"conversation_type": conversation_type,
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"sender_staff_id": sender_staff_id,
|
|
||||||
"sender_nick": sender_nick,
|
|
||||||
"message_id": msg_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = topic_id
|
|
||||||
|
|
||||||
if self._card_template_id:
|
|
||||||
source_key = self._make_card_source_key(inbound)
|
|
||||||
with self._incoming_messages_lock:
|
|
||||||
self._incoming_messages[source_key] = message
|
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
|
||||||
logger.info("[DingTalk] publishing inbound message to bus (type=%s, msg_id=%s)", msg_type.value, msg_id)
|
|
||||||
fut = asyncio.run_coroutine_threadsafe(
|
|
||||||
self._prepare_inbound(chat_id, inbound),
|
|
||||||
self._main_loop,
|
|
||||||
)
|
|
||||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "prepare_inbound", mid))
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] main loop not running, cannot publish inbound message")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] error processing chatbot message")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_text(message: Any) -> str:
|
|
||||||
msg_type = message.message_type
|
|
||||||
if msg_type == "text" and message.text:
|
|
||||||
return message.text.content.strip()
|
|
||||||
if msg_type == "richText" and message.rich_text_content:
|
|
||||||
return _extract_text_from_rich_text(message.rich_text_content.rich_text_list).strip()
|
|
||||||
return ""
|
|
||||||
|
|
||||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
|
||||||
# Running reply must finish before publish_inbound so AI card tracks are
|
|
||||||
# registered before the manager emits streaming outbounds.
|
|
||||||
await self._send_running_reply(chat_id, inbound)
|
|
||||||
await self.bus.publish_inbound(inbound)
|
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
|
||||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
|
||||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
|
||||||
conversation_id = inbound.metadata.get("conversation_id", "")
|
|
||||||
text = "\u23f3 Working on it..."
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self._card_template_id:
|
|
||||||
source_key = self._make_card_source_key(inbound)
|
|
||||||
with self._incoming_messages_lock:
|
|
||||||
chatbot_message = self._incoming_messages.pop(source_key, None)
|
|
||||||
out_track_id = await self._create_and_deliver_card(
|
|
||||||
text,
|
|
||||||
chatbot_message=chatbot_message,
|
|
||||||
)
|
|
||||||
if out_track_id:
|
|
||||||
self._card_track_ids[source_key] = out_track_id
|
|
||||||
logger.info("[DingTalk] AI card running reply sent for chat=%s", chat_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
robot_code = self._client_id
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
|
||||||
else:
|
|
||||||
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
|
||||||
logger.info("[DingTalk] 'Working on it...' reply sent for chat=%s", chat_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] failed to send running reply for chat=%s", chat_id)
|
|
||||||
|
|
||||||
# -- DingTalk API helpers ----------------------------------------------
|
|
||||||
|
|
||||||
async def _get_access_token(self) -> str:
|
|
||||||
if self._cached_token and time.monotonic() < self._token_expires_at:
|
|
||||||
return self._cached_token
|
|
||||||
async with self._token_lock:
|
|
||||||
if self._cached_token and time.monotonic() < self._token_expires_at:
|
|
||||||
return self._cached_token
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/oauth2/accessToken",
|
|
||||||
json={"appKey": self._client_id, "appSecret": self._client_secret}, # DingTalk API field names
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
raise ValueError(f"DingTalk access token response must be a JSON object, got {type(data).__name__}")
|
|
||||||
|
|
||||||
access_token = data.get("accessToken")
|
|
||||||
if not isinstance(access_token, str) or not access_token.strip():
|
|
||||||
raise ValueError("DingTalk access token response did not contain a usable accessToken")
|
|
||||||
|
|
||||||
raw_expires_in = data.get("expireIn", 7200)
|
|
||||||
try:
|
|
||||||
expires_in = int(raw_expires_in)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
logger.warning("[DingTalk] invalid expireIn value %r, using default 7200s", raw_expires_in)
|
|
||||||
expires_in = 7200
|
|
||||||
|
|
||||||
self._cached_token = access_token.strip()
|
|
||||||
self._token_expires_at = time.monotonic() + expires_in - _TOKEN_REFRESH_MARGIN_SECONDS
|
|
||||||
return self._cached_token
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _api_headers(token: str) -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
"x-acs-dingtalk-access-token": token,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _send_text_message_to_user(self, robot_code: str, user_id: str, text: str) -> None:
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleText",
|
|
||||||
"msgParam": json.dumps({"content": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"userIds": [user_id],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def _send_text_message_to_group(self, robot_code: str, conversation_id: str, text: str) -> None:
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleText",
|
|
||||||
"msgParam": json.dumps({"content": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"openConversationId": conversation_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def _send_p2p_message(self, robot_code: str, user_id: str, text: str) -> None:
|
|
||||||
text = _adapt_markdown_for_dingtalk(text)
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleMarkdown",
|
|
||||||
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"userIds": [user_id],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
if data.get("processQueryKey"):
|
|
||||||
logger.info("[DingTalk] P2P message sent to user=%s", user_id)
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] P2P send response: %s", data)
|
|
||||||
|
|
||||||
async def _send_group_message(
|
|
||||||
self,
|
|
||||||
robot_code: str,
|
|
||||||
conversation_id: str,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
at_user_ids: list[str] | None = None, # noqa: ARG002
|
|
||||||
) -> None:
|
|
||||||
# at_user_ids accepted for call-site compatibility but not passed to the API
|
|
||||||
# (sampleMarkdown does not support @mentions).
|
|
||||||
text = _adapt_markdown_for_dingtalk(text)
|
|
||||||
token = await self._get_access_token()
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleMarkdown",
|
|
||||||
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"openConversationId": conversation_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
if data.get("processQueryKey"):
|
|
||||||
logger.info("[DingTalk] group message sent to conversation=%s", conversation_id)
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] group send response: %s", data)
|
|
||||||
|
|
||||||
# -- AI Card streaming helpers -------------------------------------------
|
|
||||||
|
|
||||||
def _make_card_source_key(self, inbound: InboundMessage) -> str:
|
|
||||||
m = inbound.metadata
|
|
||||||
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{m.get('message_id', '')}"
|
|
||||||
|
|
||||||
def _make_card_source_key_from_outbound(self, msg: OutboundMessage) -> str:
|
|
||||||
m = msg.metadata
|
|
||||||
correlation_id = m.get("message_id") or msg.thread_ts or ""
|
|
||||||
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{correlation_id}"
|
|
||||||
|
|
||||||
async def _create_and_deliver_card(
|
|
||||||
self,
|
|
||||||
initial_text: str,
|
|
||||||
*,
|
|
||||||
chatbot_message: Any = None,
|
|
||||||
) -> str | None:
|
|
||||||
if self._dingtalk_client is None or chatbot_message is None:
|
|
||||||
logger.warning("[DingTalk] SDK client or chatbot_message unavailable, skipping AI card")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dingtalk_stream.card_replier import AICardReplier
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("[DingTalk] dingtalk-stream card_replier not available")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
replier = AICardReplier(self._dingtalk_client, chatbot_message)
|
|
||||||
card_instance_id = await replier.async_create_and_deliver_card(
|
|
||||||
card_template_id=self._card_template_id,
|
|
||||||
card_data={"content": initial_text},
|
|
||||||
)
|
|
||||||
if not card_instance_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
self._card_repliers[card_instance_id] = replier
|
|
||||||
logger.info("[DingTalk] AI card created: outTrackId=%s", card_instance_id)
|
|
||||||
return card_instance_id
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] failed to create AI card")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _stream_update_card(
|
|
||||||
self,
|
|
||||||
out_track_id: str,
|
|
||||||
content: str,
|
|
||||||
*,
|
|
||||||
is_finalize: bool = False,
|
|
||||||
is_error: bool = False,
|
|
||||||
) -> None:
|
|
||||||
replier = self._card_repliers.get(out_track_id)
|
|
||||||
if not replier:
|
|
||||||
raise RuntimeError(f"No AICardReplier found for track ID {out_track_id}")
|
|
||||||
|
|
||||||
await replier.async_streaming(
|
|
||||||
card_instance_id=out_track_id,
|
|
||||||
content_key="content",
|
|
||||||
content_value=content,
|
|
||||||
append=False,
|
|
||||||
finished=is_finalize,
|
|
||||||
failed=is_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- media upload --------------------------------------------------------
|
|
||||||
|
|
||||||
async def _upload_media(self, file_path: str | Path, media_type: str) -> str | None:
|
|
||||||
try:
|
|
||||||
file_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/files/upload",
|
|
||||||
headers={"x-acs-dingtalk-access-token": token},
|
|
||||||
files={"file": ("upload", file_bytes)},
|
|
||||||
data={"type": media_type},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
try:
|
|
||||||
payload = response.json()
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.exception("[DingTalk] failed to decode upload response JSON: %s", file_path)
|
|
||||||
return None
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
logger.warning("[DingTalk] unexpected upload response type %s for %s", type(payload).__name__, file_path)
|
|
||||||
return None
|
|
||||||
return payload.get("mediaId")
|
|
||||||
except (httpx.HTTPError, OSError):
|
|
||||||
logger.exception("[DingTalk] failed to upload media: %s", file_path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _log_future_error(fut: Any, name: str, msg_id: str) -> None:
|
|
||||||
try:
|
|
||||||
exc = fut.exception()
|
|
||||||
if exc:
|
|
||||||
logger.error("[DingTalk] %s failed for msg_id=%s: %s", name, msg_id, exc)
|
|
||||||
except (asyncio.CancelledError, asyncio.InvalidStateError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _DingTalkMessageHandler:
|
|
||||||
"""Callback handler registered with dingtalk-stream."""
|
|
||||||
|
|
||||||
def __init__(self, channel: DingTalkChannel) -> None:
|
|
||||||
self._channel = channel
|
|
||||||
|
|
||||||
def pre_start(self) -> None:
|
|
||||||
if hasattr(self, "dingtalk_client") and self.dingtalk_client is not None:
|
|
||||||
self._channel._dingtalk_client = self.dingtalk_client
|
|
||||||
|
|
||||||
async def raw_process(self, callback_message: Any) -> Any:
|
|
||||||
import dingtalk_stream
|
|
||||||
from dingtalk_stream.frames import Headers
|
|
||||||
|
|
||||||
code, message = await self.process(callback_message)
|
|
||||||
ack_message = dingtalk_stream.AckMessage()
|
|
||||||
ack_message.code = code
|
|
||||||
ack_message.headers.message_id = callback_message.headers.message_id
|
|
||||||
ack_message.headers.content_type = Headers.CONTENT_TYPE_APPLICATION_JSON
|
|
||||||
ack_message.data = {"response": message}
|
|
||||||
return ack_message
|
|
||||||
|
|
||||||
async def process(self, callback: Any) -> tuple[int, str]:
|
|
||||||
import dingtalk_stream
|
|
||||||
|
|
||||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
|
||||||
self._channel._on_chatbot_message(incoming_message)
|
|
||||||
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
|
|
||||||
@@ -1,553 +0,0 @@
|
|||||||
"""Discord channel integration using discord.py."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.channels.base import Channel
|
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_DISCORD_MAX_MESSAGE_LEN = 2000
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(Channel):
|
|
||||||
"""Discord bot channel.
|
|
||||||
|
|
||||||
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
|
||||||
- ``bot_token``: Discord Bot token.
|
|
||||||
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
|
||||||
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
|
|
||||||
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
|
|
||||||
(even when mention_only is true). Use for channels where you want the bot to respond
|
|
||||||
without mentions. Empty = mention_only applies everywhere.
|
|
||||||
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
|
|
||||||
Default: same as ``mention_only``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
|
||||||
super().__init__(name="discord", bus=bus, config=config)
|
|
||||||
self._bot_token = str(config.get("bot_token", "")).strip()
|
|
||||||
self._allowed_guilds: set[int] = set()
|
|
||||||
for guild_id in config.get("allowed_guilds", []):
|
|
||||||
try:
|
|
||||||
self._allowed_guilds.add(int(guild_id))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
continue
|
|
||||||
self._mention_only: bool = bool(config.get("mention_only", False))
|
|
||||||
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
|
|
||||||
self._allowed_channels: set[str] = set()
|
|
||||||
for channel_id in config.get("allowed_channels", []):
|
|
||||||
self._allowed_channels.add(str(channel_id))
|
|
||||||
|
|
||||||
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
|
|
||||||
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
|
|
||||||
# conversations to DeerFlow thread IDs — a different concern.
|
|
||||||
self._active_threads: dict[str, str] = {}
|
|
||||||
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
|
|
||||||
self._active_thread_ids: set[str] = set()
|
|
||||||
# Lock protecting _active_threads and the JSON file from concurrent access.
|
|
||||||
# _run_client (Discord loop thread) and the main thread both read/write.
|
|
||||||
self._thread_store_lock = threading.Lock()
|
|
||||||
store = config.get("channel_store")
|
|
||||||
if store is not None:
|
|
||||||
self._thread_store_path = store._path.parent / "discord_threads.json"
|
|
||||||
else:
|
|
||||||
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
|
|
||||||
|
|
||||||
# Typing indicator management
|
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
self._client = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._discord_module = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
import discord
|
|
||||||
except ImportError:
|
|
||||||
logger.error("discord.py is not installed. Install it with: uv add discord.py")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self._bot_token:
|
|
||||||
logger.error("Discord channel requires bot_token")
|
|
||||||
return
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.messages = True
|
|
||||||
intents.guilds = True
|
|
||||||
intents.message_content = True
|
|
||||||
|
|
||||||
client = discord.Client(
|
|
||||||
intents=intents,
|
|
||||||
allowed_mentions=discord.AllowedMentions.none(),
|
|
||||||
)
|
|
||||||
self._client = client
|
|
||||||
self._discord_module = discord
|
|
||||||
self._main_loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_message(message) -> None:
|
|
||||||
await self._on_message(message)
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
self._load_active_threads()
|
|
||||||
logger.info("Discord channel started")
|
|
||||||
|
|
||||||
def _load_active_threads(self) -> None:
|
|
||||||
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
|
|
||||||
with self._thread_store_lock:
|
|
||||||
try:
|
|
||||||
if not self._thread_store_path.exists():
|
|
||||||
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
|
|
||||||
return
|
|
||||||
data = json.loads(self._thread_store_path.read_text())
|
|
||||||
self._active_threads.clear()
|
|
||||||
self._active_thread_ids.clear()
|
|
||||||
for channel_id, thread_id in data.items():
|
|
||||||
self._active_threads[channel_id] = thread_id
|
|
||||||
self._active_thread_ids.add(thread_id)
|
|
||||||
if self._active_threads:
|
|
||||||
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to load thread mappings")
|
|
||||||
|
|
||||||
def _save_thread(self, channel_id: str, thread_id: str) -> None:
|
|
||||||
"""Persist a Discord thread mapping to the dedicated JSON file."""
|
|
||||||
with self._thread_store_lock:
|
|
||||||
try:
|
|
||||||
data: dict[str, str] = {}
|
|
||||||
if self._thread_store_path.exists():
|
|
||||||
data = json.loads(self._thread_store_path.read_text())
|
|
||||||
old_id = data.get(channel_id)
|
|
||||||
data[channel_id] = thread_id
|
|
||||||
# Update reverse-lookup set
|
|
||||||
if old_id:
|
|
||||||
self._active_thread_ids.discard(old_id)
|
|
||||||
self._active_thread_ids.add(thread_id)
|
|
||||||
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._thread_store_path.write_text(json.dumps(data, indent=2))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
self._running = False
|
|
||||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
# Cancel all active typing indicator tasks
|
|
||||||
for target_id, task in list(self._typing_tasks.items()):
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.debug("[Discord] cancelled typing task for target %s", target_id)
|
|
||||||
self._typing_tasks.clear()
|
|
||||||
|
|
||||||
if self._client and self._discord_loop and self._discord_loop.is_running():
|
|
||||||
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning("[Discord] client close timed out after 10s")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] error while closing client")
|
|
||||||
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=10)
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
self._client = None
|
|
||||||
self._discord_loop = None
|
|
||||||
self._discord_module = None
|
|
||||||
logger.info("Discord channel stopped")
|
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
|
||||||
# Stop typing indicator once we're sending the response
|
|
||||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(stop_future)
|
|
||||||
|
|
||||||
target = await self._resolve_target(msg)
|
|
||||||
if target is None:
|
|
||||||
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
|
||||||
return
|
|
||||||
|
|
||||||
text = msg.text or ""
|
|
||||||
for chunk in self._split_text(text):
|
|
||||||
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(send_future)
|
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
|
||||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(stop_future)
|
|
||||||
|
|
||||||
target = await self._resolve_target(msg)
|
|
||||||
if target is None:
|
|
||||||
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self._discord_module is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
|
|
||||||
file = self._discord_module.File(fp, filename=attachment.filename)
|
|
||||||
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(send_future)
|
|
||||||
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
|
|
||||||
"""Starts a loop to send periodic typing indicators."""
|
|
||||||
target_id = thread_ts or chat_id
|
|
||||||
if target_id in self._typing_tasks:
|
|
||||||
return # Already typing for this target
|
|
||||||
|
|
||||||
async def _typing_loop():
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await channel.trigger_typing()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
task = asyncio.create_task(_typing_loop())
|
|
||||||
self._typing_tasks[target_id] = task
|
|
||||||
|
|
||||||
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
|
|
||||||
"""Stops the typing loop for a specific target."""
|
|
||||||
target_id = thread_ts or chat_id
|
|
||||||
task = self._typing_tasks.pop(target_id, None)
|
|
||||||
if task and not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
|
|
||||||
|
|
||||||
async def _add_reaction(self, message) -> None:
|
|
||||||
"""Add a checkmark reaction to acknowledge the message was received."""
|
|
||||||
try:
|
|
||||||
await message.add_reaction("✅")
|
|
||||||
except Exception:
|
|
||||||
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
|
|
||||||
|
|
||||||
async def _on_message(self, message) -> None:
|
|
||||||
if not self._running or not self._client:
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.author.bot:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._client.user and message.author.id == self._client.user.id:
|
|
||||||
return
|
|
||||||
|
|
||||||
guild = message.guild
|
|
||||||
if self._allowed_guilds:
|
|
||||||
if guild is None or guild.id not in self._allowed_guilds:
|
|
||||||
return
|
|
||||||
|
|
||||||
text = (message.content or "").strip()
|
|
||||||
if not text:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._discord_module is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Determine whether the bot is mentioned in this message
|
|
||||||
user = self._client.user if self._client else None
|
|
||||||
if user:
|
|
||||||
bot_mention = user.mention # <@ID>
|
|
||||||
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
|
|
||||||
standard_mention = f"<@{user.id}>"
|
|
||||||
else:
|
|
||||||
bot_mention = None
|
|
||||||
alt_mention = None
|
|
||||||
standard_mention = ""
|
|
||||||
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
|
|
||||||
|
|
||||||
# Strip mention from text for processing
|
|
||||||
if has_mention:
|
|
||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
|
||||||
|
|
||||||
# --- Determine thread/channel routing and typing target ---
|
|
||||||
thread_id = None
|
|
||||||
chat_id = None
|
|
||||||
typing_target = None # The Discord object to type into
|
|
||||||
|
|
||||||
if isinstance(message.channel, self._discord_module.Thread):
|
|
||||||
# --- Message already inside a thread ---
|
|
||||||
thread_obj = message.channel
|
|
||||||
thread_id = str(thread_obj.id)
|
|
||||||
chat_id = str(thread_obj.parent_id or thread_obj.id)
|
|
||||||
typing_target = thread_obj
|
|
||||||
|
|
||||||
# If this is a known active thread, process normally
|
|
||||||
if thread_id in self._active_thread_ids:
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=str(message.author.id),
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=thread_id,
|
|
||||||
metadata={
|
|
||||||
"guild_id": str(guild.id) if guild else None,
|
|
||||||
"channel_id": str(message.channel.id),
|
|
||||||
"message_id": str(message.id),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = thread_id
|
|
||||||
self._publish(inbound)
|
|
||||||
# Start typing indicator in the thread
|
|
||||||
if typing_target:
|
|
||||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
||||||
asyncio.create_task(self._add_reaction(message))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Thread not tracked (orphaned) — create new thread and handle below
|
|
||||||
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
|
|
||||||
thread_id = None
|
|
||||||
typing_target = None
|
|
||||||
|
|
||||||
# At this point we're guaranteed to be in a channel, not a thread
|
|
||||||
# (the Thread case is handled above). Apply mention_only for all
|
|
||||||
# non-thread messages — no special case needed.
|
|
||||||
channel_id = str(message.channel.id)
|
|
||||||
|
|
||||||
# Check if there's an active thread for this channel
|
|
||||||
if channel_id in self._active_threads:
|
|
||||||
# respect mention_only: if enabled, only process messages that mention the bot
|
|
||||||
# (unless the channel is in allowed_channels)
|
|
||||||
# Messages within a thread are always allowed through (continuation).
|
|
||||||
# At this code point we know the message is in a channel, not a thread
|
|
||||||
# (Thread case handled above), so always apply the check.
|
|
||||||
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
||||||
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
|
|
||||||
return
|
|
||||||
# mention_only + fresh @ → create new thread instead of routing to existing one
|
|
||||||
if self._mention_only and has_mention:
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is not None:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj
|
|
||||||
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
|
|
||||||
else:
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel
|
|
||||||
else:
|
|
||||||
# Existing session → route to the existing thread
|
|
||||||
target_thread_id = self._active_threads[channel_id]
|
|
||||||
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = await self._get_channel_or_thread(target_thread_id)
|
|
||||||
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
||||||
# Not mentioned and not in an allowed channel → skip
|
|
||||||
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
|
|
||||||
return
|
|
||||||
elif self._mention_only and has_mention:
|
|
||||||
# First mention in this channel → create thread
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is not None:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj # Type into the new thread
|
|
||||||
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
|
|
||||||
else:
|
|
||||||
# Fallback: thread creation failed (disabled/permissions), reply in channel
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
elif self._thread_mode:
|
|
||||||
# thread_mode but mention_only is False → create thread anyway for conversation grouping
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is None:
|
|
||||||
# Thread creation failed (disabled/permissions), fall back to channel replies
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
else:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj # Type into the new thread
|
|
||||||
else:
|
|
||||||
# No threading — reply directly in channel
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=str(message.author.id),
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=thread_id,
|
|
||||||
metadata={
|
|
||||||
"guild_id": str(guild.id) if guild else None,
|
|
||||||
"channel_id": str(message.channel.id),
|
|
||||||
"message_id": str(message.id),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = thread_id
|
|
||||||
|
|
||||||
# Start typing indicator in the correct target (thread or channel)
|
|
||||||
if typing_target:
|
|
||||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
||||||
|
|
||||||
self._publish(inbound)
|
|
||||||
asyncio.create_task(self._add_reaction(message))
|
|
||||||
|
|
||||||
def _publish(self, inbound) -> None:
|
|
||||||
"""Publish an inbound message to the main event loop."""
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
|
||||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
|
||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
|
||||||
|
|
||||||
def _run_client(self) -> None:
|
|
||||||
self._discord_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._discord_loop)
|
|
||||||
try:
|
|
||||||
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
|
|
||||||
except Exception:
|
|
||||||
if self._running:
|
|
||||||
logger.exception("Discord client error")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
if self._client and not self._client.is_closed():
|
|
||||||
self._discord_loop.run_until_complete(self._client.close())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error during Discord shutdown")
|
|
||||||
|
|
||||||
async def _create_thread(self, message):
|
|
||||||
try:
|
|
||||||
if self._discord_module is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
|
|
||||||
channel_type = message.channel.type
|
|
||||||
if channel_type not in (
|
|
||||||
self._discord_module.ChannelType.text,
|
|
||||||
self._discord_module.ChannelType.news,
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"[Discord] channel type %s (%s) does not support threads",
|
|
||||||
channel_type.value,
|
|
||||||
channel_type.name,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
|
||||||
return await message.create_thread(name=thread_name)
|
|
||||||
except self._discord_module.errors.HTTPException as exc:
|
|
||||||
if exc.code == 50024:
|
|
||||||
logger.info(
|
|
||||||
"[Discord] cannot create thread in channel %s (error code 50024): %s",
|
|
||||||
message.channel.id,
|
|
||||||
channel_type.name if (channel_type := message.channel.type) else "unknown",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.exception(
|
|
||||||
"[Discord] failed to create thread for message=%s (HTTPException %s)",
|
|
||||||
message.id,
|
|
||||||
exc.code,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _resolve_target(self, msg: OutboundMessage):
|
|
||||||
if not self._client or not self._discord_loop:
|
|
||||||
return None
|
|
||||||
|
|
||||||
target_ids: list[str] = []
|
|
||||||
if msg.thread_ts:
|
|
||||||
target_ids.append(msg.thread_ts)
|
|
||||||
if msg.chat_id and msg.chat_id not in target_ids:
|
|
||||||
target_ids.append(msg.chat_id)
|
|
||||||
|
|
||||||
for raw_id in target_ids:
|
|
||||||
target = await self._get_channel_or_thread(raw_id)
|
|
||||||
if target is not None:
|
|
||||||
return target
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_channel_or_thread(self, raw_id: str):
|
|
||||||
if not self._client or not self._discord_loop:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
target_id = int(raw_id)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
|
|
||||||
try:
|
|
||||||
return await asyncio.wrap_future(get_future)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _fetch_channel(self, target_id: int):
|
|
||||||
if not self._client:
|
|
||||||
return None
|
|
||||||
|
|
||||||
channel = self._client.get_channel(target_id)
|
|
||||||
if channel is not None:
|
|
||||||
return channel
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self._client.fetch_channel(target_id)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_text(text: str) -> list[str]:
|
|
||||||
if not text:
|
|
||||||
return [""]
|
|
||||||
|
|
||||||
chunks: list[str] = []
|
|
||||||
remaining = text
|
|
||||||
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
|
|
||||||
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
|
|
||||||
if split_at <= 0:
|
|
||||||
split_at = _DISCORD_MAX_MESSAGE_LEN
|
|
||||||
chunks.append(remaining[:split_at])
|
|
||||||
remaining = remaining[split_at:].lstrip("\n")
|
|
||||||
|
|
||||||
if remaining:
|
|
||||||
chunks.append(remaining)
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
@@ -9,11 +9,12 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_user_actor_context
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -63,10 +64,6 @@ class FeishuChannel(Channel):
|
|||||||
self._GetMessageResourceRequest = None
|
self._GetMessageResourceRequest = None
|
||||||
self._thread_lock = threading.Lock()
|
self._thread_lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
@@ -302,15 +299,35 @@ class FeishuChannel(Channel):
|
|||||||
text = msg.text
|
text = msg.text
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.get("image_key"):
|
if file.get("image_key"):
|
||||||
virtual_path = await self._receive_single_file(msg.thread_ts, file["image_key"], "image", thread_id)
|
virtual_path = await self._receive_single_file(
|
||||||
|
msg.thread_ts,
|
||||||
|
file["image_key"],
|
||||||
|
"image",
|
||||||
|
thread_id,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
text = text.replace("[image]", virtual_path, 1)
|
text = text.replace("[image]", virtual_path, 1)
|
||||||
elif file.get("file_key"):
|
elif file.get("file_key"):
|
||||||
virtual_path = await self._receive_single_file(msg.thread_ts, file["file_key"], "file", thread_id)
|
virtual_path = await self._receive_single_file(
|
||||||
|
msg.thread_ts,
|
||||||
|
file["file_key"],
|
||||||
|
"file",
|
||||||
|
thread_id,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
text = text.replace("[file]", virtual_path, 1)
|
text = text.replace("[file]", virtual_path, 1)
|
||||||
msg.text = text
|
msg.text = text
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def _receive_single_file(self, message_id: str, file_key: str, type: Literal["image", "file"], thread_id: str) -> str:
|
async def _receive_single_file(
|
||||||
|
self,
|
||||||
|
message_id: str,
|
||||||
|
file_key: str,
|
||||||
|
type: Literal["image", "file"],
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
|
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
|
||||||
|
|
||||||
def inner():
|
def inner():
|
||||||
@@ -349,50 +366,51 @@ class FeishuChannel(Channel):
|
|||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
user_id = get_effective_user_id()
|
with bind_user_actor_context(user_id):
|
||||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
effective_user_id = get_effective_user_id()
|
||||||
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
|
paths.ensure_thread_dirs(thread_id, user_id=effective_user_id)
|
||||||
|
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=effective_user_id).resolve()
|
||||||
|
|
||||||
ext = "png" if type == "image" else "bin"
|
ext = "png" if type == "image" else "bin"
|
||||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||||
|
|
||||||
# Sanitize filename: preserve extension, replace path chars in name part
|
# Sanitize filename: preserve extension, replace path chars in name part
|
||||||
if "." in raw_filename:
|
if "." in raw_filename:
|
||||||
name_part, ext = raw_filename.rsplit(".", 1)
|
name_part, ext = raw_filename.rsplit(".", 1)
|
||||||
name_part = re.sub(r"[./\\]", "_", name_part)
|
name_part = re.sub(r"[./\\]", "_", name_part)
|
||||||
filename = f"{name_part}.{ext}"
|
filename = f"{name_part}.{ext}"
|
||||||
else:
|
else:
|
||||||
filename = re.sub(r"[./\\]", "_", raw_filename)
|
filename = re.sub(r"[./\\]", "_", raw_filename)
|
||||||
resolved_target = uploads_dir / filename
|
resolved_target = uploads_dir / filename
|
||||||
|
|
||||||
def down_load():
|
def down_load():
|
||||||
# use thread_lock to avoid filename conflicts when writing
|
# use thread_lock to avoid filename conflicts when writing
|
||||||
with self._thread_lock:
|
with self._thread_lock:
|
||||||
resolved_target.write_bytes(content)
|
resolved_target.write_bytes(content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(down_load)
|
await asyncio.to_thread(down_load)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
|
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
|
||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
if sandbox_id != "local":
|
if sandbox_id != "local":
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
if sandbox is None:
|
if sandbox is None:
|
||||||
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
|
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
|
||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
sandbox.update_file(virtual_path, content)
|
sandbox.update_file(virtual_path, content)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
|
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
|
||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
|
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
|
||||||
return virtual_path
|
return virtual_path
|
||||||
|
|
||||||
# -- message formatting ------------------------------------------------
|
# -- message formatting ------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
+70
-118
@@ -1,4 +1,4 @@
|
|||||||
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via Gateway."""
|
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via LangGraph Server."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -14,16 +14,15 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
from langgraph_sdk.errors import ConflictError
|
from langgraph_sdk.errors import ConflictError
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_user_actor_context
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
from app.gateway.internal_auth import create_internal_auth_headers
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_LANGGRAPH_URL = "http://localhost:8001/api"
|
DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
|
||||||
DEFAULT_GATEWAY_URL = "http://localhost:8001"
|
DEFAULT_GATEWAY_URL = "http://localhost:8001"
|
||||||
DEFAULT_ASSISTANT_ID = "lead_agent"
|
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
||||||
@@ -38,8 +37,6 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
|||||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||||
|
|
||||||
CHANNEL_CAPABILITIES = {
|
CHANNEL_CAPABILITIES = {
|
||||||
"dingtalk": {"supports_streaming": False},
|
|
||||||
"discord": {"supports_streaming": False},
|
|
||||||
"feishu": {"supports_streaming": True},
|
"feishu": {"supports_streaming": True},
|
||||||
"slack": {"supports_streaming": False},
|
"slack": {"supports_streaming": False},
|
||||||
"telegram": {"supports_streaming": False},
|
"telegram": {"supports_streaming": False},
|
||||||
@@ -49,13 +46,6 @@ CHANNEL_CAPABILITIES = {
|
|||||||
|
|
||||||
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
|
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
|
||||||
|
|
||||||
_METADATA_DROP_KEYS = frozenset({"raw_message", "ref_msg"})
|
|
||||||
|
|
||||||
|
|
||||||
def _slim_metadata(meta: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Return a shallow copy of *meta* with known-large keys removed."""
|
|
||||||
return {k: v for k, v in meta.items() if k not in _METADATA_DROP_KEYS}
|
|
||||||
|
|
||||||
|
|
||||||
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
|
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
|
||||||
|
|
||||||
@@ -155,6 +145,7 @@ def _extract_response_text(result: dict | list) -> str:
|
|||||||
Handles special cases:
|
Handles special cases:
|
||||||
- Regular AI text responses
|
- Regular AI text responses
|
||||||
- Clarification interrupts (``ask_clarification`` tool messages)
|
- Clarification interrupts (``ask_clarification`` tool messages)
|
||||||
|
- AI messages with tool_calls but no text content
|
||||||
"""
|
"""
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
messages = result
|
messages = result
|
||||||
@@ -338,7 +329,7 @@ def _format_artifact_text(artifacts: list[str]) -> str:
|
|||||||
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
||||||
|
|
||||||
|
|
||||||
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
def _resolve_attachments(thread_id: str, artifacts: list[str], *, user_id: str | None = None) -> list[ResolvedAttachment]:
|
||||||
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
||||||
|
|
||||||
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
|
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
|
||||||
@@ -352,39 +343,40 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
|
|||||||
|
|
||||||
attachments: list[ResolvedAttachment] = []
|
attachments: list[ResolvedAttachment] = []
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
user_id = get_effective_user_id()
|
with bind_user_actor_context(user_id):
|
||||||
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
|
effective_user_id = get_effective_user_id()
|
||||||
for virtual_path in artifacts:
|
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=effective_user_id).resolve()
|
||||||
# Security: only allow files from the agent outputs directory
|
for virtual_path in artifacts:
|
||||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
# Security: only allow files from the agent outputs directory
|
||||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||||
continue
|
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||||
try:
|
continue
|
||||||
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
|
|
||||||
# Verify the resolved path is actually under the outputs directory
|
|
||||||
# (guards against path-traversal even after prefix check)
|
|
||||||
try:
|
try:
|
||||||
actual.resolve().relative_to(outputs_dir)
|
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=effective_user_id)
|
||||||
except ValueError:
|
# Verify the resolved path is actually under the outputs directory
|
||||||
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
# (guards against path-traversal even after prefix check)
|
||||||
continue
|
try:
|
||||||
if not actual.is_file():
|
actual.resolve().relative_to(outputs_dir)
|
||||||
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
except ValueError:
|
||||||
continue
|
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
||||||
mime, _ = mimetypes.guess_type(str(actual))
|
continue
|
||||||
mime = mime or "application/octet-stream"
|
if not actual.is_file():
|
||||||
attachments.append(
|
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
||||||
ResolvedAttachment(
|
continue
|
||||||
virtual_path=virtual_path,
|
mime, _ = mimetypes.guess_type(str(actual))
|
||||||
actual_path=actual,
|
mime = mime or "application/octet-stream"
|
||||||
filename=actual.name,
|
attachments.append(
|
||||||
mime_type=mime,
|
ResolvedAttachment(
|
||||||
size=actual.stat().st_size,
|
virtual_path=virtual_path,
|
||||||
is_image=mime.startswith("image/"),
|
actual_path=actual,
|
||||||
|
filename=actual.name,
|
||||||
|
mime_type=mime,
|
||||||
|
size=actual.stat().st_size,
|
||||||
|
is_image=mime.startswith("image/"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except (ValueError, OSError) as exc:
|
||||||
except (ValueError, OSError) as exc:
|
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
||||||
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
|
||||||
return attachments
|
return attachments
|
||||||
|
|
||||||
|
|
||||||
@@ -392,13 +384,15 @@ def _prepare_artifact_delivery(
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
response_text: str,
|
response_text: str,
|
||||||
artifacts: list[str],
|
artifacts: list[str],
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> tuple[str, list[ResolvedAttachment]]:
|
) -> tuple[str, list[ResolvedAttachment]]:
|
||||||
"""Resolve attachments and append filename fallbacks to the text response."""
|
"""Resolve attachments and append filename fallbacks to the text response."""
|
||||||
attachments: list[ResolvedAttachment] = []
|
attachments: list[ResolvedAttachment] = []
|
||||||
if not artifacts:
|
if not artifacts:
|
||||||
return response_text, attachments
|
return response_text, attachments
|
||||||
|
|
||||||
attachments = _resolve_attachments(thread_id, artifacts)
|
attachments = _resolve_attachments(thread_id, artifacts, user_id=user_id)
|
||||||
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
|
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
|
||||||
unresolved = [path for path in artifacts if path not in resolved_virtuals]
|
unresolved = [path for path in artifacts if path not in resolved_virtuals]
|
||||||
|
|
||||||
@@ -419,15 +413,10 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
if not msg.files:
|
if not msg.files:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||||
UnsafeUploadPathError,
|
|
||||||
claim_unique_filename,
|
|
||||||
ensure_uploads_dir,
|
|
||||||
normalize_filename,
|
|
||||||
write_upload_file_no_symlink,
|
|
||||||
)
|
|
||||||
|
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
with bind_user_actor_context(msg.user_id):
|
||||||
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||||
|
|
||||||
created: list[dict[str, Any]] = []
|
created: list[dict[str, Any]] = []
|
||||||
@@ -476,10 +465,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
|
|
||||||
dest = uploads_dir / safe_name
|
dest = uploads_dir / safe_name
|
||||||
try:
|
try:
|
||||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
dest.write_bytes(data)
|
||||||
except UnsafeUploadPathError:
|
|
||||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
|
||||||
continue
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||||
continue
|
continue
|
||||||
@@ -527,7 +513,7 @@ class ChannelManager:
|
|||||||
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
||||||
|
|
||||||
It reads from the MessageBus inbound queue, creates/reuses threads on
|
It reads from the MessageBus inbound queue, creates/reuses threads on
|
||||||
Gateway's LangGraph-compatible API, sends messages via ``runs.wait``, and publishes
|
the LangGraph Server, sends messages via ``runs.wait``, and publishes
|
||||||
outbound responses back through the bus.
|
outbound responses back through the bus.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -552,20 +538,12 @@ class ChannelManager:
|
|||||||
self._default_session = _as_dict(default_session)
|
self._default_session = _as_dict(default_session)
|
||||||
self._channel_sessions = dict(channel_sessions or {})
|
self._channel_sessions = dict(channel_sessions or {})
|
||||||
self._client = None # lazy init — langgraph_sdk async client
|
self._client = None # lazy init — langgraph_sdk async client
|
||||||
self._csrf_token = generate_csrf_token()
|
|
||||||
self._semaphore: asyncio.Semaphore | None = None
|
self._semaphore: asyncio.Semaphore | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _channel_supports_streaming(channel_name: str) -> bool:
|
def _channel_supports_streaming(channel_name: str) -> bool:
|
||||||
from .service import get_channel_service
|
|
||||||
|
|
||||||
service = get_channel_service()
|
|
||||||
if service:
|
|
||||||
channel = service.get_channel(channel_name)
|
|
||||||
if channel is not None:
|
|
||||||
return channel.supports_streaming
|
|
||||||
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
|
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
|
||||||
|
|
||||||
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
|
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
@@ -588,17 +566,6 @@ class ChannelManager:
|
|||||||
user_layer.get("config"),
|
user_layer.get("config"),
|
||||||
)
|
)
|
||||||
|
|
||||||
configurable = run_config.get("configurable")
|
|
||||||
if isinstance(configurable, Mapping):
|
|
||||||
configurable = dict(configurable)
|
|
||||||
else:
|
|
||||||
configurable = {}
|
|
||||||
run_config["configurable"] = configurable
|
|
||||||
# Pin channel-triggered runs to the root graph namespace so follow-up
|
|
||||||
# turns continue from the same conversation checkpoint.
|
|
||||||
configurable["checkpoint_ns"] = ""
|
|
||||||
configurable["thread_id"] = thread_id
|
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
DEFAULT_RUN_CONTEXT,
|
DEFAULT_RUN_CONTEXT,
|
||||||
self._default_session.get("context"),
|
self._default_session.get("context"),
|
||||||
@@ -623,14 +590,7 @@ class ChannelManager:
|
|||||||
if self._client is None:
|
if self._client is None:
|
||||||
from langgraph_sdk import get_client
|
from langgraph_sdk import get_client
|
||||||
|
|
||||||
self._client = get_client(
|
self._client = get_client(url=self._langgraph_url)
|
||||||
url=self._langgraph_url,
|
|
||||||
headers={
|
|
||||||
**create_internal_auth_headers(),
|
|
||||||
CSRF_HEADER_NAME: self._csrf_token,
|
|
||||||
"Cookie": f"{CSRF_COOKIE_NAME}={self._csrf_token}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
# -- lifecycle ---------------------------------------------------------
|
||||||
@@ -713,7 +673,7 @@ class ChannelManager:
|
|||||||
# -- chat handling -----------------------------------------------------
|
# -- chat handling -----------------------------------------------------
|
||||||
|
|
||||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||||
"""Create a new thread through Gateway and store the mapping."""
|
"""Create a new thread on the LangGraph Server and store the mapping."""
|
||||||
thread = await client.threads.create()
|
thread = await client.threads.create()
|
||||||
thread_id = thread["thread_id"]
|
thread_id = thread["thread_id"]
|
||||||
self.store.set_thread_id(
|
self.store.set_thread_id(
|
||||||
@@ -723,7 +683,7 @@ class ChannelManager:
|
|||||||
topic_id=msg.topic_id,
|
topic_id=msg.topic_id,
|
||||||
user_id=msg.user_id,
|
user_id=msg.user_id,
|
||||||
)
|
)
|
||||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
logger.info("[Manager] new thread created on LangGraph Server: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||||
@@ -772,22 +732,13 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
try:
|
result = await client.runs.wait(
|
||||||
result = await client.runs.wait(
|
thread_id,
|
||||||
thread_id,
|
assistant_id,
|
||||||
assistant_id,
|
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
config=run_config,
|
||||||
config=run_config,
|
context=run_context,
|
||||||
context=run_context,
|
)
|
||||||
multitask_strategy="reject",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
if _is_thread_busy_error(exc):
|
|
||||||
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
|
|
||||||
await self._send_error(msg, THREAD_BUSY_MESSAGE)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
@@ -799,7 +750,12 @@ class ChannelManager:
|
|||||||
len(artifacts),
|
len(artifacts),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
response_text, attachments = _prepare_artifact_delivery(
|
||||||
|
thread_id,
|
||||||
|
response_text,
|
||||||
|
artifacts,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
if attachments:
|
if attachments:
|
||||||
@@ -815,7 +771,6 @@ class ChannelManager:
|
|||||||
artifacts=artifacts,
|
artifacts=artifacts,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
@@ -877,7 +832,6 @@ class ChannelManager:
|
|||||||
text=latest_text,
|
text=latest_text,
|
||||||
is_final=False,
|
is_final=False,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
last_published_text = latest_text
|
last_published_text = latest_text
|
||||||
@@ -892,7 +846,12 @@ class ChannelManager:
|
|||||||
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
response_text, attachments = _prepare_artifact_delivery(
|
||||||
|
thread_id,
|
||||||
|
response_text,
|
||||||
|
artifacts,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
if attachments:
|
if attachments:
|
||||||
@@ -922,7 +881,6 @@ class ChannelManager:
|
|||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
is_final=True,
|
is_final=True,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -942,7 +900,7 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if command == "new":
|
if command == "new":
|
||||||
# Create a new thread through Gateway
|
# Create a new thread on the LangGraph Server
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
thread = await client.threads.create()
|
thread = await client.threads.create()
|
||||||
new_thread_id = thread["thread_id"]
|
new_thread_id = thread["thread_id"]
|
||||||
@@ -981,7 +939,6 @@ class ChannelManager:
|
|||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||||
text=reply,
|
text=reply,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|
||||||
@@ -991,11 +948,7 @@ class ChannelManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as http:
|
async with httpx.AsyncClient() as http:
|
||||||
resp = await http.get(
|
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
|
||||||
f"{self._gateway_url}{path}",
|
|
||||||
timeout=10,
|
|
||||||
headers=create_internal_auth_headers(),
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -1019,6 +972,5 @@ class ChannelManager:
|
|||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||||
text=error_text,
|
text=error_text,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||||
@@ -13,13 +13,8 @@ from app.channels.store import ChannelStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
# Channel name → import path for lazy loading
|
# Channel name → import path for lazy loading
|
||||||
_CHANNEL_REGISTRY: dict[str, str] = {
|
_CHANNEL_REGISTRY: dict[str, str] = {
|
||||||
"dingtalk": "app.channels.dingtalk:DingTalkChannel",
|
|
||||||
"discord": "app.channels.discord:DiscordChannel",
|
|
||||||
"feishu": "app.channels.feishu:FeishuChannel",
|
"feishu": "app.channels.feishu:FeishuChannel",
|
||||||
"slack": "app.channels.slack:SlackChannel",
|
"slack": "app.channels.slack:SlackChannel",
|
||||||
"telegram": "app.channels.telegram:TelegramChannel",
|
"telegram": "app.channels.telegram:TelegramChannel",
|
||||||
@@ -27,17 +22,6 @@ _CHANNEL_REGISTRY: dict[str, str] = {
|
|||||||
"wecom": "app.channels.wecom:WeComChannel",
|
"wecom": "app.channels.wecom:WeComChannel",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Keys that indicate a user has configured credentials for a channel.
|
|
||||||
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
|
|
||||||
"dingtalk": ["client_id", "client_secret"],
|
|
||||||
"discord": ["bot_token"],
|
|
||||||
"feishu": ["app_id", "app_secret"],
|
|
||||||
"slack": ["bot_token", "app_token"],
|
|
||||||
"telegram": ["bot_token"],
|
|
||||||
"wecom": ["bot_id", "bot_secret"],
|
|
||||||
"wechat": ["bot_token"],
|
|
||||||
}
|
|
||||||
|
|
||||||
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||||
|
|
||||||
@@ -80,15 +64,14 @@ class ChannelService:
|
|||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
def from_app_config(cls) -> ChannelService:
|
||||||
"""Create a ChannelService from the application config."""
|
"""Create a ChannelService from the application config."""
|
||||||
if app_config is None:
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
|
|
||||||
app_config = get_app_config()
|
config = get_app_config()
|
||||||
channels_config = {}
|
channels_config = {}
|
||||||
# extra fields are allowed by AppConfig (extra="allow")
|
# extra fields are allowed by AppConfig (extra="allow")
|
||||||
extra = app_config.model_extra or {}
|
extra = config.model_extra or {}
|
||||||
if "channels" in extra:
|
if "channels" in extra:
|
||||||
channels_config = extra["channels"]
|
channels_config = extra["channels"]
|
||||||
return cls(channels_config=channels_config)
|
return cls(channels_config=channels_config)
|
||||||
@@ -104,16 +87,7 @@ class ChannelService:
|
|||||||
if not isinstance(channel_config, dict):
|
if not isinstance(channel_config, dict):
|
||||||
continue
|
continue
|
||||||
if not channel_config.get("enabled", False):
|
if not channel_config.get("enabled", False):
|
||||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
logger.info("Channel %s is disabled, skipping", name)
|
||||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
|
||||||
if has_creds:
|
|
||||||
logger.warning(
|
|
||||||
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
|
|
||||||
name,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Channel %s is disabled, skipping", name)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self._start_channel(name, channel_config)
|
await self._start_channel(name, channel_config)
|
||||||
@@ -167,19 +141,12 @@ class ChannelService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = dict(config)
|
|
||||||
config["channel_store"] = self.store
|
|
||||||
channel = channel_cls(bus=self.bus, config=config)
|
channel = channel_cls(bus=self.bus, config=config)
|
||||||
self._channels[name] = channel
|
|
||||||
await channel.start()
|
await channel.start()
|
||||||
if not channel.is_running:
|
self._channels[name] = channel
|
||||||
self._channels.pop(name, None)
|
|
||||||
logger.error("Channel %s did not enter a running state after start()", name)
|
|
||||||
return False
|
|
||||||
logger.info("Channel %s started", name)
|
logger.info("Channel %s started", name)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
self._channels.pop(name, None)
|
|
||||||
logger.exception("Failed to start channel %s", name)
|
logger.exception("Failed to start channel %s", name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -214,12 +181,12 @@ def get_channel_service() -> ChannelService | None:
|
|||||||
return _channel_service
|
return _channel_service
|
||||||
|
|
||||||
|
|
||||||
async def start_channel_service(app_config: AppConfig | None = None) -> ChannelService:
|
async def start_channel_service() -> ChannelService:
|
||||||
"""Create and start the global ChannelService from app config."""
|
"""Create and start the global ChannelService from app config."""
|
||||||
global _channel_service
|
global _channel_service
|
||||||
if _channel_service is not None:
|
if _channel_service is not None:
|
||||||
return _channel_service
|
return _channel_service
|
||||||
_channel_service = ChannelService.from_app_config(app_config)
|
_channel_service = ChannelService.from_app_config()
|
||||||
await _channel_service.start()
|
await _channel_service.start()
|
||||||
return _channel_service
|
return _channel_service
|
||||||
|
|
||||||
|
|||||||
@@ -16,31 +16,13 @@ logger = logging.getLogger(__name__)
|
|||||||
_slack_md_converter = SlackMarkdownConverter()
|
_slack_md_converter = SlackMarkdownConverter()
|
||||||
|
|
||||||
|
|
||||||
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|
||||||
if allowed_users is None:
|
|
||||||
return set()
|
|
||||||
if isinstance(allowed_users, str):
|
|
||||||
values = [allowed_users]
|
|
||||||
elif isinstance(allowed_users, list | tuple | set):
|
|
||||||
values = allowed_users
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Slack allowed_users should be a list of Slack user IDs or a single Slack user ID string; treating %s as one string value",
|
|
||||||
type(allowed_users).__name__,
|
|
||||||
)
|
|
||||||
values = [allowed_users]
|
|
||||||
return {str(user_id) for user_id in values if str(user_id)}
|
|
||||||
|
|
||||||
|
|
||||||
class SlackChannel(Channel):
|
class SlackChannel(Channel):
|
||||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||||
|
|
||||||
Configuration keys (in ``config.yaml`` under ``channels.slack``):
|
Configuration keys (in ``config.yaml`` under ``channels.slack``):
|
||||||
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
|
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
|
||||||
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
|
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
|
||||||
- ``allowed_users``: (optional) List of allowed Slack user IDs, or a
|
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all.
|
||||||
single Slack user ID string as shorthand. Empty = allow all. Other
|
|
||||||
scalar values are treated as a single string with a warning.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||||
@@ -48,7 +30,7 @@ class SlackChannel(Channel):
|
|||||||
self._socket_client = None
|
self._socket_client = None
|
||||||
self._web_client = None
|
self._web_client = None
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
|
|||||||
@@ -29,10 +29,6 @@ class WeComChannel(Channel):
|
|||||||
self._ws_stream_ids: dict[str, str] = {}
|
self._ws_stream_ids: dict[str, str] = {}
|
||||||
self._working_message = "Working on it..."
|
self._working_message = "Working on it..."
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
||||||
if not thread_ts:
|
if not thread_ts:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,4 +1,23 @@
|
|||||||
from .app import app, create_app
|
from __future__ import annotations
|
||||||
from .config import GatewayConfig, get_gateway_config
|
|
||||||
|
|
||||||
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]
|
__all__ = ["GatewayConfig", "app", "get_gateway_config", "register_app"]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "app":
|
||||||
|
from .app import app
|
||||||
|
|
||||||
|
return app
|
||||||
|
if name == "GatewayConfig":
|
||||||
|
from .config import GatewayConfig
|
||||||
|
|
||||||
|
return GatewayConfig
|
||||||
|
if name == "get_gateway_config":
|
||||||
|
from .config import get_gateway_config
|
||||||
|
|
||||||
|
return get_gateway_config
|
||||||
|
if name == "register_app":
|
||||||
|
from .registrar import register_app
|
||||||
|
|
||||||
|
return register_app
|
||||||
|
raise AttributeError(name)
|
||||||
|
|||||||
+4
-387
@@ -1,391 +1,8 @@
|
|||||||
import asyncio
|
from app.gateway.registrar import register_app
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from app.gateway.auth_middleware import AuthMiddleware
|
def create_app():
|
||||||
from app.gateway.config import get_gateway_config
|
return register_app()
|
||||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
|
||||||
from app.gateway.deps import langgraph_runtime
|
|
||||||
from app.gateway.routers import (
|
|
||||||
agents,
|
|
||||||
artifacts,
|
|
||||||
assistants_compat,
|
|
||||||
auth,
|
|
||||||
channels,
|
|
||||||
feedback,
|
|
||||||
mcp,
|
|
||||||
memory,
|
|
||||||
models,
|
|
||||||
runs,
|
|
||||||
skills,
|
|
||||||
suggestions,
|
|
||||||
thread_runs,
|
|
||||||
threads,
|
|
||||||
uploads,
|
|
||||||
)
|
|
||||||
from deerflow.config import app_config as deerflow_app_config
|
|
||||||
from deerflow.config.app_config import apply_logging_level
|
|
||||||
|
|
||||||
AppConfig = deerflow_app_config.AppConfig
|
|
||||||
get_app_config = deerflow_app_config.get_app_config
|
|
||||||
|
|
||||||
# Default logging; lifespan overrides from config.yaml log_level.
|
app = register_app()
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
|
|
||||||
# Bounds worker exit time so uvicorn's reload supervisor does not keep
|
|
||||||
# firing signals into a worker that is stuck waiting for shutdown cleanup.
|
|
||||||
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
|
|
||||||
|
|
||||||
|
|
||||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
|
||||||
"""Startup hook: handle first boot and migrate orphan threads otherwise.
|
|
||||||
|
|
||||||
After admin creation, migrate orphan threads from the LangGraph
|
|
||||||
store (metadata.user_id unset) to the admin account. This is the
|
|
||||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
|
||||||
authentication have existing LangGraph thread data that needs an
|
|
||||||
owner assigned.
|
|
||||||
First boot (no admin exists):
|
|
||||||
- Does NOT create any user accounts automatically.
|
|
||||||
- The operator must visit ``/setup`` to create the first admin.
|
|
||||||
|
|
||||||
Subsequent boots (admin already exists):
|
|
||||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
|
||||||
existing LangGraph thread metadata that has no user_id.
|
|
||||||
|
|
||||||
No SQL persistence migration is needed: the four user_id columns
|
|
||||||
(threads_meta, runs, run_events, feedback) only come into existence
|
|
||||||
alongside the auth module via create_all, so freshly created tables
|
|
||||||
never contain NULL-owner rows.
|
|
||||||
"""
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.gateway.deps import get_local_provider
|
|
||||||
from deerflow.persistence.engine import get_session_factory
|
|
||||||
from deerflow.persistence.user.model import UserRow
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider = get_local_provider()
|
|
||||||
except RuntimeError:
|
|
||||||
# Auth persistence may not be initialized in some test/boot paths.
|
|
||||||
# Skip admin migration work rather than failing gateway startup.
|
|
||||||
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
|
|
||||||
return
|
|
||||||
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
admin_count = await provider.count_admin_users()
|
|
||||||
|
|
||||||
if admin_count == 0:
|
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info(" First boot detected — no admin account exists.")
|
|
||||||
logger.info(" Visit /setup to complete admin account creation.")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Admin already exists — run orphan thread migration for any
|
|
||||||
# LangGraph thread metadata that pre-dates the auth module.
|
|
||||||
async with sf() as session:
|
|
||||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
return # Should not happen (admin_count > 0 above), but be safe.
|
|
||||||
|
|
||||||
admin_id = str(row.id)
|
|
||||||
|
|
||||||
# LangGraph store orphan migration — non-fatal.
|
|
||||||
# This covers the "no-auth → with-auth" upgrade path for users
|
|
||||||
# whose existing LangGraph thread metadata has no user_id set.
|
|
||||||
store = getattr(app.state, "store", None)
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
migrated = await _migrate_orphaned_threads(store, admin_id)
|
|
||||||
if migrated:
|
|
||||||
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("LangGraph thread migration failed (non-fatal)")
|
|
||||||
|
|
||||||
|
|
||||||
async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
|
||||||
"""Paginated async iterator over a LangGraph store namespace.
|
|
||||||
|
|
||||||
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
|
|
||||||
loop so that environments with more than one page of orphans do
|
|
||||||
not silently lose data. Terminates when a page is empty OR when a
|
|
||||||
short page arrives (indicating the last page).
|
|
||||||
"""
|
|
||||||
offset = 0
|
|
||||||
while True:
|
|
||||||
batch = await store.asearch(namespace, limit=page_size, offset=offset)
|
|
||||||
if not batch:
|
|
||||||
return
|
|
||||||
for item in batch:
|
|
||||||
yield item
|
|
||||||
if len(batch) < page_size:
|
|
||||||
return
|
|
||||||
offset += page_size
|
|
||||||
|
|
||||||
|
|
||||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
|
||||||
"""Migrate LangGraph store threads with no user_id to the given admin.
|
|
||||||
|
|
||||||
Uses cursor pagination so all orphans are migrated regardless of
|
|
||||||
count. Returns the number of rows migrated.
|
|
||||||
"""
|
|
||||||
migrated = 0
|
|
||||||
async for item in _iter_store_items(store, ("threads",)):
|
|
||||||
metadata = item.value.get("metadata", {})
|
|
||||||
if not metadata.get("user_id"):
|
|
||||||
metadata["user_id"] = admin_user_id
|
|
||||||
item.value["metadata"] = metadata
|
|
||||||
await store.aput(("threads",), item.key, item.value)
|
|
||||||
migrated += 1
|
|
||||||
return migrated
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
||||||
"""Application lifespan handler."""
|
|
||||||
|
|
||||||
# Load config and check necessary environment variables at startup.
|
|
||||||
# `startup_config` is a local snapshot used only for one-shot bootstrap
|
|
||||||
# work (logging level, langgraph_runtime engines, channels). Request-time
|
|
||||||
# config resolution always routes through `get_app_config()` in
|
|
||||||
# `app/gateway/deps.py::get_config()` so `config.yaml` edits become
|
|
||||||
# visible without a process restart. We deliberately do NOT cache this
|
|
||||||
# snapshot on `app.state` to keep that contract enforceable.
|
|
||||||
try:
|
|
||||||
startup_config = get_app_config()
|
|
||||||
apply_logging_level(startup_config.log_level)
|
|
||||||
logger.info("Configuration loaded successfully")
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
|
||||||
logger.exception(error_msg)
|
|
||||||
raise RuntimeError(error_msg) from e
|
|
||||||
config = get_gateway_config()
|
|
||||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
|
||||||
async with langgraph_runtime(app, startup_config):
|
|
||||||
logger.info("LangGraph runtime initialised")
|
|
||||||
|
|
||||||
# Check admin bootstrap state and migrate orphan threads after admin exists.
|
|
||||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
|
||||||
await _ensure_admin_user(app)
|
|
||||||
|
|
||||||
# Start IM channel service if any channels are configured
|
|
||||||
try:
|
|
||||||
from app.channels.service import start_channel_service
|
|
||||||
|
|
||||||
channel_service = await start_channel_service(startup_config)
|
|
||||||
logger.info("Channel service started: %s", channel_service.get_status())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("No IM channels configured or channel service failed to start")
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Stop channel service on shutdown (bounded to prevent worker hang)
|
|
||||||
try:
|
|
||||||
from app.channels.service import stop_channel_service
|
|
||||||
|
|
||||||
await asyncio.wait_for(
|
|
||||||
stop_channel_service(),
|
|
||||||
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
|
|
||||||
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to stop channel service")
|
|
||||||
|
|
||||||
logger.info("Shutting down API Gateway")
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
"""Create and configure the FastAPI application.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured FastAPI application instance.
|
|
||||||
"""
|
|
||||||
config = get_gateway_config()
|
|
||||||
docs_url = "/docs" if config.enable_docs else None
|
|
||||||
redoc_url = "/redoc" if config.enable_docs else None
|
|
||||||
openapi_url = "/openapi.json" if config.enable_docs else None
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="DeerFlow API Gateway",
|
|
||||||
description="""
|
|
||||||
## DeerFlow API Gateway
|
|
||||||
|
|
||||||
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
- **Models Management**: Query and retrieve available AI models
|
|
||||||
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
|
|
||||||
- **Memory Management**: Access and manage global memory data for personalized conversations
|
|
||||||
- **Skills Management**: Query and manage skills and their enabled status
|
|
||||||
- **Artifacts**: Access thread artifacts and generated files
|
|
||||||
- **Health Monitoring**: System health check endpoints
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
LangGraph-compatible requests are routed through nginx to this gateway.
|
|
||||||
This gateway provides runtime endpoints for agent runs plus custom endpoints for models, MCP configuration, skills, and artifacts.
|
|
||||||
""",
|
|
||||||
version="0.1.0",
|
|
||||||
lifespan=lifespan,
|
|
||||||
docs_url=docs_url,
|
|
||||||
redoc_url=redoc_url,
|
|
||||||
openapi_url=openapi_url,
|
|
||||||
openapi_tags=[
|
|
||||||
{
|
|
||||||
"name": "models",
|
|
||||||
"description": "Operations for querying available AI models and their configurations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "mcp",
|
|
||||||
"description": "Manage Model Context Protocol (MCP) server configurations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "memory",
|
|
||||||
"description": "Access and manage global memory data for personalized conversations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "skills",
|
|
||||||
"description": "Manage skills and their configurations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "artifacts",
|
|
||||||
"description": "Access and download thread artifacts and generated files",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "uploads",
|
|
||||||
"description": "Upload and manage user files for threads",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "threads",
|
|
||||||
"description": "Manage DeerFlow thread-local filesystem data",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "agents",
|
|
||||||
"description": "Create and manage custom agents with per-agent config and prompts",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "suggestions",
|
|
||||||
"description": "Generate follow-up question suggestions for conversations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "channels",
|
|
||||||
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "assistants-compat",
|
|
||||||
"description": "LangGraph Platform-compatible assistants API (stub)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "runs",
|
|
||||||
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "health",
|
|
||||||
"description": "Health check and system status endpoints",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
|
||||||
app.add_middleware(AuthMiddleware)
|
|
||||||
|
|
||||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
|
||||||
app.add_middleware(CSRFMiddleware)
|
|
||||||
|
|
||||||
# CORS: the unified nginx endpoint is same-origin by default. Split-origin
|
|
||||||
# browser clients must opt in with this explicit Gateway allowlist so CORS
|
|
||||||
# and CSRF origin checks share the same source of truth.
|
|
||||||
cors_origins = sorted(get_configured_cors_origins())
|
|
||||||
if cors_origins:
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=cors_origins,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Include routers
|
|
||||||
# Models API is mounted at /api/models
|
|
||||||
app.include_router(models.router)
|
|
||||||
|
|
||||||
# MCP API is mounted at /api/mcp
|
|
||||||
app.include_router(mcp.router)
|
|
||||||
|
|
||||||
# Memory API is mounted at /api/memory
|
|
||||||
app.include_router(memory.router)
|
|
||||||
|
|
||||||
# Skills API is mounted at /api/skills
|
|
||||||
app.include_router(skills.router)
|
|
||||||
|
|
||||||
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
|
||||||
app.include_router(artifacts.router)
|
|
||||||
|
|
||||||
# Uploads API is mounted at /api/threads/{thread_id}/uploads
|
|
||||||
app.include_router(uploads.router)
|
|
||||||
|
|
||||||
# Thread cleanup API is mounted at /api/threads/{thread_id}
|
|
||||||
app.include_router(threads.router)
|
|
||||||
|
|
||||||
# Agents API is mounted at /api/agents
|
|
||||||
app.include_router(agents.router)
|
|
||||||
|
|
||||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
|
||||||
app.include_router(suggestions.router)
|
|
||||||
|
|
||||||
# Channels API is mounted at /api/channels
|
|
||||||
app.include_router(channels.router)
|
|
||||||
|
|
||||||
# Assistants compatibility API (LangGraph Platform stub)
|
|
||||||
app.include_router(assistants_compat.router)
|
|
||||||
|
|
||||||
# Auth API is mounted at /api/v1/auth
|
|
||||||
app.include_router(auth.router)
|
|
||||||
|
|
||||||
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
|
||||||
app.include_router(feedback.router)
|
|
||||||
|
|
||||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
|
||||||
app.include_router(thread_runs.router)
|
|
||||||
|
|
||||||
# Stateless Runs API (stream/wait without a pre-existing thread)
|
|
||||||
app.include_router(runs.router)
|
|
||||||
|
|
||||||
@app.get("/health", tags=["health"])
|
|
||||||
async def health_check() -> dict[str, str]:
|
|
||||||
"""Health check endpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Service health status information.
|
|
||||||
"""
|
|
||||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
# Create app instance for uvicorn
|
|
||||||
app = create_app()
|
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
"""Authentication module for DeerFlow.
|
|
||||||
|
|
||||||
This module provides:
|
|
||||||
- JWT-based authentication
|
|
||||||
- Provider Factory pattern for extensible auth methods
|
|
||||||
- UserRepository interface for storage backends (SQLite)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
|
||||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
|
||||||
from app.gateway.auth.models import User, UserResponse
|
|
||||||
from app.gateway.auth.password import hash_password, verify_password
|
|
||||||
from app.gateway.auth.providers import AuthProvider
|
|
||||||
from app.gateway.auth.repositories.base import UserRepository
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Config
|
|
||||||
"AuthConfig",
|
|
||||||
"get_auth_config",
|
|
||||||
"set_auth_config",
|
|
||||||
# Errors
|
|
||||||
"AuthErrorCode",
|
|
||||||
"AuthErrorResponse",
|
|
||||||
"TokenError",
|
|
||||||
# JWT
|
|
||||||
"TokenPayload",
|
|
||||||
"create_access_token",
|
|
||||||
"decode_token",
|
|
||||||
# Password
|
|
||||||
"hash_password",
|
|
||||||
"verify_password",
|
|
||||||
# Models
|
|
||||||
"User",
|
|
||||||
"UserResponse",
|
|
||||||
# Providers
|
|
||||||
"AuthProvider",
|
|
||||||
"LocalAuthProvider",
|
|
||||||
# Repository
|
|
||||||
"UserRepository",
|
|
||||||
]
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
"""Authentication configuration for DeerFlow."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_SECRET_FILE = ".jwt_secret"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(BaseModel):
|
|
||||||
"""JWT and auth-related configuration. Parsed once at startup.
|
|
||||||
|
|
||||||
Note: the ``users`` table now lives in the shared persistence
|
|
||||||
database managed by ``deerflow.persistence.engine``. The old
|
|
||||||
``users_db_path`` config key has been removed — user storage is
|
|
||||||
configured through ``config.database`` like every other table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
jwt_secret: str = Field(
|
|
||||||
...,
|
|
||||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
|
||||||
)
|
|
||||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
|
||||||
oauth_github_client_id: str | None = Field(default=None)
|
|
||||||
oauth_github_client_secret: str | None = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
_auth_config: AuthConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _load_or_create_secret() -> str:
|
|
||||||
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
paths = get_paths()
|
|
||||||
secret_file = paths.base_dir / _SECRET_FILE
|
|
||||||
|
|
||||||
try:
|
|
||||||
if secret_file.exists():
|
|
||||||
secret = secret_file.read_text(encoding="utf-8").strip()
|
|
||||||
if secret:
|
|
||||||
return secret
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
|
|
||||||
|
|
||||||
secret = secrets.token_urlsafe(32)
|
|
||||||
try:
|
|
||||||
secret_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
||||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
||||||
fh.write(secret)
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
|
|
||||||
return secret
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_config() -> AuthConfig:
|
|
||||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
|
||||||
global _auth_config
|
|
||||||
if _auth_config is None:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
|
||||||
if not jwt_secret:
|
|
||||||
jwt_secret = _load_or_create_secret()
|
|
||||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
|
||||||
logger.warning(
|
|
||||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
|
|
||||||
"persisted to .jwt_secret. Sessions will survive restarts. "
|
|
||||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
|
||||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
|
||||||
)
|
|
||||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
|
||||||
return _auth_config
|
|
||||||
|
|
||||||
|
|
||||||
def set_auth_config(config: AuthConfig) -> None:
|
|
||||||
"""Set the global AuthConfig instance (for testing)."""
|
|
||||||
global _auth_config
|
|
||||||
_auth_config = config
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""Write initial admin credentials to a restricted file instead of logs.
|
|
||||||
|
|
||||||
Logging secrets to stdout/stderr is a well-known CodeQL finding
|
|
||||||
(py/clear-text-logging-sensitive-data) — in production those logs
|
|
||||||
get collected into ELK/Splunk/etc and become a secret sprawl
|
|
||||||
source. This helper writes the credential to a 0600 file that only
|
|
||||||
the process user can read, and returns the path so the caller can
|
|
||||||
log **the path** (not the password) for the operator to pick up.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
|
|
||||||
|
|
||||||
|
|
||||||
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
|
||||||
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
|
|
||||||
|
|
||||||
The file is created **atomically** with mode 0600 via ``os.open``
|
|
||||||
so the password is never world-readable, even for the single syscall
|
|
||||||
window between ``write_text`` and ``chmod``.
|
|
||||||
|
|
||||||
``label`` distinguishes "initial" (fresh creation) from "reset"
|
|
||||||
(password reset) in the file header so an operator picking up the
|
|
||||||
file after a restart can tell which event produced it.
|
|
||||||
|
|
||||||
Returns the absolute :class:`Path` to the file.
|
|
||||||
"""
|
|
||||||
target = get_paths().base_dir / _CREDENTIAL_FILENAME
|
|
||||||
target.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
content = (
|
|
||||||
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
|
|
||||||
# reset-password path can rewrite an existing file without a
|
|
||||||
# separate unlink-then-create dance.
|
|
||||||
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
||||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
||||||
fh.write(content)
|
|
||||||
|
|
||||||
return target.resolve()
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
"""Local email/password authentication provider."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
from app.gateway.auth.password import hash_password_async, needs_rehash, verify_password_async
|
|
||||||
from app.gateway.auth.providers import AuthProvider
|
|
||||||
from app.gateway.auth.repositories.base import UserRepository
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAuthProvider(AuthProvider):
|
|
||||||
"""Email/password authentication provider using local database."""
|
|
||||||
|
|
||||||
def __init__(self, repository: UserRepository):
|
|
||||||
"""Initialize with a UserRepository.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repository: UserRepository implementation (SQLite)
|
|
||||||
"""
|
|
||||||
self._repo = repository
|
|
||||||
|
|
||||||
async def authenticate(self, credentials: dict) -> User | None:
|
|
||||||
"""Authenticate with email and password.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: dict with 'email' and 'password' keys
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if authentication succeeds, None otherwise
|
|
||||||
"""
|
|
||||||
email = credentials.get("email")
|
|
||||||
password = credentials.get("password")
|
|
||||||
|
|
||||||
if not email or not password:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user = await self._repo.get_user_by_email(email)
|
|
||||||
if user is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if user.password_hash is None:
|
|
||||||
# OAuth user without local password
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not await verify_password_async(password, user.password_hash):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if needs_rehash(user.password_hash):
|
|
||||||
try:
|
|
||||||
user.password_hash = await hash_password_async(password)
|
|
||||||
await self._repo.update_user(user)
|
|
||||||
except Exception:
|
|
||||||
# Rehash is an opportunistic upgrade; a transient DB error must not
|
|
||||||
# prevent an otherwise-valid login from succeeding.
|
|
||||||
logger.warning("Failed to rehash password for user %s; login will still succeed", user.email, exc_info=True)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def get_user(self, user_id: str) -> User | None:
|
|
||||||
"""Get user by ID."""
|
|
||||||
return await self._repo.get_user_by_id(user_id)
|
|
||||||
|
|
||||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
|
||||||
"""Create a new local user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
email: User email address
|
|
||||||
password: Plain text password (will be hashed)
|
|
||||||
system_role: Role to assign ("admin" or "user")
|
|
||||||
needs_setup: If True, user must complete setup on first login
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created User instance
|
|
||||||
"""
|
|
||||||
password_hash = await hash_password_async(password) if password else None
|
|
||||||
user = User(
|
|
||||||
email=email,
|
|
||||||
password_hash=password_hash,
|
|
||||||
system_role=system_role,
|
|
||||||
needs_setup=needs_setup,
|
|
||||||
)
|
|
||||||
return await self._repo.create_user(user)
|
|
||||||
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
||||||
"""Get user by OAuth provider and ID."""
|
|
||||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
|
||||||
|
|
||||||
async def count_users(self) -> int:
|
|
||||||
"""Return total number of registered users."""
|
|
||||||
return await self._repo.count_users()
|
|
||||||
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
"""Return number of admin users."""
|
|
||||||
return await self._repo.count_admin_users()
|
|
||||||
|
|
||||||
async def update_user(self, user: User) -> User:
|
|
||||||
"""Update an existing user."""
|
|
||||||
return await self._repo.update_user(user)
|
|
||||||
|
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
|
||||||
"""Get user by email."""
|
|
||||||
return await self._repo.get_user_by_email(email)
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"""Password hashing utilities with versioned hash format.
|
|
||||||
|
|
||||||
Hash format: ``$dfv<N>$<bcrypt_hash>`` where ``<N>`` is the version.
|
|
||||||
|
|
||||||
- **v1** (legacy): ``bcrypt(password)`` — plain bcrypt, susceptible to
|
|
||||||
72-byte silent truncation.
|
|
||||||
- **v2** (current): ``bcrypt(b64(sha256(password)))`` — SHA-256 pre-hash
|
|
||||||
avoids the 72-byte truncation limit so the full password contributes
|
|
||||||
to the hash.
|
|
||||||
|
|
||||||
Verification auto-detects the version and falls back to v1 for hashes
|
|
||||||
without a prefix, so existing deployments upgrade transparently on next
|
|
||||||
login.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
|
|
||||||
_CURRENT_VERSION = 2
|
|
||||||
_PREFIX_V2 = "$dfv2$"
|
|
||||||
_PREFIX_V1 = "$dfv1$"
|
|
||||||
|
|
||||||
|
|
||||||
def _pre_hash_v2(password: str) -> bytes:
|
|
||||||
"""SHA-256 pre-hash to bypass bcrypt's 72-byte limit."""
|
|
||||||
return base64.b64encode(hashlib.sha256(password.encode("utf-8")).digest())
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
|
||||||
"""Hash a password (current version: v2 — SHA-256 + bcrypt)."""
|
|
||||||
raw = bcrypt.hashpw(_pre_hash_v2(password), bcrypt.gensalt()).decode("utf-8")
|
|
||||||
return f"{_PREFIX_V2}{raw}"
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
"""Verify a password, auto-detecting the hash version.
|
|
||||||
|
|
||||||
Accepts v2 (``$dfv2$…``), v1 (``$dfv1$…``), and bare bcrypt hashes
|
|
||||||
(treated as v1 for backward compatibility with pre-versioning data).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if hashed_password.startswith(_PREFIX_V2):
|
|
||||||
bcrypt_hash = hashed_password[len(_PREFIX_V2) :]
|
|
||||||
return bcrypt.checkpw(_pre_hash_v2(plain_password), bcrypt_hash.encode("utf-8"))
|
|
||||||
|
|
||||||
if hashed_password.startswith(_PREFIX_V1):
|
|
||||||
bcrypt_hash = hashed_password[len(_PREFIX_V1) :]
|
|
||||||
else:
|
|
||||||
bcrypt_hash = hashed_password
|
|
||||||
|
|
||||||
return bcrypt.checkpw(plain_password.encode("utf-8"), bcrypt_hash.encode("utf-8"))
|
|
||||||
except ValueError:
|
|
||||||
# bcrypt raises ValueError for malformed or corrupt hashes (e.g., invalid salt).
|
|
||||||
# Fail closed rather than crashing the request.
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def needs_rehash(hashed_password: str) -> bool:
|
|
||||||
"""Return True if the hash uses an older version and should be rehashed."""
|
|
||||||
return not hashed_password.startswith(_PREFIX_V2)
|
|
||||||
|
|
||||||
|
|
||||||
async def hash_password_async(password: str) -> str:
|
|
||||||
"""Hash a password using bcrypt (non-blocking).
|
|
||||||
|
|
||||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
|
||||||
blocking the event loop during password hashing.
|
|
||||||
"""
|
|
||||||
return await asyncio.to_thread(hash_password, password)
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
"""Verify a password against its hash (non-blocking).
|
|
||||||
|
|
||||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
|
||||||
blocking the event loop during password verification.
|
|
||||||
"""
|
|
||||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""Auth provider abstraction."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
|
||||||
"""Abstract base class for authentication providers."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def authenticate(self, credentials: dict) -> "User | None":
|
|
||||||
"""Authenticate user with given credentials.
|
|
||||||
|
|
||||||
Returns User if authentication succeeds, None otherwise.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user(self, user_id: str) -> "User | None":
|
|
||||||
"""Retrieve user by ID."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
# Import User at runtime to avoid circular imports
|
|
||||||
from app.gateway.auth.models import User # noqa: E402
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
"""User repository interface for abstracting database operations."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(LookupError):
|
|
||||||
"""Raised when a user repository operation targets a non-existent row.
|
|
||||||
|
|
||||||
Subclass of :class:`LookupError` so callers that already catch
|
|
||||||
``LookupError`` for "missing entity" can keep working unchanged,
|
|
||||||
while specific call sites can pin to this class to distinguish
|
|
||||||
"concurrent delete during update" from other lookups.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class UserRepository(ABC):
|
|
||||||
"""Abstract interface for user data storage.
|
|
||||||
|
|
||||||
Implement this interface to support different storage backends
|
|
||||||
(SQLite)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def create_user(self, user: User) -> User:
|
|
||||||
"""Create a new user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: User object to create
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created User with ID assigned
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If email already exists
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
||||||
"""Get user by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User UUID as string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
|
||||||
"""Get user by email.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
email: User email address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def update_user(self, user: User) -> User:
|
|
||||||
"""Update an existing user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: User object with updated fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated User
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
UserNotFoundError: If no row exists for ``user.id``. This is
|
|
||||||
a hard failure (not a no-op) so callers cannot mistake a
|
|
||||||
concurrent-delete race for a successful update.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def count_users(self) -> int:
|
|
||||||
"""Return total number of registered users."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
"""Return number of users with system_role == 'admin'."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
||||||
"""Get user by OAuth provider and ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: OAuth provider name (e.g. 'github', 'google')
|
|
||||||
oauth_id: User ID from the OAuth provider
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
"""SQLAlchemy-backed UserRepository implementation.
|
|
||||||
|
|
||||||
Uses the shared async session factory from
|
|
||||||
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
|
||||||
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
|
||||||
``feedback``.
|
|
||||||
|
|
||||||
Constructor takes the session factory directly (same pattern as the
|
|
||||||
other four repositories in ``deerflow.persistence.*``). Callers
|
|
||||||
construct this after ``init_engine_from_config()`` has run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import UTC
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
|
|
||||||
from deerflow.persistence.user.model import UserRow
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteUserRepository(UserRepository):
|
|
||||||
"""Async user repository backed by the shared SQLAlchemy engine."""
|
|
||||||
|
|
||||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
|
||||||
self._sf = session_factory
|
|
||||||
|
|
||||||
# ── Converters ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _row_to_user(row: UserRow) -> User:
|
|
||||||
return User(
|
|
||||||
id=UUID(row.id),
|
|
||||||
email=row.email,
|
|
||||||
password_hash=row.password_hash,
|
|
||||||
system_role=row.system_role, # type: ignore[arg-type]
|
|
||||||
# SQLite loses tzinfo on read; reattach UTC so downstream
|
|
||||||
# code can compare timestamps reliably.
|
|
||||||
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
|
||||||
oauth_provider=row.oauth_provider,
|
|
||||||
oauth_id=row.oauth_id,
|
|
||||||
needs_setup=row.needs_setup,
|
|
||||||
token_version=row.token_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _user_to_row(user: User) -> UserRow:
|
|
||||||
return UserRow(
|
|
||||||
id=str(user.id),
|
|
||||||
email=user.email,
|
|
||||||
password_hash=user.password_hash,
|
|
||||||
system_role=user.system_role,
|
|
||||||
created_at=user.created_at,
|
|
||||||
oauth_provider=user.oauth_provider,
|
|
||||||
oauth_id=user.oauth_id,
|
|
||||||
needs_setup=user.needs_setup,
|
|
||||||
token_version=user.token_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── CRUD ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def create_user(self, user: User) -> User:
|
|
||||||
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
|
||||||
row = self._user_to_row(user)
|
|
||||||
async with self._sf() as session:
|
|
||||||
session.add(row)
|
|
||||||
try:
|
|
||||||
await session.commit()
|
|
||||||
except IntegrityError as exc:
|
|
||||||
await session.rollback()
|
|
||||||
raise ValueError(f"Email already registered: {user.email}") from exc
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
||||||
async with self._sf() as session:
|
|
||||||
row = await session.get(UserRow, user_id)
|
|
||||||
return self._row_to_user(row) if row is not None else None
|
|
||||||
|
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
|
||||||
stmt = select(UserRow).where(UserRow.email == email)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
return self._row_to_user(row) if row is not None else None
|
|
||||||
|
|
||||||
async def update_user(self, user: User) -> User:
|
|
||||||
async with self._sf() as session:
|
|
||||||
row = await session.get(UserRow, str(user.id))
|
|
||||||
if row is None:
|
|
||||||
# Hard fail on concurrent delete: callers (reset_admin,
|
|
||||||
# password change handlers, _ensure_admin_user) all
|
|
||||||
# fetched the user just before this call, so a missing
|
|
||||||
# row here means the row vanished underneath us. Silent
|
|
||||||
# success would let the caller log "password reset" for
|
|
||||||
# a row that no longer exists.
|
|
||||||
raise UserNotFoundError(f"User {user.id} no longer exists")
|
|
||||||
row.email = user.email
|
|
||||||
row.password_hash = user.password_hash
|
|
||||||
row.system_role = user.system_role
|
|
||||||
row.oauth_provider = user.oauth_provider
|
|
||||||
row.oauth_id = user.oauth_id
|
|
||||||
row.needs_setup = user.needs_setup
|
|
||||||
row.token_version = user.token_version
|
|
||||||
await session.commit()
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def count_users(self) -> int:
|
|
||||||
stmt = select(func.count()).select_from(UserRow)
|
|
||||||
async with self._sf() as session:
|
|
||||||
return await session.scalar(stmt) or 0
|
|
||||||
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
|
|
||||||
async with self._sf() as session:
|
|
||||||
return await session.scalar(stmt) or 0
|
|
||||||
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
||||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
return self._row_to_user(row) if row is not None else None
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""CLI tool to reset an admin password.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m app.gateway.auth.reset_admin
|
|
||||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
|
||||||
|
|
||||||
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
|
|
||||||
(mode 0600) instead of printing it, so CI / log aggregators never see
|
|
||||||
the cleartext secret.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import secrets
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.gateway.auth.credential_file import write_initial_credentials
|
|
||||||
from app.gateway.auth.password import hash_password
|
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
|
||||||
from deerflow.persistence.user.model import UserRow
|
|
||||||
|
|
||||||
|
|
||||||
async def _run(email: str | None) -> int:
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
from deerflow.persistence.engine import (
|
|
||||||
close_engine,
|
|
||||||
get_session_factory,
|
|
||||||
init_engine_from_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = get_app_config()
|
|
||||||
await init_engine_from_config(config.database)
|
|
||||||
try:
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
repo = SQLiteUserRepository(sf)
|
|
||||||
|
|
||||||
if email:
|
|
||||||
user = await repo.get_user_by_email(email)
|
|
||||||
else:
|
|
||||||
# Find first admin via direct SELECT — repository does not
|
|
||||||
# expose a "first admin" helper and we do not want to add
|
|
||||||
# one just for this CLI.
|
|
||||||
async with sf() as session:
|
|
||||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
user = None
|
|
||||||
else:
|
|
||||||
user = await repo.get_user_by_id(row.id)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
if email:
|
|
||||||
print(f"Error: user '{email}' not found.", file=sys.stderr)
|
|
||||||
else:
|
|
||||||
print("Error: no admin user found.", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
new_password = secrets.token_urlsafe(16)
|
|
||||||
user.password_hash = hash_password(new_password)
|
|
||||||
user.token_version += 1
|
|
||||||
user.needs_setup = True
|
|
||||||
await repo.update_user(user)
|
|
||||||
|
|
||||||
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
|
||||||
print(f"Password reset for: {user.email}")
|
|
||||||
print(f"Credentials written to: {cred_path} (mode 0600)")
|
|
||||||
print("Next login will require setup (new email + password).")
|
|
||||||
return 0
|
|
||||||
finally:
|
|
||||||
await close_engine()
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
|
||||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
exit_code = asyncio.run(_run(args.email))
|
|
||||||
sys.exit(exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
"""Global authentication middleware — fail-closed safety net.
|
|
||||||
|
|
||||||
Rejects unauthenticated requests to non-public paths with 401. When a
|
|
||||||
request passes the cookie check, resolves the JWT payload to a real
|
|
||||||
``User`` object and stamps it into both ``request.state.user`` and the
|
|
||||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
|
||||||
owner filtering works automatically via the sentinel pattern.
|
|
||||||
|
|
||||||
Fine-grained permission checks remain in authz.py decorators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, Response
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
|
||||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
|
||||||
|
|
||||||
# Paths that never require authentication.
|
|
||||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
|
||||||
"/health",
|
|
||||||
"/docs",
|
|
||||||
"/redoc",
|
|
||||||
"/openapi.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exact auth paths that are public (login/register/status check).
|
|
||||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
|
||||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"/api/v1/auth/login/local",
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
"/api/v1/auth/setup-status",
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_public(path: str) -> bool:
|
|
||||||
stripped = path.rstrip("/")
|
|
||||||
if stripped in _PUBLIC_EXACT_PATHS:
|
|
||||||
return True
|
|
||||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Strict auth gate: reject requests without a valid session.
|
|
||||||
|
|
||||||
Two-stage check for non-public paths:
|
|
||||||
|
|
||||||
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
|
|
||||||
2. JWT validation via ``get_optional_user_from_request`` — return 401
|
|
||||||
TOKEN_INVALID if the token is absent, malformed, expired, or the
|
|
||||||
signed user does not exist / is stale
|
|
||||||
|
|
||||||
On success, stamps ``request.state.user`` and the
|
|
||||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
|
||||||
owner filters work downstream without every route needing a
|
|
||||||
``@require_auth`` decorator. Routes that need per-resource
|
|
||||||
authorization (e.g. "user A cannot read user B's thread by guessing
|
|
||||||
the URL") should additionally use ``@require_permission(...,
|
|
||||||
owner_check=True)`` for explicit enforcement — but authentication
|
|
||||||
itself is fully handled here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
super().__init__(app)
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
if _is_public(request.url.path):
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
internal_user = None
|
|
||||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
|
||||||
internal_user = get_internal_user()
|
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
|
||||||
if internal_user is None and not request.cookies.get("access_token"):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=401,
|
|
||||||
content={
|
|
||||||
"detail": AuthErrorResponse(
|
|
||||||
code=AuthErrorCode.NOT_AUTHENTICATED,
|
|
||||||
message="Authentication required",
|
|
||||||
).model_dump()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Strict JWT validation: reject junk/expired tokens with 401
|
|
||||||
# right here instead of silently passing through. This closes
|
|
||||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
|
||||||
# without this, non-isolation routes like /api/models would
|
|
||||||
# accept any cookie-shaped string as authentication.
|
|
||||||
#
|
|
||||||
# We call the *strict* resolver so that fine-grained error
|
|
||||||
# codes (token_expired, token_invalid, user_not_found, …)
|
|
||||||
# propagate from AuthErrorCode, not get flattened into one
|
|
||||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
|
||||||
# bubble up, so we catch and render it as JSONResponse here.
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
if internal_user is not None:
|
|
||||||
user = internal_user
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
except HTTPException as exc:
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
|
||||||
# None" branch short-circuits instead of running the entire
|
|
||||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
|
||||||
request.state.user = user
|
|
||||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
|
||||||
token = set_current_user(user)
|
|
||||||
try:
|
|
||||||
return await call_next(request)
|
|
||||||
finally:
|
|
||||||
reset_current_user(token)
|
|
||||||
@@ -1,301 +0,0 @@
|
|||||||
"""Authorization decorators and context for DeerFlow.
|
|
||||||
|
|
||||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
|
|
||||||
1. Use ``@require_auth`` on routes that need authentication
|
|
||||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
|
||||||
3. The decorator chain processes from bottom to top
|
|
||||||
|
|
||||||
**Example:**
|
|
||||||
|
|
||||||
@router.get("/{thread_id}")
|
|
||||||
@require_auth
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread(thread_id: str, request: Request):
|
|
||||||
# User is authenticated and has threads:read permission
|
|
||||||
...
|
|
||||||
|
|
||||||
**Permission Model:**
|
|
||||||
|
|
||||||
- threads:read - View thread
|
|
||||||
- threads:write - Create/update thread
|
|
||||||
- threads:delete - Delete thread
|
|
||||||
- runs:create - Run agent
|
|
||||||
- runs:read - View run
|
|
||||||
- runs:cancel - Cancel run
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import inspect
|
|
||||||
from collections.abc import Callable
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
# Permission constants
|
|
||||||
class Permissions:
|
|
||||||
"""Permission constants for resource:action format."""
|
|
||||||
|
|
||||||
# Threads
|
|
||||||
THREADS_READ = "threads:read"
|
|
||||||
THREADS_WRITE = "threads:write"
|
|
||||||
THREADS_DELETE = "threads:delete"
|
|
||||||
|
|
||||||
# Runs
|
|
||||||
RUNS_CREATE = "runs:create"
|
|
||||||
RUNS_READ = "runs:read"
|
|
||||||
RUNS_CANCEL = "runs:cancel"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthContext:
|
|
||||||
"""Authentication context for the current request.
|
|
||||||
|
|
||||||
Stored in request.state.auth after require_auth decoration.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
user: The authenticated user, or None if anonymous
|
|
||||||
permissions: List of permission strings (e.g., "threads:read")
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("user", "permissions")
|
|
||||||
|
|
||||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
|
||||||
self.user = user
|
|
||||||
self.permissions = permissions or []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_authenticated(self) -> bool:
|
|
||||||
"""Check if user is authenticated."""
|
|
||||||
return self.user is not None
|
|
||||||
|
|
||||||
def has_permission(self, resource: str, action: str) -> bool:
|
|
||||||
"""Check if context has permission for resource:action.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource: Resource name (e.g., "threads")
|
|
||||||
action: Action name (e.g., "read")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if user has permission
|
|
||||||
"""
|
|
||||||
permission = f"{resource}:{action}"
|
|
||||||
return permission in self.permissions
|
|
||||||
|
|
||||||
def require_user(self) -> User:
|
|
||||||
"""Get user or raise 401.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 401 if not authenticated
|
|
||||||
"""
|
|
||||||
if not self.user:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
return self.user
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_context(request: Request) -> AuthContext | None:
|
|
||||||
"""Get AuthContext from request state."""
|
|
||||||
return getattr(request.state, "auth", None)
|
|
||||||
|
|
||||||
|
|
||||||
_ALL_PERMISSIONS: list[str] = [
|
|
||||||
Permissions.THREADS_READ,
|
|
||||||
Permissions.THREADS_WRITE,
|
|
||||||
Permissions.THREADS_DELETE,
|
|
||||||
Permissions.RUNS_CREATE,
|
|
||||||
Permissions.RUNS_READ,
|
|
||||||
Permissions.RUNS_CANCEL,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _make_test_request_stub() -> Any:
|
|
||||||
"""Create a minimal request-like object for direct unit calls.
|
|
||||||
|
|
||||||
Used when decorated route handlers are invoked without FastAPI's
|
|
||||||
request injection. Includes fields accessed by auth helpers.
|
|
||||||
"""
|
|
||||||
return SimpleNamespace(state=SimpleNamespace(), cookies={}, _deerflow_test_bypass_auth=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _authenticate(request: Request) -> AuthContext:
|
|
||||||
"""Authenticate request and return AuthContext.
|
|
||||||
|
|
||||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
|
||||||
Returns AuthContext with user=None for anonymous requests.
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_optional_user_from_request
|
|
||||||
|
|
||||||
user = await get_optional_user_from_request(request)
|
|
||||||
if user is None:
|
|
||||||
return AuthContext(user=None, permissions=[])
|
|
||||||
|
|
||||||
# In future, permissions could be stored in user record
|
|
||||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
|
||||||
|
|
||||||
|
|
||||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
"""Decorator that authenticates the request and enforces authentication.
|
|
||||||
|
|
||||||
Independently raises HTTP 401 for unauthenticated requests, regardless of
|
|
||||||
whether ``AuthMiddleware`` is present in the ASGI stack. Sets the resolved
|
|
||||||
``AuthContext`` on ``request.state.auth`` for downstream handlers.
|
|
||||||
|
|
||||||
Must be placed ABOVE other decorators (executes after them).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@router.get("/{thread_id}")
|
|
||||||
@require_auth # Bottom decorator (executes first after permission check)
|
|
||||||
@require_permission("threads", "read")
|
|
||||||
async def get_thread(thread_id: str, request: Request):
|
|
||||||
auth: AuthContext = request.state.auth
|
|
||||||
...
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 401 if the request is unauthenticated.
|
|
||||||
ValueError: If 'request' parameter is missing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
request = kwargs.get("request")
|
|
||||||
if request is None:
|
|
||||||
# Unit tests may call decorated handlers directly without a
|
|
||||||
# FastAPI Request object. Inject a minimal request stub when
|
|
||||||
# the wrapped function declares `request`.
|
|
||||||
if "request" in inspect.signature(func).parameters:
|
|
||||||
kwargs["request"] = _make_test_request_stub()
|
|
||||||
else:
|
|
||||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
|
||||||
request = kwargs["request"]
|
|
||||||
|
|
||||||
if getattr(request, "_deerflow_test_bypass_auth", False):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Authenticate and set context
|
|
||||||
auth_context = await _authenticate(request)
|
|
||||||
request.state.auth = auth_context
|
|
||||||
|
|
||||||
if not auth_context.is_authenticated:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def require_permission(
|
|
||||||
resource: str,
|
|
||||||
action: str,
|
|
||||||
owner_check: bool = False,
|
|
||||||
require_existing: bool = False,
|
|
||||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
||||||
"""Decorator that checks permission for resource:action.
|
|
||||||
|
|
||||||
Must be used AFTER @require_auth.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource: Resource name (e.g., "threads", "runs")
|
|
||||||
action: Action name (e.g., "read", "write", "delete")
|
|
||||||
owner_check: If True, validates that the current user owns the resource.
|
|
||||||
Requires 'thread_id' path parameter and performs ownership check.
|
|
||||||
require_existing: Only meaningful with ``owner_check=True``. If True, a
|
|
||||||
missing ``threads_meta`` row counts as a denial (404)
|
|
||||||
instead of "untracked legacy thread, allow". Use on
|
|
||||||
**destructive / mutating** routes (DELETE, PATCH,
|
|
||||||
state-update) so a deleted thread can't be re-targeted
|
|
||||||
by another user via the missing-row code path.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Read-style: legacy untracked threads are allowed
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread(thread_id: str, request: Request):
|
|
||||||
...
|
|
||||||
|
|
||||||
# Destructive: thread row MUST exist and be owned by caller
|
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_thread(thread_id: str, request: Request):
|
|
||||||
...
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 401: If authentication required but user is anonymous
|
|
||||||
HTTPException 403: If user lacks permission
|
|
||||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
|
||||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
request = kwargs.get("request")
|
|
||||||
if request is None:
|
|
||||||
# Unit tests may call decorated route handlers directly without
|
|
||||||
# constructing a FastAPI Request object. Inject a minimal stub
|
|
||||||
# when the wrapped function declares `request`.
|
|
||||||
if "request" in inspect.signature(func).parameters:
|
|
||||||
kwargs["request"] = _make_test_request_stub()
|
|
||||||
else:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
request = kwargs["request"]
|
|
||||||
|
|
||||||
if getattr(request, "_deerflow_test_bypass_auth", False):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
auth: AuthContext = getattr(request.state, "auth", None)
|
|
||||||
if auth is None:
|
|
||||||
auth = await _authenticate(request)
|
|
||||||
request.state.auth = auth
|
|
||||||
|
|
||||||
if not auth.is_authenticated:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
|
|
||||||
# Check permission
|
|
||||||
if not auth.has_permission(resource, action):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=f"Permission denied: {resource}:{action}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Owner check for thread-specific resources.
|
|
||||||
#
|
|
||||||
# 2.0-rc moved thread metadata into the SQL persistence layer
|
|
||||||
# (``threads_meta`` table). We verify ownership via
|
|
||||||
# ``ThreadMetaStore.check_access``: it returns True for
|
|
||||||
# missing rows (untracked legacy thread) and for rows whose
|
|
||||||
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
|
||||||
# strict-deny rather than strict-allow — only an *existing*
|
|
||||||
# row with a *different* user_id triggers 404.
|
|
||||||
if owner_check:
|
|
||||||
thread_id = kwargs.get("thread_id")
|
|
||||||
if thread_id is None:
|
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
|
||||||
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
allowed = await thread_store.check_access(
|
|
||||||
thread_id,
|
|
||||||
str(auth.user.id),
|
|
||||||
require_existing=require_existing,
|
|
||||||
)
|
|
||||||
if not allowed:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Thread {thread_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .lifespan import lifespan_manager
|
||||||
|
|
||||||
|
__all__ = ["lifespan_manager"]
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]]
|
||||||
|
|
||||||
|
|
||||||
|
class LifespanManager:
|
||||||
|
"""FastAPI lifespan manager"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lifespans: list[LifespanFunc] = []
|
||||||
|
|
||||||
|
def register(self, func: LifespanFunc) -> LifespanFunc:
|
||||||
|
"""
|
||||||
|
Register a lifespan hook.
|
||||||
|
|
||||||
|
:param func: lifespan hook
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if func not in self._lifespans:
|
||||||
|
self._lifespans.append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
def build(self) -> LifespanFunc:
|
||||||
|
"""
|
||||||
|
Build the combined lifespan hook.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def combined_lifespan(app: FastAPI): # noqa: ANN202
|
||||||
|
state: dict[str, Any] = {}
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
|
for lifespan_fn in self._lifespans:
|
||||||
|
result = await exit_stack.enter_async_context(lifespan_fn(app))
|
||||||
|
if isinstance(result, dict):
|
||||||
|
state.update(result)
|
||||||
|
|
||||||
|
for key, value in state.items():
|
||||||
|
setattr(app.state, key, value)
|
||||||
|
|
||||||
|
yield state or None
|
||||||
|
|
||||||
|
return combined_lifespan
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton lifespan_manager instance
|
||||||
|
lifespan_manager = LifespanManager()
|
||||||
@@ -8,7 +8,7 @@ class GatewayConfig(BaseModel):
|
|||||||
|
|
||||||
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
||||||
port: int = Field(default=8001, description="Port to bind the gateway server")
|
port: int = Field(default=8001, description="Port to bind the gateway server")
|
||||||
enable_docs: bool = Field(default=True, description="Enable Swagger/ReDoc/OpenAPI endpoints")
|
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
|
||||||
|
|
||||||
|
|
||||||
_gateway_config: GatewayConfig | None = None
|
_gateway_config: GatewayConfig | None = None
|
||||||
@@ -18,9 +18,10 @@ def get_gateway_config() -> GatewayConfig:
|
|||||||
"""Get gateway config, loading from environment if available."""
|
"""Get gateway config, loading from environment if available."""
|
||||||
global _gateway_config
|
global _gateway_config
|
||||||
if _gateway_config is None:
|
if _gateway_config is None:
|
||||||
|
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
_gateway_config = GatewayConfig(
|
_gateway_config = GatewayConfig(
|
||||||
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
||||||
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
||||||
enable_docs=os.getenv("GATEWAY_ENABLE_DOCS", "true").lower() == "true",
|
cors_origins=cors_origins_str.split(","),
|
||||||
)
|
)
|
||||||
return _gateway_config
|
return _gateway_config
|
||||||
|
|||||||
@@ -1,229 +0,0 @@
|
|||||||
"""CSRF protection middleware for FastAPI.
|
|
||||||
|
|
||||||
Per RFC-001:
|
|
||||||
State-changing operations require CSRF protection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import secrets
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from urllib.parse import urlsplit
|
|
||||||
|
|
||||||
from fastapi import Request, Response
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
|
|
||||||
CSRF_COOKIE_NAME = "csrf_token"
|
|
||||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
|
||||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
|
||||||
|
|
||||||
|
|
||||||
def is_secure_request(request: Request) -> bool:
|
|
||||||
"""Detect whether the original client request was made over HTTPS."""
|
|
||||||
return _request_scheme(request) == "https"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_csrf_token() -> str:
|
|
||||||
"""Generate a secure random CSRF token."""
|
|
||||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
|
||||||
|
|
||||||
|
|
||||||
def should_check_csrf(request: Request) -> bool:
|
|
||||||
"""Determine if a request needs CSRF validation.
|
|
||||||
|
|
||||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
|
||||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
|
||||||
"""
|
|
||||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
|
||||||
return False
|
|
||||||
|
|
||||||
path = request.url.path.rstrip("/")
|
|
||||||
# Exempt /api/v1/auth/me endpoint
|
|
||||||
if path == "/api/v1/auth/me":
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"/api/v1/auth/login/local",
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_auth_endpoint(request: Request) -> bool:
|
|
||||||
"""Check if the request is to an auth endpoint.
|
|
||||||
|
|
||||||
Auth endpoints don't need CSRF validation on first call (no token).
|
|
||||||
"""
|
|
||||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
|
||||||
|
|
||||||
|
|
||||||
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
|
|
||||||
"""Return normalized host[:port], omitting default ports."""
|
|
||||||
host = hostname.lower()
|
|
||||||
if ":" in host and not host.startswith("["):
|
|
||||||
host = f"[{host}]"
|
|
||||||
|
|
||||||
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
|
|
||||||
return host
|
|
||||||
return f"{host}:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_origin(origin: str) -> str | None:
|
|
||||||
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
|
|
||||||
try:
|
|
||||||
parsed = urlsplit(origin.strip())
|
|
||||||
port = parsed.port
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
scheme = parsed.scheme.lower()
|
|
||||||
if scheme not in {"http", "https"} or not parsed.hostname:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
|
|
||||||
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
|
|
||||||
|
|
||||||
|
|
||||||
def _configured_cors_origins() -> set[str]:
|
|
||||||
"""Return explicit configured browser origins that may call auth routes."""
|
|
||||||
origins = set()
|
|
||||||
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
|
|
||||||
origin = raw_origin.strip()
|
|
||||||
if not origin or origin == "*":
|
|
||||||
continue
|
|
||||||
normalized = _normalize_origin(origin)
|
|
||||||
if normalized:
|
|
||||||
origins.add(normalized)
|
|
||||||
return origins
|
|
||||||
|
|
||||||
|
|
||||||
def get_configured_cors_origins() -> set[str]:
|
|
||||||
"""Return normalized explicit browser origins from GATEWAY_CORS_ORIGINS."""
|
|
||||||
return _configured_cors_origins()
|
|
||||||
|
|
||||||
|
|
||||||
def _first_header_value(value: str | None) -> str | None:
|
|
||||||
"""Return the first value from a comma-separated proxy header."""
|
|
||||||
if not value:
|
|
||||||
return None
|
|
||||||
first = value.split(",", 1)[0].strip()
|
|
||||||
return first or None
|
|
||||||
|
|
||||||
|
|
||||||
def _forwarded_param(request: Request, name: str) -> str | None:
|
|
||||||
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
|
|
||||||
forwarded = _first_header_value(request.headers.get("forwarded"))
|
|
||||||
if not forwarded:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for part in forwarded.split(";"):
|
|
||||||
key, sep, value = part.strip().partition("=")
|
|
||||||
if sep and key.lower() == name:
|
|
||||||
return value.strip().strip('"') or None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _request_scheme(request: Request) -> str:
|
|
||||||
"""Resolve the original request scheme from trusted proxy headers."""
|
|
||||||
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
|
|
||||||
return scheme.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def _request_origin(request: Request) -> str | None:
|
|
||||||
"""Build the origin for the URL the browser is targeting."""
|
|
||||||
scheme = _request_scheme(request)
|
|
||||||
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
|
|
||||||
|
|
||||||
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
|
|
||||||
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
|
|
||||||
host = f"{host}:{forwarded_port}"
|
|
||||||
|
|
||||||
return _normalize_origin(f"{scheme}://{host}")
|
|
||||||
|
|
||||||
|
|
||||||
def is_allowed_auth_origin(request: Request) -> bool:
|
|
||||||
"""Allow auth POSTs only from the same origin or explicit configured origins.
|
|
||||||
|
|
||||||
Login/register/initialize are exempt from the double-submit token because
|
|
||||||
first-time browser clients do not have a CSRF token yet. They still create
|
|
||||||
a session cookie, so browser requests with a hostile Origin header must be
|
|
||||||
rejected to prevent login CSRF / session fixation. Requests without Origin
|
|
||||||
are allowed for non-browser clients such as curl and mobile integrations.
|
|
||||||
"""
|
|
||||||
origin = request.headers.get("origin")
|
|
||||||
if not origin:
|
|
||||||
return True
|
|
||||||
|
|
||||||
normalized_origin = _normalize_origin(origin)
|
|
||||||
if normalized_origin is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
request_origin = _request_origin(request)
|
|
||||||
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
|
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
super().__init__(app)
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
|
||||||
_is_auth = is_auth_endpoint(request)
|
|
||||||
|
|
||||||
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "Cross-site auth request denied."},
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_check_csrf(request) and not _is_auth:
|
|
||||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
|
||||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
|
||||||
|
|
||||||
if not cookie_token or not header_token:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(cookie_token, header_token):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "CSRF token mismatch."},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# For auth endpoints that set up session, also set CSRF cookie
|
|
||||||
if _is_auth and request.method == "POST":
|
|
||||||
# Generate a new CSRF token for the session
|
|
||||||
csrf_token = generate_csrf_token()
|
|
||||||
is_https = is_secure_request(request)
|
|
||||||
response.set_cookie(
|
|
||||||
key=CSRF_COOKIE_NAME,
|
|
||||||
value=csrf_token,
|
|
||||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
|
||||||
secure=is_https,
|
|
||||||
samesite="strict",
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_csrf_token(request: Request) -> str | None:
|
|
||||||
"""Get the CSRF token from the current request's cookies.
|
|
||||||
|
|
||||||
This is useful for server-side rendering where you need to embed
|
|
||||||
token in forms or headers.
|
|
||||||
"""
|
|
||||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
from app.gateway.dependencies.checkpointer import (
|
||||||
|
CurrentCheckpointer,
|
||||||
|
get_checkpointer,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.security.dependencies import (
|
||||||
|
CurrentAuthService,
|
||||||
|
CurrentUserRepository,
|
||||||
|
get_auth_service,
|
||||||
|
get_current_user_from_request,
|
||||||
|
get_current_user_id,
|
||||||
|
get_optional_user_from_request,
|
||||||
|
get_user_repository,
|
||||||
|
)
|
||||||
|
from app.gateway.dependencies.db import (
|
||||||
|
CurrentSession,
|
||||||
|
CurrentSessionTransaction,
|
||||||
|
get_db_session,
|
||||||
|
get_db_session_transaction,
|
||||||
|
)
|
||||||
|
from app.gateway.dependencies.repositories import (
|
||||||
|
CurrentFeedbackRepository,
|
||||||
|
CurrentRunRepository,
|
||||||
|
CurrentThreadMetaRepository,
|
||||||
|
CurrentThreadMetaStorage,
|
||||||
|
get_feedback_repository,
|
||||||
|
get_run_repository,
|
||||||
|
get_thread_meta_repository,
|
||||||
|
get_thread_meta_storage,
|
||||||
|
)
|
||||||
|
from app.gateway.dependencies.stream_bridge import (
|
||||||
|
CurrentStreamBridge,
|
||||||
|
get_stream_bridge,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CurrentCheckpointer",
|
||||||
|
"CurrentAuthService",
|
||||||
|
"CurrentFeedbackRepository",
|
||||||
|
"CurrentRunRepository",
|
||||||
|
"CurrentSession",
|
||||||
|
"CurrentSessionTransaction",
|
||||||
|
"CurrentStreamBridge",
|
||||||
|
"CurrentThreadMetaRepository",
|
||||||
|
"CurrentThreadMetaStorage",
|
||||||
|
"CurrentUserRepository",
|
||||||
|
"get_auth_service",
|
||||||
|
"get_checkpointer",
|
||||||
|
"get_current_user_from_request",
|
||||||
|
"get_current_user_id",
|
||||||
|
"get_db_session",
|
||||||
|
"get_db_session_transaction",
|
||||||
|
"get_feedback_repository",
|
||||||
|
"get_optional_user_from_request",
|
||||||
|
"get_run_repository",
|
||||||
|
"get_stream_bridge",
|
||||||
|
"get_thread_meta_repository",
|
||||||
|
"get_thread_meta_storage",
|
||||||
|
"get_user_repository",
|
||||||
|
]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpointer(request: Request) -> Checkpointer:
|
||||||
|
"""Get checkpointer from app.state.persistence."""
|
||||||
|
persistence = getattr(request.app.state, "persistence", None)
|
||||||
|
if persistence is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Persistence not available")
|
||||||
|
checkpointer = getattr(persistence, "checkpointer", None)
|
||||||
|
if checkpointer is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||||
|
return checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
CurrentCheckpointer = Annotated[Checkpointer, Depends(get_checkpointer)]
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession]:
|
||||||
|
factory = getattr(request.app.state.persistence, "session_factory", None)
|
||||||
|
if factory is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Database session factory not available")
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||||
|
"""Open a session without auto-commit. Use for read-only endpoints."""
|
||||||
|
session_factory = _get_session_factory(request)
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_session_transaction(request: Request) -> AsyncIterator[AsyncSession]:
|
||||||
|
"""Open a session and commit on success, rollback on error."""
|
||||||
|
session_factory = _get_session_factory(request)
|
||||||
|
async with session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
CurrentSession = Annotated[AsyncSession, Depends(get_db_session)]
|
||||||
|
CurrentSessionTransaction = Annotated[AsyncSession, Depends(get_db_session_transaction)]
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from store.repositories.contracts import (
|
||||||
|
FeedbackRepositoryProtocol,
|
||||||
|
RunRepositoryProtocol,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_state(request: Request, attr: str, label: str):
|
||||||
|
value = getattr(request.app.state, attr, None)
|
||||||
|
if value is None:
|
||||||
|
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_run_repository(request: Request) -> RunRepositoryProtocol:
|
||||||
|
return _require_state(request, "run_store", "Run store")
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_meta_repository(request: Request) -> ThreadMetaRepositoryProtocol:
|
||||||
|
return _require_state(request, "thread_meta_repo", "Thread metadata store")
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_meta_storage(request: Request) -> ThreadMetaStorage:
|
||||||
|
return _require_state(request, "thread_meta_storage", "Thread metadata storage")
|
||||||
|
|
||||||
|
|
||||||
|
def get_feedback_repository(request: Request) -> FeedbackRepositoryProtocol:
|
||||||
|
return _require_state(request, "feedback_repo", "Feedback")
|
||||||
|
|
||||||
|
|
||||||
|
CurrentRunRepository = Annotated[RunRepositoryProtocol, Depends(get_run_repository)]
|
||||||
|
CurrentThreadMetaRepository = Annotated[ThreadMetaRepositoryProtocol, Depends(get_thread_meta_repository)]
|
||||||
|
CurrentThreadMetaStorage = Annotated[ThreadMetaStorage, Depends(get_thread_meta_storage)]
|
||||||
|
CurrentFeedbackRepository = Annotated[FeedbackRepositoryProtocol, Depends(get_feedback_repository)]
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from deerflow.runtime import StreamBridge
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||||
|
"""Get stream bridge from app.state."""
|
||||||
|
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||||
|
if bridge is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||||
|
return bridge
|
||||||
|
|
||||||
|
|
||||||
|
CurrentStreamBridge = Annotated[StreamBridge, Depends(get_stream_bridge)]
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
|
||||||
|
|
||||||
**Getters** (used by routers): raise 503 when a required dependency is
|
|
||||||
missing, except ``get_store`` which returns ``None``.
|
|
||||||
|
|
||||||
``AppConfig`` is intentionally *not* cached on ``app.state``. Routers and the
|
|
||||||
run path resolve it through :func:`deerflow.config.app_config.get_app_config`,
|
|
||||||
which performs mtime-based hot reload, so edits to ``config.yaml`` take
|
|
||||||
effect on the next request without a process restart. The engines created in
|
|
||||||
:func:`langgraph_runtime` (stream bridge, persistence, checkpointer, store,
|
|
||||||
run-event store) accept a ``startup_config`` snapshot — they are
|
|
||||||
restart-required by design and stay bound to that snapshot to keep the live
|
|
||||||
process consistent with itself.
|
|
||||||
|
|
||||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
|
||||||
from typing import TYPE_CHECKING, TypeVar, cast
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
from langgraph.types import Checkpointer
|
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig, get_app_config
|
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
|
||||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> AppConfig:
|
|
||||||
"""Return the freshest ``AppConfig`` for the current request.
|
|
||||||
|
|
||||||
Routes through :func:`deerflow.config.app_config.get_app_config`, which
|
|
||||||
honours runtime ``ContextVar`` overrides and reloads ``config.yaml`` from
|
|
||||||
disk when its mtime changes. ``AppConfig`` is not cached on ``app.state``
|
|
||||||
at all — the only startup-time snapshot lives as a local
|
|
||||||
``startup_config`` variable inside ``lifespan()`` and is passed
|
|
||||||
explicitly into :func:`langgraph_runtime` for the engines that are
|
|
||||||
restart-required by design. Routing every request through
|
|
||||||
:func:`get_app_config` closes the bytedance/deer-flow issue #3107 BUG-001
|
|
||||||
split-brain where the worker / lead-agent thread saw a stale startup
|
|
||||||
snapshot.
|
|
||||||
|
|
||||||
Any failure to materialise the config (missing file, permission denied,
|
|
||||||
YAML parse error, validation error) is reported as 503 — semantically
|
|
||||||
"the gateway cannot serve requests without a usable configuration" — and
|
|
||||||
logged with the original exception so operators have something to debug.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return get_app_config()
|
|
||||||
except Exception as exc: # noqa: BLE001 - request boundary: log and degrade gracefully
|
|
||||||
logger.exception("Failed to load AppConfig at request time")
|
|
||||||
raise HTTPException(status_code=503, detail="Configuration not available") from exc
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGenerator[None, None]:
|
|
||||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
|
||||||
|
|
||||||
``startup_config`` is the ``AppConfig`` snapshot taken once during
|
|
||||||
``lifespan()`` for one-shot infrastructure bootstrap. The engines and
|
|
||||||
stores constructed here (stream bridge, persistence engine, checkpointer,
|
|
||||||
store, run-event store) are restart-required by design — they hold live
|
|
||||||
connections, file handles, or singleton providers — so they bind to this
|
|
||||||
snapshot and survive across `config.yaml` edits. Request-time consumers
|
|
||||||
must still go through :func:`get_config` for any field that should be
|
|
||||||
hot-reloadable. See ``backend/CLAUDE.md`` "Config Hot-Reload Boundary".
|
|
||||||
|
|
||||||
The matching ``run_events_config`` is frozen onto ``app.state`` so
|
|
||||||
:func:`get_run_context` pairs a freshly-loaded ``AppConfig`` with the
|
|
||||||
*startup-time* run-events configuration the underlying ``event_store``
|
|
||||||
was built from — otherwise the runtime could end up combining a live
|
|
||||||
new ``run_events_config`` with an event store still bound to the
|
|
||||||
previous backend.
|
|
||||||
|
|
||||||
Usage in ``app.py``::
|
|
||||||
|
|
||||||
async with langgraph_runtime(app, startup_config):
|
|
||||||
yield
|
|
||||||
"""
|
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
|
||||||
from deerflow.runtime.events.store import make_run_event_store
|
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
|
||||||
config = startup_config
|
|
||||||
|
|
||||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
|
|
||||||
|
|
||||||
# Initialize persistence engine BEFORE checkpointer so that
|
|
||||||
# auto-create-database logic runs first (postgres backend).
|
|
||||||
await init_engine_from_config(config.database)
|
|
||||||
|
|
||||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
|
|
||||||
app.state.store = await stack.enter_async_context(make_store(config))
|
|
||||||
|
|
||||||
# Initialize repositories — one get_session_factory() call for all.
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is not None:
|
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
|
||||||
from deerflow.persistence.run import RunRepository
|
|
||||||
|
|
||||||
app.state.run_store = RunRepository(sf)
|
|
||||||
app.state.feedback_repo = FeedbackRepository(sf)
|
|
||||||
else:
|
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
|
||||||
|
|
||||||
app.state.run_store = MemoryRunStore()
|
|
||||||
app.state.feedback_repo = None
|
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta import make_thread_store
|
|
||||||
|
|
||||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
|
||||||
|
|
||||||
# Run event store. The store and the matching ``run_events_config`` are
|
|
||||||
# both frozen at startup so ``get_run_context`` does not combine a
|
|
||||||
# freshly-reloaded ``AppConfig.run_events`` with a store still bound to
|
|
||||||
# the previous backend.
|
|
||||||
run_events_config = getattr(config, "run_events", None)
|
|
||||||
app.state.run_events_config = run_events_config
|
|
||||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
|
||||||
|
|
||||||
# RunManager with store backing for persistence
|
|
||||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
await close_engine()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Getters – called by routers per-request
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _require(attr: str, label: str) -> Callable[[Request], T]:
|
|
||||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
|
||||||
|
|
||||||
def dep(request: Request) -> T:
|
|
||||||
val = getattr(request.app.state, attr, None)
|
|
||||||
if val is None:
|
|
||||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
|
||||||
return cast(T, val)
|
|
||||||
|
|
||||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
|
||||||
return dep
|
|
||||||
|
|
||||||
|
|
||||||
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
|
|
||||||
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
|
|
||||||
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
|
|
||||||
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
|
|
||||||
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
|
|
||||||
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
|
|
||||||
|
|
||||||
|
|
||||||
def get_store(request: Request):
|
|
||||||
"""Return the global store (may be ``None`` if not configured)."""
|
|
||||||
return getattr(request.app.state, "store", None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_thread_store(request: Request) -> ThreadMetaStore:
|
|
||||||
"""Return the thread metadata store (SQL or memory-backed)."""
|
|
||||||
val = getattr(request.app.state, "thread_store", None)
|
|
||||||
if val is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def get_run_context(request: Request) -> RunContext:
|
|
||||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
|
||||||
|
|
||||||
Returns a *base* context with infrastructure dependencies. The
|
|
||||||
``app_config`` field is resolved live so per-run fields (e.g.
|
|
||||||
``models[*].max_tokens``) follow ``config.yaml`` edits; the
|
|
||||||
``event_store`` / ``run_events_config`` pair stays frozen to the snapshot
|
|
||||||
captured in :func:`langgraph_runtime` so callers never see a store bound
|
|
||||||
to one backend paired with a config pointing at another.
|
|
||||||
"""
|
|
||||||
return RunContext(
|
|
||||||
checkpointer=get_checkpointer(request),
|
|
||||||
store=get_store(request),
|
|
||||||
event_store=get_run_event_store(request),
|
|
||||||
run_events_config=getattr(request.app.state, "run_events_config", None),
|
|
||||||
thread_store=get_thread_store(request),
|
|
||||||
app_config=get_config(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Auth helpers (used by authz.py and auth middleware)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Cached singletons to avoid repeated instantiation per request
|
|
||||||
_cached_local_provider: LocalAuthProvider | None = None
|
|
||||||
_cached_repo: SQLiteUserRepository | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_provider() -> LocalAuthProvider:
|
|
||||||
"""Get or create the cached LocalAuthProvider singleton.
|
|
||||||
|
|
||||||
Must be called after ``init_engine_from_config()`` — the shared
|
|
||||||
session factory is required to construct the user repository.
|
|
||||||
"""
|
|
||||||
global _cached_local_provider, _cached_repo
|
|
||||||
if _cached_repo is None:
|
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
|
||||||
from deerflow.persistence.engine import get_session_factory
|
|
||||||
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
|
|
||||||
_cached_repo = SQLiteUserRepository(sf)
|
|
||||||
if _cached_local_provider is None:
|
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
|
||||||
|
|
||||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
|
||||||
return _cached_local_provider
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user_from_request(request: Request):
|
|
||||||
"""Get the current authenticated user from the request cookie.
|
|
||||||
|
|
||||||
Raises HTTPException 401 if not authenticated.
|
|
||||||
"""
|
|
||||||
from app.gateway.auth import decode_token
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
|
||||||
|
|
||||||
access_token = request.cookies.get("access_token")
|
|
||||||
if not access_token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = decode_token(access_token)
|
|
||||||
if isinstance(payload, TokenError):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = get_local_provider()
|
|
||||||
user = await provider.get_user(payload.sub)
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Token version mismatch → password was changed, token is stale
|
|
||||||
if user.token_version != payload.ver:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def get_optional_user_from_request(request: Request):
|
|
||||||
"""Get optional authenticated user from request.
|
|
||||||
|
|
||||||
Returns None if not authenticated.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return await get_current_user_from_request(request)
|
|
||||||
except HTTPException:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> str | None:
|
|
||||||
"""Extract user_id from request cookie, or None if not authenticated.
|
|
||||||
|
|
||||||
Thin adapter that returns the string id for callers that only need
|
|
||||||
identification (e.g., ``feedback.py``). Full-user callers should use
|
|
||||||
``get_current_user_from_request`` or ``get_optional_user_from_request``.
|
|
||||||
"""
|
|
||||||
user = await get_optional_user_from_request(request)
|
|
||||||
return str(user.id) if user else None
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""Process-local authentication for Gateway internal callers."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
|
||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
|
||||||
_INTERNAL_AUTH_TOKEN = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
def create_internal_auth_headers() -> dict[str, str]:
|
|
||||||
"""Return headers that authenticate same-process Gateway internal calls."""
|
|
||||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
|
||||||
"""Return True when *token* matches the process-local internal token."""
|
|
||||||
return bool(token) and secrets.compare_digest(token, _INTERNAL_AUTH_TOKEN)
|
|
||||||
|
|
||||||
|
|
||||||
def get_internal_user():
|
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
"""LangGraph compatibility auth handler — shares JWT logic with Gateway.
|
|
||||||
|
|
||||||
The default DeerFlow runtime is embedded in the FastAPI Gateway; scripts and
|
|
||||||
Docker deployments do not load this module. It is retained for LangGraph
|
|
||||||
tooling, Studio, or direct LangGraph Server compatibility through
|
|
||||||
``langgraph.json``'s ``auth.path``.
|
|
||||||
|
|
||||||
When that compatibility path is used, this module reuses the same JWT and CSRF
|
|
||||||
rules as Gateway so both modes validate sessions consistently.
|
|
||||||
|
|
||||||
Two layers:
|
|
||||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
|
||||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
|
||||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
|
||||||
"""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
from langgraph_sdk import Auth
|
|
||||||
|
|
||||||
from app.gateway.auth.errors import TokenError
|
|
||||||
from app.gateway.auth.jwt import decode_token
|
|
||||||
from app.gateway.deps import get_local_provider
|
|
||||||
|
|
||||||
auth = Auth()
|
|
||||||
|
|
||||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
|
||||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
|
||||||
|
|
||||||
|
|
||||||
def _check_csrf(request) -> None:
|
|
||||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
|
||||||
|
|
||||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
|
||||||
proxied directly by nginx have the same CSRF protection.
|
|
||||||
"""
|
|
||||||
method = getattr(request, "method", "") or ""
|
|
||||||
if method.upper() not in _CSRF_METHODS:
|
|
||||||
return
|
|
||||||
|
|
||||||
cookie_token = request.cookies.get("csrf_token")
|
|
||||||
header_token = request.headers.get("x-csrf-token")
|
|
||||||
|
|
||||||
if not cookie_token or not header_token:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(cookie_token, header_token):
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="CSRF token mismatch.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@auth.authenticate
|
|
||||||
async def authenticate(request):
|
|
||||||
"""Validate the session cookie, decode JWT, and check token_version.
|
|
||||||
|
|
||||||
Same validation chain as Gateway's get_current_user_from_request:
|
|
||||||
cookie → decode JWT → DB lookup → token_version match
|
|
||||||
Also enforces CSRF on state-changing methods.
|
|
||||||
"""
|
|
||||||
# CSRF check before authentication so forged cross-site requests
|
|
||||||
# are rejected early, even if the cookie carries a valid JWT.
|
|
||||||
_check_csrf(request)
|
|
||||||
|
|
||||||
token = request.cookies.get("access_token")
|
|
||||||
if not token:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Not authenticated",
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = decode_token(token)
|
|
||||||
if isinstance(payload, TokenError):
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Invalid token",
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await get_local_provider().get_user(payload.sub)
|
|
||||||
if user is None:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
if user.token_version != payload.ver:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token revoked (password changed)",
|
|
||||||
)
|
|
||||||
|
|
||||||
return payload.sub
|
|
||||||
|
|
||||||
|
|
||||||
@auth.on
|
|
||||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
|
||||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
|
||||||
|
|
||||||
Gateway stores thread ownership as ``metadata.user_id``.
|
|
||||||
This handler ensures LangGraph Server enforces the same isolation.
|
|
||||||
"""
|
|
||||||
# On create/update: stamp user_id into metadata
|
|
||||||
metadata = value.setdefault("metadata", {})
|
|
||||||
metadata["user_id"] = ctx.user.identity
|
|
||||||
|
|
||||||
# Return filter dict — LangGraph applies it to search/read/delete
|
|
||||||
return {"user_id": ctx.user.identity}
|
|
||||||
@@ -5,16 +5,17 @@ from pathlib import Path
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
|
|
||||||
|
|
||||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||||
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
virtual_path: The virtual path as seen inside the sandbox
|
virtual_path: The virtual path as seen inside the sandbox
|
||||||
(e.g., /mnt/user-data/outputs/file.txt).
|
(e.g., /mnt/user-data/outputs/file.txt).
|
||||||
|
user_id: Explicit user id override. Falls back to the current actor context.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The resolved filesystem path.
|
The resolved filesystem path.
|
||||||
@@ -23,7 +24,8 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
|||||||
HTTPException: If the path is invalid or outside allowed directories.
|
HTTPException: If the path is invalid or outside allowed directories.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
|
resolved_user_id = get_effective_user_id() if user_id is None else user_id
|
||||||
|
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=resolved_user_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
status = 403 if "traversal" in str(e) else 400
|
status = 403 if "traversal" in str(e) else 400
|
||||||
raise HTTPException(status_code=status, detail=str(e))
|
raise HTTPException(status_code=status, detail=str(e))
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from scalar_fastapi import AgentScalarConfig, get_scalar_api_reference
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from store.persistence import create_persistence
|
||||||
|
|
||||||
|
from app.gateway.common import lifespan_manager
|
||||||
|
from app.gateway.router import router as gateway_router
|
||||||
|
from app.infra.run_events import build_run_event_store
|
||||||
|
from app.infra.storage import FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||||
|
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||||
|
from app.plugins.auth.injection import install_route_guards, load_route_policy_registry, validate_route_policy_registry
|
||||||
|
from app.plugins.auth.security import AuthMiddleware, CSRFMiddleware
|
||||||
|
|
||||||
|
STATIC_DIR = Path(__file__).resolve().parents[1] / "static"
|
||||||
|
STATIC_MOUNT = "/api/static"
|
||||||
|
SCALAR_JS_URL = f"{STATIC_MOUNT}/scalar.js"
|
||||||
|
|
||||||
|
|
||||||
|
@lifespan_manager.register
|
||||||
|
@asynccontextmanager
|
||||||
|
async def init_persistence(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Initialize persistence layer (DB, checkpointer, store)."""
|
||||||
|
app_persistence = await create_persistence()
|
||||||
|
|
||||||
|
await app_persistence.setup()
|
||||||
|
run_store = RunStoreAdapter(app_persistence.session_factory)
|
||||||
|
thread_meta_store = ThreadMetaStoreAdapter(app_persistence.session_factory)
|
||||||
|
feedback_store = FeedbackStoreAdapter(app_persistence.session_factory)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield {
|
||||||
|
"persistence": app_persistence,
|
||||||
|
"checkpointer": app_persistence.checkpointer,
|
||||||
|
"store": None,
|
||||||
|
"session_factory": app_persistence.session_factory,
|
||||||
|
"run_store": run_store,
|
||||||
|
"run_read_repo": run_store,
|
||||||
|
"run_write_repo": run_store,
|
||||||
|
"run_delete_repo": run_store,
|
||||||
|
"feedback_repo": feedback_store,
|
||||||
|
"thread_meta_repo": thread_meta_store,
|
||||||
|
"thread_meta_storage": ThreadMetaStorage(thread_meta_store),
|
||||||
|
"run_event_store": build_run_event_store(app_persistence.session_factory),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await app_persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@lifespan_manager.register
|
||||||
|
@asynccontextmanager
|
||||||
|
async def init_runtime(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Initialize StreamBridge for LangGraph-compatible runtime endpoints."""
|
||||||
|
from app.infra.stream_bridge import build_stream_bridge
|
||||||
|
|
||||||
|
async with build_stream_bridge() as stream_bridge:
|
||||||
|
yield {
|
||||||
|
"stream_bridge": stream_bridge,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="DeerFlow API Gateway",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan_manager.build(),
|
||||||
|
openapi_tags=[
|
||||||
|
{
|
||||||
|
"name": "threads",
|
||||||
|
"description": "Endpoints for managing threads, which are conversations between a human and an assistant. A thread can have multiple runs as the conversation progresses."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
app.state.authz_hooks = build_authz_hooks()
|
||||||
|
|
||||||
|
_register_static(app)
|
||||||
|
_register_routes(app)
|
||||||
|
_register_scalar(app)
|
||||||
|
_register_auth_route_policies(app)
|
||||||
|
_register_middlewares(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _register_static(app: FastAPI) -> None:
|
||||||
|
app.mount(STATIC_MOUNT, StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
def _register_routes(app: FastAPI) -> None:
|
||||||
|
app.include_router(gateway_router)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_auth_route_policies(app: FastAPI) -> None:
|
||||||
|
registry = load_route_policy_registry()
|
||||||
|
validate_route_policy_registry(app, registry)
|
||||||
|
app.state.auth_route_policy_registry = registry
|
||||||
|
install_route_guards(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_middlewares(app: FastAPI) -> None:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
expose_headers=["*"],
|
||||||
|
)
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_scalar(app: FastAPI) -> None:
|
||||||
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
def scalar_docs() -> HTMLResponse:
|
||||||
|
return get_scalar_api_reference(
|
||||||
|
openapi_url=app.openapi_url,
|
||||||
|
title=app.title,
|
||||||
|
scalar_js_url=SCALAR_JS_URL,
|
||||||
|
agent=AgentScalarConfig(disabled=True),
|
||||||
|
hide_client_button=True,
|
||||||
|
overrides={"mcp": {"disabled": True}},
|
||||||
|
)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.plugins.auth.api.router import router as auth_router
|
||||||
|
|
||||||
|
from .routers import artifacts, channels, mcp, models, skills, uploads
|
||||||
|
from .routers.agents import router as agents_router
|
||||||
|
from .routers.langgraph import feedback_router, runs_router, suggestion_router, threads_router
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
router.include_router(auth_router)
|
||||||
|
router.include_router(threads_router, prefix="/api/threads")
|
||||||
|
router.include_router(runs_router, prefix="/api/threads")
|
||||||
|
router.include_router(feedback_router, prefix="/api/threads")
|
||||||
|
router.include_router(suggestion_router)
|
||||||
|
router.include_router(agents_router)
|
||||||
|
router.include_router(channels.router)
|
||||||
|
router.include_router(artifacts.router)
|
||||||
|
router.include_router(mcp.router)
|
||||||
|
router.include_router(models.router)
|
||||||
|
router.include_router(skills.router)
|
||||||
|
router.include_router(uploads.router)
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
from . import artifacts, mcp, models, skills, suggestions, uploads
|
||||||
|
|
||||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ import yaml
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from deerflow.config.agents_api_config import get_agents_api_config
|
|
||||||
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api", tags=["agents"])
|
router = APIRouter(prefix="/api", tags=["agents"])
|
||||||
@@ -26,7 +24,6 @@ class AgentResponse(BaseModel):
|
|||||||
description: str = Field(default="", description="Agent description")
|
description: str = Field(default="", description="Agent description")
|
||||||
model: str | None = Field(default=None, description="Optional model override")
|
model: str | None = Field(default=None, description="Optional model override")
|
||||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||||
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all, []=none)")
|
|
||||||
soul: str | None = Field(default=None, description="SOUL.md content")
|
soul: str | None = Field(default=None, description="SOUL.md content")
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +40,6 @@ class AgentCreateRequest(BaseModel):
|
|||||||
description: str = Field(default="", description="Agent description")
|
description: str = Field(default="", description="Agent description")
|
||||||
model: str | None = Field(default=None, description="Optional model override")
|
model: str | None = Field(default=None, description="Optional model override")
|
||||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||||
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all enabled, []=none)")
|
|
||||||
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
|
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +49,6 @@ class AgentUpdateRequest(BaseModel):
|
|||||||
description: str | None = Field(default=None, description="Updated description")
|
description: str | None = Field(default=None, description="Updated description")
|
||||||
model: str | None = Field(default=None, description="Updated model override")
|
model: str | None = Field(default=None, description="Updated model override")
|
||||||
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
|
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
|
||||||
skills: list[str] | None = Field(default=None, description="Updated skill whitelist (None=all, []=none)")
|
|
||||||
soul: str | None = Field(default=None, description="Updated SOUL.md content")
|
soul: str | None = Field(default=None, description="Updated SOUL.md content")
|
||||||
|
|
||||||
|
|
||||||
@@ -78,27 +73,17 @@ def _normalize_agent_name(name: str) -> str:
|
|||||||
return name.lower()
|
return name.lower()
|
||||||
|
|
||||||
|
|
||||||
def _require_agents_api_enabled() -> None:
|
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
|
||||||
"""Reject access unless the custom-agent management API is explicitly enabled."""
|
|
||||||
if not get_agents_api_config().enabled:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
|
|
||||||
"""Convert AgentConfig to AgentResponse."""
|
"""Convert AgentConfig to AgentResponse."""
|
||||||
soul: str | None = None
|
soul: str | None = None
|
||||||
if include_soul:
|
if include_soul:
|
||||||
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
|
soul = load_agent_soul(agent_cfg.name) or ""
|
||||||
|
|
||||||
return AgentResponse(
|
return AgentResponse(
|
||||||
name=agent_cfg.name,
|
name=agent_cfg.name,
|
||||||
description=agent_cfg.description,
|
description=agent_cfg.description,
|
||||||
model=agent_cfg.model,
|
model=agent_cfg.model,
|
||||||
tool_groups=agent_cfg.tool_groups,
|
tool_groups=agent_cfg.tool_groups,
|
||||||
skills=agent_cfg.skills,
|
|
||||||
soul=soul,
|
soul=soul,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,12 +100,9 @@ async def list_agents() -> AgentsListResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
List of all custom agents with their metadata and soul content.
|
List of all custom agents with their metadata and soul content.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
|
|
||||||
user_id = get_effective_user_id()
|
|
||||||
try:
|
try:
|
||||||
agents = list_custom_agents(user_id=user_id)
|
agents = list_custom_agents()
|
||||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
|
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||||
@@ -143,15 +125,9 @@ async def check_agent_name(name: str) -> dict:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 422 if the name is invalid.
|
HTTPException: 422 if the name is invalid.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
normalized = _normalize_agent_name(name)
|
normalized = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
available = not get_paths().agent_dir(normalized).exists()
|
||||||
paths = get_paths()
|
|
||||||
# Treat the name as taken if either the per-user path or the legacy shared
|
|
||||||
# path holds an agent — picking a name that collides with an unmigrated
|
|
||||||
# legacy agent would shadow the legacy entry once migration runs.
|
|
||||||
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
|
|
||||||
return {"available": available, "name": normalized}
|
return {"available": available, "name": normalized}
|
||||||
|
|
||||||
|
|
||||||
@@ -173,14 +149,12 @@ async def get_agent(name: str) -> AgentResponse:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 404 if agent not found.
|
HTTPException: 404 if agent not found.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
agent_cfg = load_agent_config(name)
|
||||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -207,16 +181,12 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 409 if agent already exists, 422 if name is invalid.
|
HTTPException: 409 if agent already exists, 422 if name is invalid.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(request.name)
|
_validate_agent_name(request.name)
|
||||||
normalized_name = _normalize_agent_name(request.name)
|
normalized_name = _normalize_agent_name(request.name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
paths = get_paths()
|
|
||||||
|
|
||||||
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
agent_dir = get_paths().agent_dir(normalized_name)
|
||||||
legacy_dir = paths.agent_dir(normalized_name)
|
|
||||||
|
|
||||||
if agent_dir.exists() or legacy_dir.exists():
|
if agent_dir.exists():
|
||||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -230,8 +200,6 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
config_data["model"] = request.model
|
config_data["model"] = request.model
|
||||||
if request.tool_groups is not None:
|
if request.tool_groups is not None:
|
||||||
config_data["tool_groups"] = request.tool_groups
|
config_data["tool_groups"] = request.tool_groups
|
||||||
if request.skills is not None:
|
|
||||||
config_data["skills"] = request.skills
|
|
||||||
|
|
||||||
config_file = agent_dir / "config.yaml"
|
config_file = agent_dir / "config.yaml"
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
@@ -243,8 +211,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
|
|
||||||
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
||||||
|
|
||||||
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
agent_cfg = load_agent_config(normalized_name)
|
||||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -275,52 +243,33 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 404 if agent not found.
|
HTTPException: 404 if agent not found.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
agent_cfg = load_agent_config(name)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
|
|
||||||
paths = get_paths()
|
agent_dir = get_paths().agent_dir(name)
|
||||||
agent_dir = paths.user_agent_dir(user_id, name)
|
|
||||||
if not agent_dir.exists() and paths.agent_dir(name).exists():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update config if any config fields changed
|
# Update config if any config fields changed
|
||||||
# Use model_fields_set to distinguish "field omitted" from "explicitly set to null".
|
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
|
||||||
# This is critical for skills where None means "inherit all" (not "don't change").
|
|
||||||
fields_set = request.model_fields_set
|
|
||||||
config_changed = bool(fields_set & {"description", "model", "tool_groups", "skills"})
|
|
||||||
|
|
||||||
if config_changed:
|
if config_changed:
|
||||||
updated: dict = {
|
updated: dict = {
|
||||||
"name": agent_cfg.name,
|
"name": agent_cfg.name,
|
||||||
"description": request.description if "description" in fields_set else agent_cfg.description,
|
"description": request.description if request.description is not None else agent_cfg.description,
|
||||||
}
|
}
|
||||||
new_model = request.model if "model" in fields_set else agent_cfg.model
|
new_model = request.model if request.model is not None else agent_cfg.model
|
||||||
if new_model is not None:
|
if new_model is not None:
|
||||||
updated["model"] = new_model
|
updated["model"] = new_model
|
||||||
|
|
||||||
new_tool_groups = request.tool_groups if "tool_groups" in fields_set else agent_cfg.tool_groups
|
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
|
||||||
if new_tool_groups is not None:
|
if new_tool_groups is not None:
|
||||||
updated["tool_groups"] = new_tool_groups
|
updated["tool_groups"] = new_tool_groups
|
||||||
|
|
||||||
# skills: None = inherit all, [] = no skills, ["a","b"] = whitelist
|
|
||||||
if "skills" in fields_set:
|
|
||||||
new_skills = request.skills
|
|
||||||
else:
|
|
||||||
new_skills = agent_cfg.skills
|
|
||||||
if new_skills is not None:
|
|
||||||
updated["skills"] = new_skills
|
|
||||||
|
|
||||||
config_file = agent_dir / "config.yaml"
|
config_file = agent_dir / "config.yaml"
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
|
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
|
||||||
@@ -332,8 +281,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
|||||||
|
|
||||||
logger.info(f"Updated agent '{name}'")
|
logger.info(f"Updated agent '{name}'")
|
||||||
|
|
||||||
refreshed_cfg = load_agent_config(name, user_id=user_id)
|
refreshed_cfg = load_agent_config(name)
|
||||||
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(refreshed_cfg, include_soul=True)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -366,8 +315,6 @@ async def get_user_profile() -> UserProfileResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
UserProfileResponse with content=None if USER.md does not exist yet.
|
UserProfileResponse with content=None if USER.md does not exist yet.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_md_path = get_paths().user_md_file
|
user_md_path = get_paths().user_md_file
|
||||||
if not user_md_path.exists():
|
if not user_md_path.exists():
|
||||||
@@ -394,8 +341,6 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
|
|||||||
Returns:
|
Returns:
|
||||||
UserProfileResponse with the saved content.
|
UserProfileResponse with the saved content.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.base_dir.mkdir(parents=True, exist_ok=True)
|
paths.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -420,22 +365,14 @@ async def delete_agent(name: str) -> None:
|
|||||||
name: The agent name.
|
name: The agent name.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
|
HTTPException: 404 if agent not found.
|
||||||
shared copy exists (suggesting the migration script).
|
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
paths = get_paths()
|
agent_dir = get_paths().agent_dir(name)
|
||||||
agent_dir = paths.user_agent_dir(user_id, name)
|
|
||||||
|
|
||||||
if not agent_dir.exists():
|
if not agent_dir.exists():
|
||||||
if paths.agent_dir(name).exists():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
|
||||||
)
|
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from urllib.parse import quote
|
|||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,9 +19,6 @@ ACTIVE_CONTENT_MIME_TYPES = {
|
|||||||
"image/svg+xml",
|
"image/svg+xml",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
|
|
||||||
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
||||||
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
||||||
@@ -47,22 +43,6 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
|
|
||||||
"""Read a .skill archive member while enforcing an uncompressed size cap."""
|
|
||||||
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
|
||||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
|
||||||
|
|
||||||
chunks: list[bytes] = []
|
|
||||||
total_read = 0
|
|
||||||
with zip_ref.open(info, "r") as src:
|
|
||||||
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
|
|
||||||
total_read += len(chunk)
|
|
||||||
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
|
||||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
|
||||||
chunks.append(chunk)
|
|
||||||
return b"".join(chunks)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
||||||
"""Extract a file from a .skill ZIP archive.
|
"""Extract a file from a .skill ZIP archive.
|
||||||
|
|
||||||
@@ -79,16 +59,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
|||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
# List all files in the archive
|
# List all files in the archive
|
||||||
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
|
namelist = zip_ref.namelist()
|
||||||
|
|
||||||
# Try direct path first
|
# Try direct path first
|
||||||
if internal_path in infos_by_name:
|
if internal_path in namelist:
|
||||||
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
|
return zip_ref.read(internal_path)
|
||||||
|
|
||||||
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
||||||
for name, info in infos_by_name.items():
|
for name in namelist:
|
||||||
if name.endswith("/" + internal_path) or name == internal_path:
|
if name.endswith("/" + internal_path) or name == internal_path:
|
||||||
return _read_skill_archive_member(zip_ref, info)
|
return zip_ref.read(name)
|
||||||
|
|
||||||
# Not found
|
# Not found
|
||||||
return None
|
return None
|
||||||
@@ -101,7 +81,6 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
|||||||
summary="Get Artifact File",
|
summary="Get Artifact File",
|
||||||
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
||||||
)
|
)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
||||||
"""Get an artifact file by its path.
|
"""Get an artifact file by its path.
|
||||||
|
|
||||||
|
|||||||
@@ -1,149 +0,0 @@
|
|||||||
"""Assistants compatibility endpoints.
|
|
||||||
|
|
||||||
Provides LangGraph Platform-compatible assistants API backed by the
|
|
||||||
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
|
|
||||||
|
|
||||||
This is a minimal stub that satisfies the ``useStream`` React hook's
|
|
||||||
initialization requirements (``assistants.search()`` and ``assistants.get()``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantResponse(BaseModel):
|
|
||||||
assistant_id: str
|
|
||||||
graph_id: str
|
|
||||||
name: str
|
|
||||||
config: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
description: str | None = None
|
|
||||||
created_at: str = ""
|
|
||||||
updated_at: str = ""
|
|
||||||
version: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantSearchRequest(BaseModel):
|
|
||||||
graph_id: str | None = None
|
|
||||||
name: str | None = None
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
limit: int = 10
|
|
||||||
offset: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_assistant() -> AssistantResponse:
|
|
||||||
"""Return the default lead_agent assistant."""
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
return AssistantResponse(
|
|
||||||
assistant_id="lead_agent",
|
|
||||||
graph_id="lead_agent",
|
|
||||||
name="lead_agent",
|
|
||||||
config={},
|
|
||||||
metadata={"created_by": "system"},
|
|
||||||
description="DeerFlow lead agent",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
version=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _list_assistants() -> list[AssistantResponse]:
|
|
||||||
"""List all available assistants from config."""
|
|
||||||
assistants = [_get_default_assistant()]
|
|
||||||
|
|
||||||
# Also include custom agents from config.yaml agents directory
|
|
||||||
try:
|
|
||||||
from deerflow.config.agents_config import list_custom_agents
|
|
||||||
|
|
||||||
for agent_cfg in list_custom_agents():
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
assistants.append(
|
|
||||||
AssistantResponse(
|
|
||||||
assistant_id=agent_cfg.name,
|
|
||||||
graph_id="lead_agent", # All agents use the same graph
|
|
||||||
name=agent_cfg.name,
|
|
||||||
config={},
|
|
||||||
metadata={"created_by": "user"},
|
|
||||||
description=agent_cfg.description or "",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
version=1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not load custom agents for assistants list")
|
|
||||||
|
|
||||||
return assistants
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_model=list[AssistantResponse])
|
|
||||||
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
|
|
||||||
"""Search assistants.
|
|
||||||
|
|
||||||
Returns all registered assistants (lead_agent + custom agents from config).
|
|
||||||
"""
|
|
||||||
assistants = _list_assistants()
|
|
||||||
|
|
||||||
if body and body.graph_id:
|
|
||||||
assistants = [a for a in assistants if a.graph_id == body.graph_id]
|
|
||||||
if body and body.name:
|
|
||||||
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
|
|
||||||
|
|
||||||
offset = body.offset if body else 0
|
|
||||||
limit = body.limit if body else 10
|
|
||||||
return assistants[offset : offset + limit]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}", response_model=AssistantResponse)
|
|
||||||
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
|
|
||||||
"""Get an assistant by ID."""
|
|
||||||
for a in _list_assistants():
|
|
||||||
if a.assistant_id == assistant_id:
|
|
||||||
return a
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}/graph")
|
|
||||||
async def get_assistant_graph(assistant_id: str) -> dict:
|
|
||||||
"""Get the graph structure for an assistant.
|
|
||||||
|
|
||||||
Returns a minimal graph description. Full graph introspection is
|
|
||||||
not supported in the Gateway — this stub satisfies SDK validation.
|
|
||||||
"""
|
|
||||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
|
||||||
if not found:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"graph_id": "lead_agent",
|
|
||||||
"nodes": [],
|
|
||||||
"edges": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}/schemas")
|
|
||||||
async def get_assistant_schemas(assistant_id: str) -> dict:
|
|
||||||
"""Get JSON schemas for an assistant's input/output/state.
|
|
||||||
|
|
||||||
Returns empty schemas — full introspection not supported in Gateway.
|
|
||||||
"""
|
|
||||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
|
||||||
if not found:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"graph_id": "lead_agent",
|
|
||||||
"input_schema": {},
|
|
||||||
"output_schema": {},
|
|
||||||
"state_schema": {},
|
|
||||||
"config_schema": {},
|
|
||||||
}
|
|
||||||
@@ -1,527 +0,0 @@
|
|||||||
"""Authentication endpoints."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from ipaddress import ip_address, ip_network
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
|
||||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
|
||||||
|
|
||||||
from app.gateway.auth import (
|
|
||||||
UserResponse,
|
|
||||||
create_access_token,
|
|
||||||
)
|
|
||||||
from app.gateway.auth.config import get_auth_config
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
|
||||||
from app.gateway.csrf_middleware import is_secure_request
|
|
||||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request/Response Models ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class LoginResponse(BaseModel):
|
|
||||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
|
||||||
|
|
||||||
expires_in: int # seconds
|
|
||||||
needs_setup: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
# Top common-password blocklist. Drawn from the public SecLists "10k worst
|
|
||||||
# passwords" set, lowercased + length>=8 only (shorter ones already fail
|
|
||||||
# the min_length check). Kept tight on purpose: this is the **lower bound**
|
|
||||||
# defense, not a full HIBP / passlib check, and runs in-process per request.
|
|
||||||
_COMMON_PASSWORDS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"password",
|
|
||||||
"password1",
|
|
||||||
"password12",
|
|
||||||
"password123",
|
|
||||||
"password1234",
|
|
||||||
"12345678",
|
|
||||||
"123456789",
|
|
||||||
"1234567890",
|
|
||||||
"qwerty12",
|
|
||||||
"qwertyui",
|
|
||||||
"qwerty123",
|
|
||||||
"abc12345",
|
|
||||||
"abcd1234",
|
|
||||||
"iloveyou",
|
|
||||||
"letmein1",
|
|
||||||
"welcome1",
|
|
||||||
"welcome123",
|
|
||||||
"admin123",
|
|
||||||
"administrator",
|
|
||||||
"passw0rd",
|
|
||||||
"p@ssw0rd",
|
|
||||||
"monkey12",
|
|
||||||
"trustno1",
|
|
||||||
"sunshine",
|
|
||||||
"princess",
|
|
||||||
"football",
|
|
||||||
"baseball",
|
|
||||||
"superman",
|
|
||||||
"batman123",
|
|
||||||
"starwars",
|
|
||||||
"dragon123",
|
|
||||||
"master123",
|
|
||||||
"shadow12",
|
|
||||||
"michael1",
|
|
||||||
"jennifer",
|
|
||||||
"computer",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _password_is_common(password: str) -> bool:
|
|
||||||
"""Case-insensitive blocklist check.
|
|
||||||
|
|
||||||
Lowercases the input so trivial mutations like ``Password`` /
|
|
||||||
``PASSWORD`` are also rejected. Does not normalize digit substitutions
|
|
||||||
(``p@ssw0rd`` is included as a literal entry instead) — keeping the
|
|
||||||
rule cheap and predictable.
|
|
||||||
"""
|
|
||||||
return password.lower() in _COMMON_PASSWORDS
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_strong_password(value: str) -> str:
|
|
||||||
"""Pydantic field-validator body shared by Register + ChangePassword.
|
|
||||||
|
|
||||||
Constraint = function, not type-level mixin. The two request models
|
|
||||||
have no "is-a" relationship; they only share the password-strength
|
|
||||||
rule. Lifting it into a free function lets each model bind it via
|
|
||||||
``@field_validator(field_name)`` without inheritance gymnastics.
|
|
||||||
"""
|
|
||||||
if _password_is_common(value):
|
|
||||||
raise ValueError("Password is too common; choose a stronger password.")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequest(BaseModel):
|
|
||||||
"""Request model for user registration."""
|
|
||||||
|
|
||||||
email: EmailStr
|
|
||||||
password: str = Field(..., min_length=8)
|
|
||||||
|
|
||||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class ChangePasswordRequest(BaseModel):
|
|
||||||
"""Request model for password change (also handles setup flow)."""
|
|
||||||
|
|
||||||
current_password: str
|
|
||||||
new_password: str = Field(..., min_length=8)
|
|
||||||
new_email: EmailStr | None = None
|
|
||||||
|
|
||||||
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class MessageResponse(BaseModel):
|
|
||||||
"""Generic message response."""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
|
||||||
"""Set the access_token HttpOnly cookie on the response."""
|
|
||||||
config = get_auth_config()
|
|
||||||
is_https = is_secure_request(request)
|
|
||||||
response.set_cookie(
|
|
||||||
key="access_token",
|
|
||||||
value=token,
|
|
||||||
httponly=True,
|
|
||||||
secure=is_https,
|
|
||||||
samesite="lax",
|
|
||||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
|
||||||
# In-process dict — not shared across workers.
|
|
||||||
#
|
|
||||||
# **Limitation**: with multi-worker deployments (e.g., gunicorn -w N), each
|
|
||||||
# worker maintains its own lockout table, so an attacker effectively gets
|
|
||||||
# N × _MAX_LOGIN_ATTEMPTS guesses before being locked out everywhere. For
|
|
||||||
# production multi-worker setups, replace this with a shared store (Redis,
|
|
||||||
# database-backed counter) to enforce a true per-IP limit.
|
|
||||||
|
|
||||||
_MAX_LOGIN_ATTEMPTS = 5
|
|
||||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
|
||||||
|
|
||||||
# ip → (fail_count, lock_until_timestamp)
|
|
||||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _trusted_proxies() -> list:
|
|
||||||
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
|
|
||||||
|
|
||||||
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
|
|
||||||
trusted (direct mode). Invalid entries are skipped with a logger warning.
|
|
||||||
Read live so env-var overrides take effect immediately and tests can
|
|
||||||
``monkeypatch.setenv`` without poking a module-level cache.
|
|
||||||
"""
|
|
||||||
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
|
|
||||||
if not raw:
|
|
||||||
return []
|
|
||||||
nets = []
|
|
||||||
for entry in raw.split(","):
|
|
||||||
entry = entry.strip()
|
|
||||||
if not entry:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
nets.append(ip_network(entry, strict=False))
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
|
|
||||||
return nets
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip(request: Request) -> str:
|
|
||||||
"""Extract the real client IP for rate limiting.
|
|
||||||
|
|
||||||
Trust model:
|
|
||||||
|
|
||||||
- The TCP peer (``request.client.host``) is always the baseline. It is
|
|
||||||
whatever the kernel reports as the connecting socket — unforgeable
|
|
||||||
by the client itself.
|
|
||||||
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
|
|
||||||
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
|
|
||||||
CIDR or single IPs). When set, the gateway is assumed to be behind a
|
|
||||||
reverse proxy (nginx, Cloudflare, ALB, …) that overwrites
|
|
||||||
``X-Real-IP`` with the original client address.
|
|
||||||
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
|
|
||||||
ignored — closing the bypass where any client could rotate the
|
|
||||||
header to dodge per-IP rate limits in dev / direct-gateway mode.
|
|
||||||
|
|
||||||
``X-Forwarded-For`` is intentionally NOT used because it is naturally
|
|
||||||
client-controlled at the *first* hop and the trust chain is harder to
|
|
||||||
audit per-request.
|
|
||||||
"""
|
|
||||||
peer_host = request.client.host if request.client else None
|
|
||||||
|
|
||||||
trusted = _trusted_proxies()
|
|
||||||
if trusted and peer_host:
|
|
||||||
try:
|
|
||||||
peer_ip = ip_address(peer_host)
|
|
||||||
if any(peer_ip in net for net in trusted):
|
|
||||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
|
||||||
if real_ip:
|
|
||||||
return real_ip
|
|
||||||
except ValueError:
|
|
||||||
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
|
|
||||||
pass
|
|
||||||
|
|
||||||
return peer_host or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_rate_limit(ip: str) -> None:
|
|
||||||
"""Raise 429 if the IP is currently locked out."""
|
|
||||||
record = _login_attempts.get(ip)
|
|
||||||
if record is None:
|
|
||||||
return
|
|
||||||
fail_count, lock_until = record
|
|
||||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
|
||||||
if time.time() < lock_until:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail="Too many login attempts. Try again later.",
|
|
||||||
)
|
|
||||||
del _login_attempts[ip]
|
|
||||||
|
|
||||||
|
|
||||||
_MAX_TRACKED_IPS = 10000
|
|
||||||
|
|
||||||
|
|
||||||
def _record_login_failure(ip: str) -> None:
|
|
||||||
"""Record a failed login attempt for the given IP."""
|
|
||||||
# Evict expired lockouts when dict grows too large
|
|
||||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
|
||||||
now = time.time()
|
|
||||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
|
||||||
for k in expired:
|
|
||||||
del _login_attempts[k]
|
|
||||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
|
||||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
|
||||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
|
||||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
|
||||||
for k, _ in by_time[: len(by_time) // 2]:
|
|
||||||
del _login_attempts[k]
|
|
||||||
|
|
||||||
record = _login_attempts.get(ip)
|
|
||||||
if record is None:
|
|
||||||
_login_attempts[ip] = (1, 0.0)
|
|
||||||
else:
|
|
||||||
new_count = record[0] + 1
|
|
||||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
|
||||||
_login_attempts[ip] = (new_count, lock_until)
|
|
||||||
|
|
||||||
|
|
||||||
def _record_login_success(ip: str) -> None:
|
|
||||||
"""Clear failure counter for the given IP on successful login."""
|
|
||||||
_login_attempts.pop(ip, None)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login/local", response_model=LoginResponse)
|
|
||||||
async def login_local(
|
|
||||||
request: Request,
|
|
||||||
response: Response,
|
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
|
||||||
):
|
|
||||||
"""Local email/password login."""
|
|
||||||
client_ip = _get_client_ip(request)
|
|
||||||
_check_rate_limit(client_ip)
|
|
||||||
|
|
||||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
_record_login_failure(client_ip)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
_record_login_success(client_ip)
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return LoginResponse(
|
|
||||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
|
||||||
needs_setup=user.needs_setup,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
|
||||||
"""Register a new user account (always 'user' role).
|
|
||||||
|
|
||||||
The first admin is created explicitly through /initialize. This endpoint creates regular users.
|
|
||||||
Auto-login by setting the session cookie.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout", response_model=MessageResponse)
|
|
||||||
async def logout(request: Request, response: Response):
|
|
||||||
"""Logout current user by clearing the cookie."""
|
|
||||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
|
||||||
return MessageResponse(message="Successfully logged out")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/change-password", response_model=MessageResponse)
|
|
||||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
|
||||||
"""Change password for the currently authenticated user.
|
|
||||||
|
|
||||||
Also handles the first-boot setup flow:
|
|
||||||
- If new_email is provided, updates email (checks uniqueness)
|
|
||||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
|
||||||
- Always increments token_version to invalidate old sessions
|
|
||||||
- Re-issues session cookie with new token_version
|
|
||||||
"""
|
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
|
|
||||||
if user.password_hash is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
|
||||||
|
|
||||||
if not await verify_password_async(body.current_password, user.password_hash):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
|
||||||
|
|
||||||
provider = get_local_provider()
|
|
||||||
|
|
||||||
# Update email if provided
|
|
||||||
if body.new_email is not None:
|
|
||||||
existing = await provider.get_user_by_email(body.new_email)
|
|
||||||
if existing and str(existing.id) != str(user.id):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
|
||||||
user.email = body.new_email
|
|
||||||
|
|
||||||
# Update password + bump version
|
|
||||||
user.password_hash = await hash_password_async(body.new_password)
|
|
||||||
user.token_version += 1
|
|
||||||
|
|
||||||
# Clear setup flag if this is the setup flow
|
|
||||||
if user.needs_setup and body.new_email is not None:
|
|
||||||
user.needs_setup = False
|
|
||||||
|
|
||||||
await provider.update_user(user)
|
|
||||||
|
|
||||||
# Re-issue cookie with new token_version
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return MessageResponse(message="Password changed successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
|
||||||
async def get_me(request: Request):
|
|
||||||
"""Get current authenticated user info."""
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
|
||||||
|
|
||||||
|
|
||||||
# Per-IP cache: ip → (timestamp, result_dict).
|
|
||||||
# Returns the cached result within the TTL instead of 429, because
|
|
||||||
# the answer (whether an admin exists) rarely changes and returning
|
|
||||||
# 429 breaks multi-tab / post-restart reconnection storms.
|
|
||||||
_SETUP_STATUS_CACHE: dict[str, tuple[float, dict]] = {}
|
|
||||||
_SETUP_STATUS_CACHE_TTL_SECONDS = 60
|
|
||||||
_MAX_TRACKED_SETUP_STATUS_IPS = 10000
|
|
||||||
_SETUP_STATUS_INFLIGHT: dict[str, asyncio.Task[dict]] = {}
|
|
||||||
_SETUP_STATUS_INFLIGHT_GUARD = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/setup-status")
|
|
||||||
async def setup_status(request: Request):
|
|
||||||
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
|
||||||
client_ip = _get_client_ip(request)
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
# Return cached result when within TTL — avoids 429 on multi-tab reconnection.
|
|
||||||
cached = _SETUP_STATUS_CACHE.get(client_ip)
|
|
||||||
if cached is not None:
|
|
||||||
cached_time, cached_result = cached
|
|
||||||
if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS:
|
|
||||||
return cached_result
|
|
||||||
|
|
||||||
async with _SETUP_STATUS_INFLIGHT_GUARD:
|
|
||||||
# Recheck cache after waiting for the inflight guard.
|
|
||||||
now = time.time()
|
|
||||||
cached = _SETUP_STATUS_CACHE.get(client_ip)
|
|
||||||
if cached is not None:
|
|
||||||
cached_time, cached_result = cached
|
|
||||||
if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS:
|
|
||||||
return cached_result
|
|
||||||
|
|
||||||
task = _SETUP_STATUS_INFLIGHT.get(client_ip)
|
|
||||||
if task is None:
|
|
||||||
# Evict stale entries when dict grows too large to bound memory usage.
|
|
||||||
if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS:
|
|
||||||
cutoff = now - _SETUP_STATUS_CACHE_TTL_SECONDS
|
|
||||||
stale = [k for k, (t, _) in _SETUP_STATUS_CACHE.items() if t < cutoff]
|
|
||||||
for k in stale:
|
|
||||||
del _SETUP_STATUS_CACHE[k]
|
|
||||||
if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS:
|
|
||||||
by_time = sorted(_SETUP_STATUS_CACHE.items(), key=lambda entry: entry[1][0])
|
|
||||||
for k, _ in by_time[: len(by_time) // 2]:
|
|
||||||
del _SETUP_STATUS_CACHE[k]
|
|
||||||
|
|
||||||
async def _compute_setup_status() -> dict:
|
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
|
||||||
return {"needs_setup": admin_count == 0}
|
|
||||||
|
|
||||||
task = asyncio.create_task(_compute_setup_status())
|
|
||||||
_SETUP_STATUS_INFLIGHT[client_ip] = task
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await task
|
|
||||||
finally:
|
|
||||||
async with _SETUP_STATUS_INFLIGHT_GUARD:
|
|
||||||
if _SETUP_STATUS_INFLIGHT.get(client_ip) is task:
|
|
||||||
del _SETUP_STATUS_INFLIGHT[client_ip]
|
|
||||||
|
|
||||||
# Cache only the stable "initialized" result to avoid stale setup redirects.
|
|
||||||
if result["needs_setup"] is False:
|
|
||||||
_SETUP_STATUS_CACHE[client_ip] = (time.time(), result)
|
|
||||||
else:
|
|
||||||
_SETUP_STATUS_CACHE.pop(client_ip, None)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class InitializeAdminRequest(BaseModel):
|
|
||||||
"""Request model for first-boot admin account creation."""
|
|
||||||
|
|
||||||
email: EmailStr
|
|
||||||
password: str = Field(..., min_length=8)
|
|
||||||
|
|
||||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
|
|
||||||
"""Create the first admin account on initial system setup.
|
|
||||||
|
|
||||||
Only callable when no admin exists. Returns 409 Conflict if an admin
|
|
||||||
already exists.
|
|
||||||
|
|
||||||
On success, the admin account is created with ``needs_setup=False`` and
|
|
||||||
the session cookie is set.
|
|
||||||
"""
|
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
|
||||||
if admin_count > 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
|
|
||||||
except ValueError:
|
|
||||||
# DB unique-constraint race: another concurrent request beat us.
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/oauth/{provider}")
|
|
||||||
async def oauth_login(provider: str):
|
|
||||||
"""Initiate OAuth login flow.
|
|
||||||
|
|
||||||
Redirects to the OAuth provider's authorization URL.
|
|
||||||
Currently a placeholder - requires OAuth provider implementation.
|
|
||||||
"""
|
|
||||||
if provider not in ["github", "google"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unsupported OAuth provider: {provider}",
|
|
||||||
)
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail="OAuth login not yet implemented",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/callback/{provider}")
|
|
||||||
async def oauth_callback(provider: str, code: str, state: str):
|
|
||||||
"""OAuth callback endpoint.
|
|
||||||
|
|
||||||
Handles the OAuth provider's callback after user authorization.
|
|
||||||
Currently a placeholder.
|
|
||||||
"""
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail="OAuth callback not yet implemented",
|
|
||||||
)
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .feedback import router as feedback_router
|
||||||
|
from .runs import router as runs_router
|
||||||
|
from .suggestions import router as suggestion_router
|
||||||
|
from .threads import router as threads_router
|
||||||
|
|
||||||
|
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]
|
||||||
+84
-93
@@ -1,8 +1,4 @@
|
|||||||
"""Feedback endpoints — create, list, stats, delete.
|
"""LangGraph-compatible run feedback endpoints."""
|
||||||
|
|
||||||
Allows users to submit thumbs-up/down feedback on runs,
|
|
||||||
optionally scoped to a specific message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -12,16 +8,12 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.dependencies import get_feedback_repository, get_run_repository
|
||||||
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
|
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||||
|
from app.plugins.auth.security.dependencies import get_current_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["feedback"])
|
router = APIRouter(tags=["feedback"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Request / response models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackCreateRequest(BaseModel):
|
class FeedbackCreateRequest(BaseModel):
|
||||||
@@ -30,16 +22,11 @@ class FeedbackCreateRequest(BaseModel):
|
|||||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||||
|
|
||||||
|
|
||||||
class FeedbackUpsertRequest(BaseModel):
|
|
||||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
|
||||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackResponse(BaseModel):
|
class FeedbackResponse(BaseModel):
|
||||||
feedback_id: str
|
feedback_id: str
|
||||||
run_id: str
|
run_id: str
|
||||||
thread_id: str
|
thread_id: str
|
||||||
user_id: str | None = None
|
owner_id: str | None = None
|
||||||
message_id: str | None = None
|
message_id: str | None = None
|
||||||
rating: int
|
rating: int
|
||||||
comment: str | None = None
|
comment: str | None = None
|
||||||
@@ -53,85 +40,36 @@ class FeedbackStatsResponse(BaseModel):
|
|||||||
negative: int = 0
|
negative: int = 0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
|
||||||
# Endpoints
|
run_store = get_run_repository(request)
|
||||||
# ---------------------------------------------------------------------------
|
if resolve_request_user_id(request) is None:
|
||||||
|
run = await run_store.get(run_id, user_id=None)
|
||||||
|
else:
|
||||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
with bind_request_actor_context(request):
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
run = await run_store.get(run_id)
|
||||||
async def upsert_feedback(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
body: FeedbackUpsertRequest,
|
|
||||||
request: Request,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create or update feedback for a run (idempotent)."""
|
|
||||||
if body.rating not in (1, -1):
|
|
||||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
|
||||||
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
run = await run_store.get(run_id)
|
|
||||||
if run is None:
|
if run is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
if run.get("thread_id") != thread_id:
|
if run.get("thread_id") != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||||
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.upsert(
|
async def _get_current_user(request: Request) -> str | None:
|
||||||
run_id=run_id,
|
"""Extract current user id from auth dependencies when available."""
|
||||||
thread_id=thread_id,
|
return await get_current_user_id(request)
|
||||||
rating=body.rating,
|
|
||||||
user_id=user_id,
|
|
||||||
comment=body.comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
async def _create_feedback(
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_run_feedback(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete the current user's feedback for a run."""
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
deleted = await feedback_repo.delete_by_run(
|
|
||||||
thread_id=thread_id,
|
|
||||||
run_id=run_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(status_code=404, detail="No feedback found for this run")
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def create_feedback(
|
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
body: FeedbackCreateRequest,
|
body: FeedbackCreateRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Submit feedback (thumbs-up/down) for a run."""
|
|
||||||
if body.rating not in (1, -1):
|
if body.rating not in (1, -1):
|
||||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||||
|
|
||||||
user_id = await get_current_user(request)
|
await _validate_run_scope(thread_id, run_id, request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
# Validate run exists and belongs to thread
|
feedback_repo = get_feedback_repository(request)
|
||||||
run_store = get_run_store(request)
|
|
||||||
run = await run_store.get(run_id)
|
|
||||||
if run is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
if run.get("thread_id") != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
|
||||||
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.create(
|
return await feedback_repo.create(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -142,41 +80,94 @@ async def create_feedback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
async def upsert_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create or replace the run-level feedback record."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
if user_id is not None:
|
||||||
|
return await feedback_repo.upsert(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=body.rating,
|
||||||
|
user_id=user_id,
|
||||||
|
comment=body.comment,
|
||||||
|
)
|
||||||
|
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||||
|
for item in existing:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
if isinstance(feedback_id, str):
|
||||||
|
await feedback_repo.delete(feedback_id)
|
||||||
|
return await _create_feedback(thread_id, run_id, body, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
async def create_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Submit feedback for a run."""
|
||||||
|
return await _create_feedback(thread_id, run_id, body, request)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def list_feedback(
|
async def list_feedback(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""List all feedback for a run."""
|
"""List all feedback for a run."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repository(request)
|
||||||
return await feedback_repo.list_by_run(thread_id, run_id)
|
user_id = await _get_current_user(request)
|
||||||
|
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def feedback_stats(
|
async def feedback_stats(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Get aggregated feedback stats (positive/negative counts) for a run."""
|
"""Get aggregated feedback stats for a run."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repository(request)
|
||||||
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||||
|
async def delete_run_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete all feedback records for a run."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
if user_id is not None:
|
||||||
|
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
|
||||||
|
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||||
|
for item in existing:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
if isinstance(feedback_id, str):
|
||||||
|
await feedback_repo.delete(feedback_id)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_feedback(
|
async def delete_feedback(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Delete a feedback record."""
|
"""Delete a single feedback record."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repository(request)
|
||||||
# Verify feedback belongs to the specified thread/run before deleting
|
|
||||||
existing = await feedback_repo.get(feedback_id)
|
existing = await feedback_repo.get(feedback_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||||
@@ -0,0 +1,501 @@
|
|||||||
|
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import Response, StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
|
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
|
||||||
|
from app.gateway.services.runs.input import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
RunSpecBuilder,
|
||||||
|
UnsupportedRunFeatureError,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.runs.types import RunRecord, RunSpec
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
|
||||||
|
|
||||||
|
router = APIRouter(tags=["runs"])
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreateRequest(BaseModel):
|
||||||
|
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||||
|
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
|
||||||
|
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||||
|
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
|
||||||
|
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
|
||||||
|
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
|
||||||
|
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||||
|
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||||
|
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||||
|
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||||
|
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||||
|
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||||
|
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||||
|
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||||
|
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||||
|
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||||
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||||
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||||
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||||
|
|
||||||
|
|
||||||
|
class RunResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None = None
|
||||||
|
status: str
|
||||||
|
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||||
|
multitask_strategy: str = "reject"
|
||||||
|
created_at: str = ""
|
||||||
|
updated_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RunDeleteResponse(BaseModel):
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
|
class RunMessageResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
content: JSONValue
|
||||||
|
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||||
|
created_at: str
|
||||||
|
seq: int
|
||||||
|
|
||||||
|
|
||||||
|
class RunMessagesResponse(BaseModel):
|
||||||
|
data: list[RunMessageResponse]
|
||||||
|
hasMore: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
|
||||||
|
"""Format a single SSE frame."""
|
||||||
|
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||||
|
parts = [f"event: {event}", f"data: {payload}"]
|
||||||
|
if event_id:
|
||||||
|
parts.append(f"id: {event_id}")
|
||||||
|
parts.append("")
|
||||||
|
parts.append("")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||||
|
return RunResponse(
|
||||||
|
run_id=record.run_id,
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
status=record.status,
|
||||||
|
metadata=record.metadata,
|
||||||
|
multitask_strategy=record.multitask_strategy,
|
||||||
|
created_at=record.created_at,
|
||||||
|
updated_at=record.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_paginated_rows(
|
||||||
|
rows: list[dict],
|
||||||
|
*,
|
||||||
|
limit: int,
|
||||||
|
after_seq: int | None,
|
||||||
|
) -> tuple[list[dict], bool]:
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
if not has_more:
|
||||||
|
return rows, False
|
||||||
|
if after_seq is not None:
|
||||||
|
return rows[:limit], True
|
||||||
|
return rows[-limit:], True
|
||||||
|
|
||||||
|
|
||||||
|
def _event_to_run_message(event: dict) -> RunMessageResponse:
|
||||||
|
return RunMessageResponse(
|
||||||
|
run_id=str(event["run_id"]),
|
||||||
|
content=event.get("content"),
|
||||||
|
metadata=dict(event.get("metadata") or {}),
|
||||||
|
created_at=str(event.get("created_at") or ""),
|
||||||
|
seq=int(event["seq"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sse_consumer(
|
||||||
|
stream: AsyncIterator[StreamEvent],
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
cancel_on_disconnect: bool,
|
||||||
|
cancel_run,
|
||||||
|
run_id: str,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
try:
|
||||||
|
async for event in stream:
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
|
||||||
|
if event.event == "__heartbeat__":
|
||||||
|
yield ": heartbeat\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.event == "__end__":
|
||||||
|
yield format_sse("end", None, event_id=event.id or None)
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.event == "__cancelled__":
|
||||||
|
yield format_sse("cancel", None, event_id=event.id or None)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield format_sse(event.event, event.data, event_id=event.id or None)
|
||||||
|
finally:
|
||||||
|
if cancel_on_disconnect:
|
||||||
|
await cancel_run(run_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_event_store(request: Request):
|
||||||
|
event_store = getattr(request.app.state, "run_event_store", None)
|
||||||
|
if event_store is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Run event store not available")
|
||||||
|
return event_store
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||||
|
async def list_runs(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[RunResponse]:
|
||||||
|
# Accepted for API compatibility; field projection is not implemented yet.
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
records = await facade.list_runs(thread_id)
|
||||||
|
if status is not None:
|
||||||
|
records = [record for record in records if record.status == status]
|
||||||
|
records = records[offset : offset + limit]
|
||||||
|
return [_record_to_response(record) for record in records]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||||
|
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
return _record_to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
|
||||||
|
async def run_messages(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> RunMessagesResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
event_store = _get_run_event_store(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
rows = await event_store.list_messages_by_run(
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
limit=limit + 1,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
|
||||||
|
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_spec(
|
||||||
|
*,
|
||||||
|
adapted: AdaptedRunRequest,
|
||||||
|
) -> RunSpec:
|
||||||
|
try:
|
||||||
|
return RunSpecBuilder().build(adapted)
|
||||||
|
except UnsupportedRunFeatureError as exc:
|
||||||
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||||
|
async def create_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> Response:
|
||||||
|
adapted = adapt_create_run_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.create_background(spec)
|
||||||
|
return Response(
|
||||||
|
content=_record_to_response(record).model_dump_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/stream")
|
||||||
|
async def stream_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> StreamingResponse:
|
||||||
|
adapted = adapt_create_stream_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, stream = await facade.create_and_stream(spec)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=record.run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/wait")
|
||||||
|
async def wait_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> Response:
|
||||||
|
adapted = adapt_create_wait_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, result = await facade.create_and_wait(spec)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs", response_model=RunResponse)
|
||||||
|
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||||
|
adapted = adapt_create_run_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.create_background(spec)
|
||||||
|
return Response(
|
||||||
|
content=_record_to_response(record).model_dump_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs/stream")
|
||||||
|
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||||
|
adapted = adapt_create_stream_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, stream = await facade.create_and_stream(spec)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=record.run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs/wait")
|
||||||
|
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||||
|
adapted = adapt_create_wait_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, result = await facade.create_and_wait(spec)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||||
|
async def stream_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
action: Literal["interrupt", "rollback"] | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
cancel_on_disconnect: bool = False,
|
||||||
|
stream_mode: str | None = None,
|
||||||
|
) -> StreamingResponse | Response:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
if action is not None:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
cancelled = await facade.cancel(run_id, action=action)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||||
|
if wait:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
await facade.join_wait(run_id)
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
adapted = adapt_join_stream_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=cancel_on_disconnect,
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||||
|
async def join_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
cancel_on_disconnect: bool = False,
|
||||||
|
) -> JSONValue:
|
||||||
|
# Accepted for API compatibility; current join_wait path does not change
|
||||||
|
# behavior based on client disconnect.
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
adapted = adapt_join_wait_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||||
|
async def cancel_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
wait: bool = False,
|
||||||
|
action: Literal["interrupt", "rollback"] = "interrupt",
|
||||||
|
) -> JSONValue:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
cancelled = await facade.cancel(run_id, action=action)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||||
|
if wait:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
return await facade.join_wait(run_id)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
|
||||||
|
async def delete_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> RunDeleteResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
deleted = await facade.delete_run(run_id)
|
||||||
|
return RunDeleteResponse(deleted=deleted)
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["suggestions"])
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionMessage(BaseModel):
|
||||||
|
role: str = Field(..., description="Message role: user|assistant")
|
||||||
|
content: str = Field(..., description="Message content as plain text")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsRequest(BaseModel):
|
||||||
|
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
|
||||||
|
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
|
||||||
|
model_name: str | None = Field(default=None, description="Optional model override")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsResponse(BaseModel):
|
||||||
|
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_markdown_code_fence(text: str) -> str:
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped.startswith("```"):
|
||||||
|
return stripped
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||||
|
return "\n".join(lines[1:-1]).strip()
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||||
|
candidate = _strip_markdown_code_fence(text)
|
||||||
|
start = candidate.find("[")
|
||||||
|
end = candidate.rfind("]")
|
||||||
|
if start == -1 or end == -1 or end <= start:
|
||||||
|
return None
|
||||||
|
candidate = candidate[start : end + 1]
|
||||||
|
try:
|
||||||
|
data = json.loads(candidate)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return None
|
||||||
|
out: list[str] = []
|
||||||
|
for item in data:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
s = item.strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
out.append(s)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_text(content: object) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
parts.append(block)
|
||||||
|
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
|
||||||
|
text = block.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(parts) if parts else ""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for m in messages:
|
||||||
|
role = m.role.strip().lower()
|
||||||
|
if role in ("user", "human"):
|
||||||
|
parts.append(f"User: {m.content.strip()}")
|
||||||
|
elif role in ("assistant", "ai"):
|
||||||
|
parts.append(f"Assistant: {m.content.strip()}")
|
||||||
|
else:
|
||||||
|
parts.append(f"{m.role}: {m.content.strip()}")
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/suggestions",
|
||||||
|
response_model=SuggestionsResponse,
|
||||||
|
summary="Generate Follow-up Questions",
|
||||||
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
|
)
|
||||||
|
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||||
|
if not request.messages:
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
|
n = request.n
|
||||||
|
conversation = _format_conversation(request.messages)
|
||||||
|
if not conversation:
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
|
system_instruction = (
|
||||||
|
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||||
|
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||||
|
"Requirements:\n"
|
||||||
|
"- Questions must be relevant to the preceding conversation.\n"
|
||||||
|
"- Questions must be written in the same language as the user.\n"
|
||||||
|
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||||
|
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||||
|
"- Output MUST be a JSON array of strings only.\n"
|
||||||
|
)
|
||||||
|
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||||
|
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||||
|
raw = _extract_response_text(response.content)
|
||||||
|
suggestions = _parse_json_string_list(raw) or []
|
||||||
|
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||||
|
cleaned = cleaned[:n]
|
||||||
|
return SuggestionsResponse(suggestions=cleaned)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
@@ -0,0 +1,455 @@
|
|||||||
|
"""Thread management endpoints.
|
||||||
|
|
||||||
|
Provides CRUD operations for threads and checkpoint state management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||||
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
from deerflow.runtime import serialize_channel_values
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["threads"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request / Response Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadCreateRequest(BaseModel):
|
||||||
|
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||||
|
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSearchRequest(BaseModel):
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||||
|
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||||
|
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||||
|
status: str | None = Field(default=None, description="Filter by thread status")
|
||||||
|
user_id: str | None = Field(default=None, description="Filter by user ID")
|
||||||
|
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadResponse(BaseModel):
|
||||||
|
thread_id: str = Field(description="Unique thread identifier")
|
||||||
|
status: str = Field(default="idle", description="Thread status")
|
||||||
|
created_at: str = Field(default="", description="ISO timestamp")
|
||||||
|
updated_at: str = Field(default="", description="ISO timestamp")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
|
||||||
|
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDeleteResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadStateUpdateRequest(BaseModel):
|
||||||
|
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||||
|
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadStateResponse(BaseModel):
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||||
|
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
|
||||||
|
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||||
|
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||||
|
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||||
|
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadHistoryRequest(BaseModel):
|
||||||
|
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||||
|
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryEntry(BaseModel):
|
||||||
|
checkpoint_id: str
|
||||||
|
parent_checkpoint_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
created_at: str | None = None
|
||||||
|
next: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_log_param(value: str) -> str:
|
||||||
|
"""Strip control characters to prevent log injection."""
|
||||||
|
|
||||||
|
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||||
|
"""Delete local filesystem data for a thread."""
|
||||||
|
path_manager = paths or get_paths()
|
||||||
|
try:
|
||||||
|
path_manager.delete_thread_dir(thread_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||||
|
|
||||||
|
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _thread_or_run_exists(
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
thread_meta_storage: ThreadMetaStorage,
|
||||||
|
run_repo,
|
||||||
|
) -> bool:
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
|
||||||
|
if request_user_id is None:
|
||||||
|
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||||
|
if thread is not None:
|
||||||
|
return True
|
||||||
|
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
|
||||||
|
return bool(runs)
|
||||||
|
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
thread = await thread_meta_storage.get_thread(thread_id)
|
||||||
|
if thread is not None:
|
||||||
|
return True
|
||||||
|
runs = await run_repo.list_by_thread(thread_id, limit=1)
|
||||||
|
return bool(runs)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ThreadResponse)
|
||||||
|
async def create_thread(
|
||||||
|
body: ThreadCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadResponse:
|
||||||
|
"""Create a new thread."""
|
||||||
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
if request_user_id is None:
|
||||||
|
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
existing = await thread_meta_storage.get_thread(thread_id)
|
||||||
|
if existing is not None:
|
||||||
|
return ThreadResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=existing.status,
|
||||||
|
created_at=existing.created_time.isoformat() if existing.created_time else "",
|
||||||
|
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
|
||||||
|
metadata=existing.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request_user_id is None:
|
||||||
|
created = await thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
created = await thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||||
|
|
||||||
|
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=created.status,
|
||||||
|
created_at=created.created_time.isoformat() if created.created_time else "",
|
||||||
|
updated_at=created.updated_time.isoformat() if created.updated_time else "",
|
||||||
|
metadata=created.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/search", response_model=list[ThreadResponse])
|
||||||
|
async def search_threads(
|
||||||
|
body: ThreadSearchRequest,
|
||||||
|
request: Request,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> list[ThreadResponse]:
|
||||||
|
"""Search threads with filters."""
|
||||||
|
try:
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
if request_user_id is None:
|
||||||
|
threads = await thread_meta_storage.search_threads(
|
||||||
|
metadata=body.metadata or None,
|
||||||
|
status=body.status,
|
||||||
|
user_id=body.user_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
limit=body.limit,
|
||||||
|
offset=body.offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
threads = await thread_meta_storage.search_threads(
|
||||||
|
metadata=body.metadata or None,
|
||||||
|
status=body.status,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
limit=body.limit,
|
||||||
|
offset=body.offset,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to search threads")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to search threads")
|
||||||
|
|
||||||
|
return [
|
||||||
|
ThreadResponse(
|
||||||
|
thread_id=t.thread_id,
|
||||||
|
status=t.status,
|
||||||
|
created_at=t.created_time.isoformat() if t.created_time else "",
|
||||||
|
updated_at=t.updated_time.isoformat() if t.updated_time else "",
|
||||||
|
metadata=t.metadata,
|
||||||
|
values={"title": t.display_name} if t.display_name else {},
|
||||||
|
interrupts={},
|
||||||
|
)
|
||||||
|
for t in threads
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||||
|
async def delete_thread(
|
||||||
|
thread_id: str,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadDeleteResponse:
|
||||||
|
"""Delete a thread and all associated data."""
|
||||||
|
response = _delete_thread_data(thread_id)
|
||||||
|
|
||||||
|
# Remove checkpoints (best-effort)
|
||||||
|
try:
|
||||||
|
if hasattr(checkpointer, "adelete_thread"):
|
||||||
|
await checkpointer.adelete_thread(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
# Remove thread_meta (best-effort)
|
||||||
|
try:
|
||||||
|
await thread_meta_storage.delete_thread(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
|
async def get_thread_state(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
run_repo: CurrentRunRepository,
|
||||||
|
) -> ThreadStateResponse:
|
||||||
|
"""Get the latest state snapshot for a thread."""
|
||||||
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||||
|
|
||||||
|
if checkpoint_tuple is None:
|
||||||
|
if await _thread_or_run_exists(
|
||||||
|
request=request,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
run_repo=run_repo,
|
||||||
|
):
|
||||||
|
return ThreadStateResponse()
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
|
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||||
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
|
|
||||||
|
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||||
|
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||||
|
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||||
|
|
||||||
|
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||||
|
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
|
return ThreadStateResponse(
|
||||||
|
values=serialize_channel_values(channel_values),
|
||||||
|
next=next_nodes,
|
||||||
|
tasks=tasks,
|
||||||
|
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
parent_checkpoint_id=parent_checkpoint_id,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
|
async def update_thread_state(
|
||||||
|
thread_id: str,
|
||||||
|
body: ThreadStateUpdateRequest,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadStateResponse:
|
||||||
|
"""Update thread state (human-in-the-loop or title rename)."""
|
||||||
|
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
if body.checkpoint_id:
|
||||||
|
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||||
|
|
||||||
|
if checkpoint_tuple is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||||
|
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||||
|
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||||
|
|
||||||
|
if body.values:
|
||||||
|
channel_values.update(body.values)
|
||||||
|
|
||||||
|
checkpoint["channel_values"] = channel_values
|
||||||
|
metadata["updated_at"] = time.time()
|
||||||
|
|
||||||
|
if body.as_node:
|
||||||
|
metadata["source"] = "update"
|
||||||
|
metadata["step"] = metadata.get("step", 0) + 1
|
||||||
|
metadata["writes"] = {body.as_node: body.values}
|
||||||
|
|
||||||
|
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
try:
|
||||||
|
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||||
|
|
||||||
|
new_checkpoint_id: str | None = None
|
||||||
|
if isinstance(new_config, dict):
|
||||||
|
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
# Sync title to thread_meta
|
||||||
|
if body.values and "title" in body.values:
|
||||||
|
new_title = body.values["title"]
|
||||||
|
if new_title:
|
||||||
|
try:
|
||||||
|
await thread_meta_storage.sync_thread_title(
|
||||||
|
thread_id=thread_id,
|
||||||
|
title=new_title,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
return ThreadStateResponse(
|
||||||
|
values=serialize_channel_values(channel_values),
|
||||||
|
next=[],
|
||||||
|
metadata=metadata,
|
||||||
|
checkpoint_id=new_checkpoint_id,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||||
|
async def get_thread_history(
|
||||||
|
thread_id: str,
|
||||||
|
body: ThreadHistoryRequest,
|
||||||
|
request: Request,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
run_repo: CurrentRunRepository,
|
||||||
|
) -> list[HistoryEntry]:
|
||||||
|
"""Get checkpoint history for a thread."""
|
||||||
|
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
if body.before:
|
||||||
|
config["configurable"]["checkpoint_id"] = body.before
|
||||||
|
|
||||||
|
entries: list[HistoryEntry] = []
|
||||||
|
is_first = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||||
|
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||||
|
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||||
|
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||||
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
|
|
||||||
|
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||||
|
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||||
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
|
|
||||||
|
values: dict[str, Any] = {}
|
||||||
|
if title := channel_values.get("title"):
|
||||||
|
values["title"] = title
|
||||||
|
if is_first and (messages := channel_values.get("messages")):
|
||||||
|
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||||
|
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
|
|
||||||
|
entries.append(
|
||||||
|
HistoryEntry(
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
metadata=metadata,
|
||||||
|
values=values,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
next=next_nodes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||||
|
|
||||||
|
if not entries and await _thread_or_run_exists(
|
||||||
|
request=request,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
run_repo=run_repo,
|
||||||
|
):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return entries
|
||||||
@@ -63,99 +63,6 @@ class McpConfigUpdateRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_MASKED_VALUE = "***"
|
|
||||||
|
|
||||||
|
|
||||||
def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse:
|
|
||||||
"""Return a copy of server config with sensitive fields masked.
|
|
||||||
|
|
||||||
Masks env values, header values, and removes OAuth secrets so they
|
|
||||||
are not exposed through the GET API endpoint.
|
|
||||||
"""
|
|
||||||
masked_env = {k: _MASKED_VALUE for k in server.env}
|
|
||||||
masked_headers = {k: _MASKED_VALUE for k in server.headers}
|
|
||||||
masked_oauth = None
|
|
||||||
if server.oauth is not None:
|
|
||||||
masked_oauth = server.oauth.model_copy(
|
|
||||||
update={
|
|
||||||
"client_secret": None,
|
|
||||||
"refresh_token": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return server.model_copy(
|
|
||||||
update={
|
|
||||||
"env": masked_env,
|
|
||||||
"headers": masked_headers,
|
|
||||||
"oauth": masked_oauth,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_preserving_secrets(
|
|
||||||
incoming: McpServerConfigResponse,
|
|
||||||
existing: McpServerConfigResponse,
|
|
||||||
) -> McpServerConfigResponse:
|
|
||||||
"""Merge incoming config with existing, preserving secrets masked by GET.
|
|
||||||
|
|
||||||
When the frontend toggles ``enabled`` it round-trips the full config:
|
|
||||||
GET (masked) → modify enabled → PUT (masked values sent back).
|
|
||||||
This function ensures masked values (``***``) are replaced with the
|
|
||||||
real secrets from the current on-disk config.
|
|
||||||
|
|
||||||
``***`` is only accepted for keys that already exist in *existing*.
|
|
||||||
New keys must provide a real value.
|
|
||||||
|
|
||||||
For OAuth secrets, ``None`` means "preserve the existing stored value"
|
|
||||||
so masked GET responses can be safely round-tripped. To explicitly clear
|
|
||||||
a stored secret, clients may send an empty string, which is converted
|
|
||||||
to ``None`` before persisting.
|
|
||||||
"""
|
|
||||||
merged_env = {}
|
|
||||||
for k, v in incoming.env.items():
|
|
||||||
if v == _MASKED_VALUE:
|
|
||||||
if k in existing.env:
|
|
||||||
merged_env[k] = existing.env[k]
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Cannot set env key '{k}' to masked value '***'; provide a real value.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
merged_env[k] = v
|
|
||||||
|
|
||||||
merged_headers = {}
|
|
||||||
for k, v in incoming.headers.items():
|
|
||||||
if v == _MASKED_VALUE:
|
|
||||||
if k in existing.headers:
|
|
||||||
merged_headers[k] = existing.headers[k]
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Cannot set header '{k}' to masked value '***'; provide a real value.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
merged_headers[k] = v
|
|
||||||
|
|
||||||
merged_oauth = incoming.oauth
|
|
||||||
if incoming.oauth is not None and existing.oauth is not None:
|
|
||||||
# None = preserve (masked round-trip), "" = explicitly clear, else = new value
|
|
||||||
merged_client_secret = existing.oauth.client_secret if incoming.oauth.client_secret is None else (None if incoming.oauth.client_secret == "" else incoming.oauth.client_secret)
|
|
||||||
merged_refresh_token = existing.oauth.refresh_token if incoming.oauth.refresh_token is None else (None if incoming.oauth.refresh_token == "" else incoming.oauth.refresh_token)
|
|
||||||
merged_oauth = incoming.oauth.model_copy(
|
|
||||||
update={
|
|
||||||
"client_secret": merged_client_secret,
|
|
||||||
"refresh_token": merged_refresh_token,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return incoming.model_copy(
|
|
||||||
update={
|
|
||||||
"env": merged_env,
|
|
||||||
"headers": merged_headers,
|
|
||||||
"oauth": merged_oauth,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/mcp/config",
|
"/mcp/config",
|
||||||
response_model=McpConfigResponse,
|
response_model=McpConfigResponse,
|
||||||
@@ -176,7 +83,7 @@ async def get_mcp_configuration() -> McpConfigResponse:
|
|||||||
"enabled": true,
|
"enabled": true,
|
||||||
"command": "npx",
|
"command": "npx",
|
||||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||||
"env": {"GITHUB_TOKEN": "***"},
|
"env": {"GITHUB_TOKEN": "ghp_xxx"},
|
||||||
"description": "GitHub MCP server for repository operations"
|
"description": "GitHub MCP server for repository operations"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -185,8 +92,7 @@ async def get_mcp_configuration() -> McpConfigResponse:
|
|||||||
"""
|
"""
|
||||||
config = get_extensions_config()
|
config = get_extensions_config()
|
||||||
|
|
||||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()}
|
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
|
||||||
return McpConfigResponse(mcp_servers=servers)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
@@ -236,39 +142,14 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
|||||||
config_path = Path.cwd().parent / "extensions_config.json"
|
config_path = Path.cwd().parent / "extensions_config.json"
|
||||||
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
||||||
|
|
||||||
# Load current config to preserve skills
|
# Load current config to preserve skills configuration
|
||||||
current_config = get_extensions_config()
|
current_config = get_extensions_config()
|
||||||
|
|
||||||
# Load raw (un-resolved) JSON from disk to use as the merge source.
|
# Convert request to dict format for JSON serialization
|
||||||
# This preserves $VAR placeholders in env values and top-level keys
|
config_data = {
|
||||||
# like mcpInterceptors that would otherwise be lost.
|
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
|
||||||
raw_servers: dict[str, dict] = {}
|
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
|
||||||
raw_other_keys: dict = {}
|
}
|
||||||
if config_path is not None and config_path.exists():
|
|
||||||
with open(config_path, encoding="utf-8") as f:
|
|
||||||
raw_data = json.load(f)
|
|
||||||
raw_servers = raw_data.get("mcpServers", {})
|
|
||||||
# Preserve any top-level keys beyond mcpServers/skills
|
|
||||||
for key, value in raw_data.items():
|
|
||||||
if key not in ("mcpServers", "skills"):
|
|
||||||
raw_other_keys[key] = value
|
|
||||||
|
|
||||||
# Merge incoming server configs with raw on-disk secrets
|
|
||||||
merged_servers: dict[str, McpServerConfigResponse] = {}
|
|
||||||
for name, incoming in request.mcp_servers.items():
|
|
||||||
raw_server = raw_servers.get(name)
|
|
||||||
if raw_server is not None:
|
|
||||||
merged_servers[name] = _merge_preserving_secrets(
|
|
||||||
incoming,
|
|
||||||
McpServerConfigResponse(**raw_server),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
merged_servers[name] = incoming
|
|
||||||
|
|
||||||
# Build config data preserving all top-level keys from the original file
|
|
||||||
config_data = dict(raw_other_keys)
|
|
||||||
config_data["mcpServers"] = {name: server.model_dump() for name, server in merged_servers.items()}
|
|
||||||
config_data["skills"] = {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}
|
|
||||||
|
|
||||||
# Write the configuration to file
|
# Write the configuration to file
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
@@ -281,8 +162,7 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
|||||||
|
|
||||||
# Reload the configuration and update the global cache
|
# Reload the configuration and update the global cache
|
||||||
reloaded_config = reload_extensions_config()
|
reloaded_config = reload_extensions_config()
|
||||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
|
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
|
||||||
return McpConfigResponse(mcp_servers=servers)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Memory API router for retrieving and managing global memory data."""
|
"""Memory API router for retrieving and managing global memory data."""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
from deerflow.agents.memory.updater import (
|
from deerflow.agents.memory.updater import (
|
||||||
clear_memory_data,
|
clear_memory_data,
|
||||||
create_memory_fact,
|
create_memory_fact,
|
||||||
@@ -13,7 +14,7 @@ from deerflow.agents.memory.updater import (
|
|||||||
update_memory_fact,
|
update_memory_fact,
|
||||||
)
|
)
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["memory"])
|
router = APIRouter(prefix="/api", tags=["memory"])
|
||||||
|
|
||||||
@@ -114,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
|
|||||||
summary="Get Memory Data",
|
summary="Get Memory Data",
|
||||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||||
)
|
)
|
||||||
async def get_memory() -> MemoryResponse:
|
async def get_memory(request: Request) -> MemoryResponse:
|
||||||
"""Get the current global memory data.
|
"""Get the current global memory data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -148,8 +149,9 @@ async def get_memory() -> MemoryResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -159,7 +161,7 @@ async def get_memory() -> MemoryResponse:
|
|||||||
summary="Reload Memory Data",
|
summary="Reload Memory Data",
|
||||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||||
)
|
)
|
||||||
async def reload_memory() -> MemoryResponse:
|
async def reload_memory(request: Request) -> MemoryResponse:
|
||||||
"""Reload memory data from file.
|
"""Reload memory data from file.
|
||||||
|
|
||||||
This forces a reload of the memory data from the storage file,
|
This forces a reload of the memory data from the storage file,
|
||||||
@@ -168,8 +170,9 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
The reloaded memory data.
|
The reloaded memory data.
|
||||||
"""
|
"""
|
||||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -179,14 +182,15 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
summary="Clear All Memory Data",
|
summary="Clear All Memory Data",
|
||||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||||
)
|
)
|
||||||
async def clear_memory() -> MemoryResponse:
|
async def clear_memory(request: Request) -> MemoryResponse:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
try:
|
||||||
except OSError as exc:
|
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -196,21 +200,22 @@ async def clear_memory() -> MemoryResponse:
|
|||||||
summary="Create Memory Fact",
|
summary="Create Memory Fact",
|
||||||
description="Create a single saved memory fact manually.",
|
description="Create a single saved memory fact manually.",
|
||||||
)
|
)
|
||||||
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
|
async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse:
|
||||||
"""Create a single fact manually."""
|
"""Create a single fact manually."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = create_memory_fact(
|
try:
|
||||||
content=request.content,
|
memory_data = create_memory_fact(
|
||||||
category=request.category,
|
content=payload.content,
|
||||||
confidence=request.confidence,
|
category=payload.category,
|
||||||
user_id=get_effective_user_id(),
|
confidence=payload.confidence,
|
||||||
)
|
user_id=get_effective_user_id(),
|
||||||
except ValueError as exc:
|
)
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
except ValueError as exc:
|
||||||
except OSError as exc:
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -220,16 +225,17 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
summary="Delete Memory Fact",
|
summary="Delete Memory Fact",
|
||||||
description="Delete a single saved memory fact by its fact id.",
|
description="Delete a single saved memory fact by its fact id.",
|
||||||
)
|
)
|
||||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse:
|
||||||
"""Delete a single fact from memory by fact id."""
|
"""Delete a single fact from memory by fact id."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
try:
|
||||||
except KeyError as exc:
|
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
except KeyError as exc:
|
||||||
except OSError as exc:
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
@@ -239,24 +245,25 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
|||||||
summary="Patch Memory Fact",
|
summary="Patch Memory Fact",
|
||||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||||
)
|
)
|
||||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
|
async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse:
|
||||||
"""Partially update a single fact manually."""
|
"""Partially update a single fact manually."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = update_memory_fact(
|
try:
|
||||||
fact_id=fact_id,
|
memory_data = update_memory_fact(
|
||||||
content=request.content,
|
fact_id=fact_id,
|
||||||
category=request.category,
|
content=payload.content,
|
||||||
confidence=request.confidence,
|
category=payload.category,
|
||||||
user_id=get_effective_user_id(),
|
confidence=payload.confidence,
|
||||||
)
|
user_id=get_effective_user_id(),
|
||||||
except ValueError as exc:
|
)
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
except ValueError as exc:
|
||||||
except KeyError as exc:
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
except KeyError as exc:
|
||||||
except OSError as exc:
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -266,10 +273,11 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
summary="Export Memory Data",
|
summary="Export Memory Data",
|
||||||
description="Export the current global memory data as JSON for backup or transfer.",
|
description="Export the current global memory data as JSON for backup or transfer.",
|
||||||
)
|
)
|
||||||
async def export_memory() -> MemoryResponse:
|
async def export_memory(request: Request) -> MemoryResponse:
|
||||||
"""Export the current memory data."""
|
"""Export the current memory data."""
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -279,14 +287,15 @@ async def export_memory() -> MemoryResponse:
|
|||||||
summary="Import Memory Data",
|
summary="Import Memory Data",
|
||||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||||
)
|
)
|
||||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse:
|
||||||
"""Import and persist memory data."""
|
"""Import and persist memory data."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
try:
|
||||||
except OSError as exc:
|
memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -333,24 +342,25 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
summary="Get Memory Status",
|
summary="Get Memory Status",
|
||||||
description="Retrieve both memory configuration and current data in a single request.",
|
description="Retrieve both memory configuration and current data in a single request.",
|
||||||
)
|
)
|
||||||
async def get_memory_status() -> MemoryStatusResponse:
|
async def get_memory_status(request: Request) -> MemoryStatusResponse:
|
||||||
"""Get the memory system status including configuration and data.
|
"""Get the memory system status including configuration and data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Combined memory configuration and current data.
|
Combined memory configuration and current data.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
with bind_request_actor_context(request):
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
config = get_memory_config()
|
||||||
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
return MemoryStatusResponse(
|
return MemoryStatusResponse(
|
||||||
config=MemoryConfigResponse(
|
config=MemoryConfigResponse(
|
||||||
enabled=config.enabled,
|
enabled=config.enabled,
|
||||||
storage_path=config.storage_path,
|
storage_path=config.storage_path,
|
||||||
debounce_seconds=config.debounce_seconds,
|
debounce_seconds=config.debounce_seconds,
|
||||||
max_facts=config.max_facts,
|
max_facts=config.max_facts,
|
||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.deps import get_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["models"])
|
router = APIRouter(prefix="/api", tags=["models"])
|
||||||
|
|
||||||
@@ -18,17 +17,10 @@ class ModelResponse(BaseModel):
|
|||||||
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
|
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
|
||||||
|
|
||||||
|
|
||||||
class TokenUsageResponse(BaseModel):
|
|
||||||
"""Token usage display configuration."""
|
|
||||||
|
|
||||||
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsListResponse(BaseModel):
|
class ModelsListResponse(BaseModel):
|
||||||
"""Response model for listing all models."""
|
"""Response model for listing all models."""
|
||||||
|
|
||||||
models: list[ModelResponse]
|
models: list[ModelResponse]
|
||||||
token_usage: TokenUsageResponse
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -37,14 +29,14 @@ class ModelsListResponse(BaseModel):
|
|||||||
summary="List All Models",
|
summary="List All Models",
|
||||||
description="Retrieve a list of all available AI models configured in the system.",
|
description="Retrieve a list of all available AI models configured in the system.",
|
||||||
)
|
)
|
||||||
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
|
async def list_models() -> ModelsListResponse:
|
||||||
"""List all available models from configuration.
|
"""List all available models from configuration.
|
||||||
|
|
||||||
Returns model information suitable for frontend display,
|
Returns model information suitable for frontend display,
|
||||||
excluding sensitive fields like API keys and internal configuration.
|
excluding sensitive fields like API keys and internal configuration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of all configured models with their metadata and token usage display settings.
|
A list of all configured models with their metadata.
|
||||||
|
|
||||||
Example Response:
|
Example Response:
|
||||||
```json
|
```json
|
||||||
@@ -52,27 +44,21 @@ async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResp
|
|||||||
"models": [
|
"models": [
|
||||||
{
|
{
|
||||||
"name": "gpt-4",
|
"name": "gpt-4",
|
||||||
"model": "gpt-4",
|
|
||||||
"display_name": "GPT-4",
|
"display_name": "GPT-4",
|
||||||
"description": "OpenAI GPT-4 model",
|
"description": "OpenAI GPT-4 model",
|
||||||
"supports_thinking": false,
|
"supports_thinking": false
|
||||||
"supports_reasoning_effort": false
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "claude-3-opus",
|
"name": "claude-3-opus",
|
||||||
"model": "claude-3-opus",
|
|
||||||
"display_name": "Claude 3 Opus",
|
"display_name": "Claude 3 Opus",
|
||||||
"description": "Anthropic Claude 3 Opus model",
|
"description": "Anthropic Claude 3 Opus model",
|
||||||
"supports_thinking": true,
|
"supports_thinking": true
|
||||||
"supports_reasoning_effort": false
|
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"token_usage": {
|
|
||||||
"enabled": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
config = get_app_config()
|
||||||
models = [
|
models = [
|
||||||
ModelResponse(
|
ModelResponse(
|
||||||
name=model.name,
|
name=model.name,
|
||||||
@@ -84,10 +70,7 @@ async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResp
|
|||||||
)
|
)
|
||||||
for model in config.models
|
for model in config.models
|
||||||
]
|
]
|
||||||
return ModelsListResponse(
|
return ModelsListResponse(models=models)
|
||||||
models=models,
|
|
||||||
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -96,7 +79,7 @@ async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResp
|
|||||||
summary="Get Model Details",
|
summary="Get Model Details",
|
||||||
description="Retrieve detailed information about a specific AI model by its name.",
|
description="Retrieve detailed information about a specific AI model by its name.",
|
||||||
)
|
)
|
||||||
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
|
async def get_model(model_name: str) -> ModelResponse:
|
||||||
"""Get a specific model by name.
|
"""Get a specific model by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -118,6 +101,7 @@ async def get_model(model_name: str, config: AppConfig = Depends(get_config)) ->
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
config = get_app_config()
|
||||||
model = config.get_model_config(model_name)
|
model = config.get_model_config(model_name)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
|
|
||||||
|
|
||||||
These endpoints auto-create a temporary thread when no ``thread_id`` is
|
|
||||||
supplied in the request body. When a ``thread_id`` **is** provided, it
|
|
||||||
is reused so that conversation history is preserved across calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
|
||||||
from app.gateway.services import sse_consumer, start_run
|
|
||||||
from deerflow.runtime import serialize_channel_values
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_thread_id(body: RunCreateRequest) -> str:
|
|
||||||
"""Return the thread_id from the request body, or generate a new one."""
|
|
||||||
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
|
|
||||||
if thread_id:
|
|
||||||
return str(thread_id)
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream")
|
|
||||||
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
||||||
"""Create a run and stream events via SSE.
|
|
||||||
|
|
||||||
If ``config.configurable.thread_id`` is provided, the run is created
|
|
||||||
on the given thread so that conversation history is preserved.
|
|
||||||
Otherwise a new temporary thread is created.
|
|
||||||
"""
|
|
||||||
thread_id = _resolve_thread_id(body)
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/wait", response_model=dict)
|
|
||||||
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|
||||||
"""Create a run and block until completion.
|
|
||||||
|
|
||||||
If ``config.configurable.thread_id`` is provided, the run is created
|
|
||||||
on the given thread so that conversation history is preserved.
|
|
||||||
Otherwise a new temporary thread is created.
|
|
||||||
"""
|
|
||||||
thread_id = _resolve_thread_id(body)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
if record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
if checkpoint_tuple is not None:
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
return serialize_channel_values(channel_values)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run-scoped read endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_run(run_id: str, request: Request) -> dict:
|
|
||||||
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{run_id}/messages")
|
|
||||||
@require_permission("runs", "read")
|
|
||||||
async def run_messages(
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200, ge=1),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""Return paginated messages for a run (cursor-based).
|
|
||||||
|
|
||||||
Pagination:
|
|
||||||
- after_seq: messages with seq > after_seq (forward)
|
|
||||||
- before_seq: messages with seq < before_seq (backward)
|
|
||||||
- neither: latest messages
|
|
||||||
|
|
||||||
Response: { data: [...], has_more: bool }
|
|
||||||
"""
|
|
||||||
run = await _resolve_run(run_id, request)
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
rows = await event_store.list_messages_by_run(
|
|
||||||
run["thread_id"],
|
|
||||||
run_id,
|
|
||||||
limit=limit + 1,
|
|
||||||
before_seq=before_seq,
|
|
||||||
after_seq=after_seq,
|
|
||||||
)
|
|
||||||
has_more = len(rows) > limit
|
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{run_id}/feedback")
|
|
||||||
@require_permission("runs", "read")
|
|
||||||
async def run_feedback(run_id: str, request: Request) -> list[dict]:
|
|
||||||
"""Return all feedback for a run."""
|
|
||||||
run = await _resolve_run(run_id, request)
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.list_by_run(run["thread_id"], run_id)
|
|
||||||
@@ -1,20 +1,29 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.deps import get_config
|
|
||||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||||
from deerflow.skills import Skill
|
from deerflow.skills import Skill, load_skills
|
||||||
from deerflow.skills.installer import SkillAlreadyExistsError
|
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||||
|
from deerflow.skills.manager import (
|
||||||
|
append_history,
|
||||||
|
atomic_write,
|
||||||
|
custom_skill_exists,
|
||||||
|
ensure_custom_skill_is_editable,
|
||||||
|
get_custom_skill_dir,
|
||||||
|
get_custom_skill_file,
|
||||||
|
get_skill_history_file,
|
||||||
|
read_custom_skill_content,
|
||||||
|
read_history,
|
||||||
|
validate_skill_markdown_content,
|
||||||
|
)
|
||||||
from deerflow.skills.security_scanner import scan_skill_content
|
from deerflow.skills.security_scanner import scan_skill_content
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
|
||||||
from deerflow.skills.types import SKILL_MD_FILE, SkillCategory
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +36,7 @@ class SkillResponse(BaseModel):
|
|||||||
name: str = Field(..., description="Name of the skill")
|
name: str = Field(..., description="Name of the skill")
|
||||||
description: str = Field(..., description="Description of what the skill does")
|
description: str = Field(..., description="Description of what the skill does")
|
||||||
license: str | None = Field(None, description="License information")
|
license: str | None = Field(None, description="License information")
|
||||||
category: SkillCategory = Field(..., description="Category of the skill (public or custom)")
|
category: str = Field(..., description="Category of the skill (public or custom)")
|
||||||
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
||||||
|
|
||||||
|
|
||||||
@@ -91,9 +100,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
|
|||||||
summary="List All Skills",
|
summary="List All Skills",
|
||||||
description="Retrieve a list of all available skills from both public and custom directories.",
|
description="Retrieve a list of all available skills from both public and custom directories.",
|
||||||
)
|
)
|
||||||
async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
async def list_skills() -> SkillsListResponse:
|
||||||
try:
|
try:
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
skills = load_skills(enabled_only=False)
|
||||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
||||||
@@ -106,10 +115,10 @@ async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResp
|
|||||||
summary="Install Skill",
|
summary="Install Skill",
|
||||||
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
|
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
|
||||||
)
|
)
|
||||||
async def install_skill(request: SkillInstallRequest, config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
|
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||||
try:
|
try:
|
||||||
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
||||||
result = await get_or_new_skill_storage(app_config=config).ainstall_skill_from_archive(skill_file_path)
|
result = install_skill_from_archive(skill_file_path)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return SkillInstallResponse(**result)
|
return SkillInstallResponse(**result)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -126,9 +135,9 @@ async def install_skill(request: SkillInstallRequest, config: AppConfig = Depend
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
||||||
async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
async def list_custom_skills() -> SkillsListResponse:
|
||||||
try:
|
try:
|
||||||
skills = [skill for skill in get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False) if skill.category == SkillCategory.CUSTOM]
|
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
||||||
@@ -136,14 +145,13 @@ async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsL
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
||||||
async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
skills = load_skills(enabled_only=False)
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
|
||||||
skill = next((s for s in skills if s.name == skill_name and s.category == SkillCategory.CUSTOM), None)
|
|
||||||
if skill is None:
|
if skill is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=get_or_new_skill_storage(app_config=config).read_custom_skill(skill_name))
|
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -152,31 +160,30 @@ async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_conf
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
||||||
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
ensure_custom_skill_is_editable(skill_name)
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
validate_skill_markdown_content(skill_name, request.content)
|
||||||
storage.ensure_custom_skill_is_editable(skill_name)
|
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||||
storage.validate_skill_markdown_content(skill_name, request.content)
|
|
||||||
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
|
|
||||||
if scan.decision == "block":
|
if scan.decision == "block":
|
||||||
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
||||||
prev_content = storage.read_custom_skill(skill_name)
|
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
|
||||||
storage.write_custom_skill(skill_name, SKILL_MD_FILE, request.content)
|
prev_content = skill_file.read_text(encoding="utf-8")
|
||||||
storage.append_history(
|
atomic_write(skill_file, request.content)
|
||||||
|
append_history(
|
||||||
skill_name,
|
skill_name,
|
||||||
{
|
{
|
||||||
"action": "human_edit",
|
"action": "human_edit",
|
||||||
"author": "human",
|
"author": "human",
|
||||||
"thread_id": None,
|
"thread_id": None,
|
||||||
"file_path": SKILL_MD_FILE,
|
"file_path": "SKILL.md",
|
||||||
"prev_content": prev_content,
|
"prev_content": prev_content,
|
||||||
"new_content": request.content,
|
"new_content": request.content,
|
||||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return await get_custom_skill(skill_name, config)
|
return await get_custom_skill(skill_name)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -189,22 +196,24 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
||||||
async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
|
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
ensure_custom_skill_is_editable(skill_name)
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
skill_dir = get_custom_skill_dir(skill_name)
|
||||||
storage.delete_custom_skill(
|
prev_content = read_custom_skill_content(skill_name)
|
||||||
|
append_history(
|
||||||
skill_name,
|
skill_name,
|
||||||
history_meta={
|
{
|
||||||
"action": "human_delete",
|
"action": "human_delete",
|
||||||
"author": "human",
|
"author": "human",
|
||||||
"thread_id": None,
|
"thread_id": None,
|
||||||
"file_path": SKILL_MD_FILE,
|
"file_path": "SKILL.md",
|
||||||
"prev_content": None,
|
"prev_content": prev_content,
|
||||||
"new_content": None,
|
"new_content": None,
|
||||||
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
shutil.rmtree(skill_dir)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -217,13 +226,11 @@ async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_c
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
||||||
async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
|
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
|
||||||
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
|
|
||||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||||
return CustomSkillHistoryResponse(history=storage.read_history(skill_name))
|
return CustomSkillHistoryResponse(history=read_history(skill_name))
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -232,39 +239,38 @@ async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
||||||
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
|
||||||
try:
|
try:
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||||
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
|
|
||||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||||
history = storage.read_history(skill_name)
|
history = read_history(skill_name)
|
||||||
if not history:
|
if not history:
|
||||||
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
||||||
record = history[request.history_index]
|
record = history[request.history_index]
|
||||||
target_content = record.get("prev_content")
|
target_content = record.get("prev_content")
|
||||||
if target_content is None:
|
if target_content is None:
|
||||||
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
||||||
storage.validate_skill_markdown_content(skill_name, target_content)
|
validate_skill_markdown_content(skill_name, target_content)
|
||||||
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
|
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||||
skill_file = storage.get_custom_skill_file(skill_name)
|
skill_file = get_custom_skill_file(skill_name)
|
||||||
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
||||||
history_entry = {
|
history_entry = {
|
||||||
"action": "rollback",
|
"action": "rollback",
|
||||||
"author": "human",
|
"author": "human",
|
||||||
"thread_id": None,
|
"thread_id": None,
|
||||||
"file_path": SKILL_MD_FILE,
|
"file_path": "SKILL.md",
|
||||||
"prev_content": current_content,
|
"prev_content": current_content,
|
||||||
"new_content": target_content,
|
"new_content": target_content,
|
||||||
"rollback_from_ts": record.get("ts"),
|
"rollback_from_ts": record.get("ts"),
|
||||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||||
}
|
}
|
||||||
if scan.decision == "block":
|
if scan.decision == "block":
|
||||||
storage.append_history(skill_name, history_entry)
|
append_history(skill_name, history_entry)
|
||||||
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
||||||
storage.write_custom_skill(skill_name, SKILL_MD_FILE, target_content)
|
atomic_write(skill_file, target_content)
|
||||||
storage.append_history(skill_name, history_entry)
|
append_history(skill_name, history_entry)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return await get_custom_skill(skill_name, config)
|
return await get_custom_skill(skill_name)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@@ -284,10 +290,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest,
|
|||||||
summary="Get Skill Details",
|
summary="Get Skill Details",
|
||||||
description="Retrieve detailed information about a specific skill by its name.",
|
description="Retrieve detailed information about a specific skill by its name.",
|
||||||
)
|
)
|
||||||
async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
|
async def get_skill(skill_name: str) -> SkillResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
skills = load_skills(enabled_only=False)
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
|
||||||
skill = next((s for s in skills if s.name == skill_name), None)
|
skill = next((s for s in skills if s.name == skill_name), None)
|
||||||
|
|
||||||
if skill is None:
|
if skill is None:
|
||||||
@@ -307,10 +312,9 @@ async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) ->
|
|||||||
summary="Update Skill",
|
summary="Update Skill",
|
||||||
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
||||||
)
|
)
|
||||||
async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
|
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
skills = load_skills(enabled_only=False)
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
|
||||||
skill = next((s for s in skills if s.name == skill_name), None)
|
skill = next((s for s in skills if s.name == skill_name), None)
|
||||||
|
|
||||||
if skill is None:
|
if skill is None:
|
||||||
@@ -336,7 +340,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest, config: App
|
|||||||
reload_extensions_config()
|
reload_extensions_config()
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
|
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
skills = load_skills(enabled_only=False)
|
||||||
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
||||||
|
|
||||||
if updated_skill is None:
|
if updated_skill is None:
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Request
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_config
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -101,13 +98,7 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
|||||||
summary="Generate Follow-up Questions",
|
summary="Generate Follow-up Questions",
|
||||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
)
|
)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||||
async def generate_suggestions(
|
|
||||||
thread_id: str,
|
|
||||||
body: SuggestionsRequest,
|
|
||||||
request: Request,
|
|
||||||
config: AppConfig = Depends(get_config),
|
|
||||||
) -> SuggestionsResponse:
|
|
||||||
if not body.messages:
|
if not body.messages:
|
||||||
return SuggestionsResponse(suggestions=[])
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
@@ -129,8 +120,8 @@ async def generate_suggestions(
|
|||||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
|
model = create_chat_model(name=body.model_name, thinking_enabled=False)
|
||||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
|
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||||
raw = _extract_response_text(response.content)
|
raw = _extract_response_text(response.content)
|
||||||
suggestions = _parse_json_string_list(raw) or []
|
suggestions = _parse_json_string_list(raw) or []
|
||||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||||
|
|||||||
@@ -1,409 +0,0 @@
|
|||||||
"""Runs endpoints — create, stream, wait, cancel.
|
|
||||||
|
|
||||||
Implements the LangGraph Platform runs API on top of
|
|
||||||
:class:`deerflow.agents.runs.RunManager` and
|
|
||||||
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
|
||||||
|
|
||||||
SSE format is aligned with the LangGraph Platform protocol so that
|
|
||||||
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
|
||||||
works without modification.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
|
||||||
from fastapi.responses import Response, StreamingResponse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
||||||
from app.gateway.services import sse_consumer, start_run
|
|
||||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Request / response models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class RunCreateRequest(BaseModel):
|
|
||||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
|
||||||
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
|
||||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
|
||||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
|
||||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
|
||||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
|
||||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
||||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
|
||||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
|
||||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
|
||||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
|
||||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
|
||||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
|
||||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
|
||||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
|
||||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
|
||||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
|
||||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
|
||||||
|
|
||||||
|
|
||||||
class RunResponse(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
thread_id: str
|
|
||||||
assistant_id: str | None = None
|
|
||||||
status: str
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
multitask_strategy: str = "reject"
|
|
||||||
created_at: str = ""
|
|
||||||
updated_at: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
|
||||||
tokens: int = 0
|
|
||||||
runs: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageCallerBreakdown(BaseModel):
|
|
||||||
lead_agent: int = 0
|
|
||||||
subagent: int = 0
|
|
||||||
middleware: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageResponse(BaseModel):
|
|
||||||
thread_id: str
|
|
||||||
total_tokens: int = 0
|
|
||||||
total_input_tokens: int = 0
|
|
||||||
total_output_tokens: int = 0
|
|
||||||
total_runs: int = 0
|
|
||||||
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
|
|
||||||
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _cancel_conflict_detail(run_id: str, record: RunRecord) -> str:
|
|
||||||
if record.status in (RunStatus.pending, RunStatus.running):
|
|
||||||
return f"Run {run_id} is not active on this worker and cannot be cancelled"
|
|
||||||
return f"Run {run_id} is not cancellable (status: {record.status.value})"
|
|
||||||
|
|
||||||
|
|
||||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
|
||||||
return RunResponse(
|
|
||||||
run_id=record.run_id,
|
|
||||||
thread_id=record.thread_id,
|
|
||||||
assistant_id=record.assistant_id,
|
|
||||||
status=record.status.value,
|
|
||||||
metadata=record.metadata,
|
|
||||||
kwargs=record.kwargs,
|
|
||||||
multitask_strategy=record.multitask_strategy,
|
|
||||||
created_at=record.created_at,
|
|
||||||
updated_at=record.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
|
||||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
||||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
|
||||||
"""Create a background run (returns immediately)."""
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
return _record_to_response(record)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/stream")
|
|
||||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
||||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
||||||
"""Create a run and stream events via SSE.
|
|
||||||
|
|
||||||
The response includes a ``Content-Location`` header with the run's
|
|
||||||
resource URL, matching the LangGraph Platform protocol. The
|
|
||||||
``useStream`` React hook uses this to extract run metadata.
|
|
||||||
"""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
# LangGraph Platform includes run metadata in this header.
|
|
||||||
# The SDK uses a greedy regex to extract the run id from this path,
|
|
||||||
# so it must point at the canonical run resource without extra suffixes.
|
|
||||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
|
||||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
||||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
|
||||||
"""Create a run and block until it completes, returning the final state."""
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
if record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
if checkpoint_tuple is not None:
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
return serialize_channel_values(channel_values)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
|
||||||
"""List all runs for a thread."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
records = await run_mgr.list_by_thread(thread_id, user_id=user_id)
|
|
||||||
return [_record_to_response(r) for r in records]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
|
||||||
"""Get details of a specific run."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
record = await run_mgr.get(run_id, user_id=user_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
return _record_to_response(record)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
|
||||||
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
|
|
||||||
async def cancel_run(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
|
||||||
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
|
||||||
) -> Response:
|
|
||||||
"""Cancel a running or pending run.
|
|
||||||
|
|
||||||
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
|
||||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
|
||||||
- wait=true: Block until the run fully stops, return 204
|
|
||||||
- wait=false: Return immediately with 202
|
|
||||||
"""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
||||||
if not cancelled:
|
|
||||||
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
|
||||||
|
|
||||||
if wait and record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
return Response(status_code=202)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
|
||||||
"""Join an existing run's SSE stream."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
if record.store_only:
|
|
||||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
|
||||||
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def stream_existing_run(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
|
||||||
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
|
||||||
):
|
|
||||||
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
|
||||||
|
|
||||||
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
|
||||||
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
|
||||||
is present the run is cancelled first; the response then streams any
|
|
||||||
remaining buffered events so the client observes a clean shutdown.
|
|
||||||
"""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
if record.store_only and action is None:
|
|
||||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
|
||||||
|
|
||||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
|
||||||
if action is not None:
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
||||||
if not cancelled:
|
|
||||||
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
|
||||||
if wait and record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except (asyncio.CancelledError, Exception):
|
|
||||||
pass
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Messages / Events / Token usage endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/messages")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_thread_messages(
|
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
|
||||||
|
|
||||||
# Attach feedback to the last AI message of each run
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
|
||||||
|
|
||||||
# Find the last ai_message per run_id
|
|
||||||
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if msg.get("event_type") == "ai_message":
|
|
||||||
last_ai_per_run[msg["run_id"]] = i
|
|
||||||
|
|
||||||
# Attach feedback field
|
|
||||||
last_ai_indices = set(last_ai_per_run.values())
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if i in last_ai_indices:
|
|
||||||
run_id = msg["run_id"]
|
|
||||||
fb = feedback_map.get(run_id)
|
|
||||||
msg["feedback"] = (
|
|
||||||
{
|
|
||||||
"feedback_id": fb["feedback_id"],
|
|
||||||
"rating": fb["rating"],
|
|
||||||
"comment": fb.get("comment"),
|
|
||||||
}
|
|
||||||
if fb
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg["feedback"] = None
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_run_messages(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200, ge=1),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""Return paginated messages for a specific run.
|
|
||||||
|
|
||||||
Response: { data: [...], has_more: bool }
|
|
||||||
"""
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
rows = await event_store.list_messages_by_run(
|
|
||||||
thread_id,
|
|
||||||
run_id,
|
|
||||||
limit=limit + 1,
|
|
||||||
before_seq=before_seq,
|
|
||||||
after_seq=after_seq,
|
|
||||||
)
|
|
||||||
has_more = len(rows) > limit
|
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_run_events(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
event_types: str | None = Query(default=None),
|
|
||||||
limit: int = Query(default=500, le=2000),
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Return the full event stream for a run (debug/audit)."""
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
types = event_types.split(",") if event_types else None
|
|
||||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
|
||||||
"""Thread-level token usage aggregation."""
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
|
||||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
|
||||||
@@ -1,648 +0,0 @@
|
|||||||
"""Thread CRUD, state, and history endpoints.
|
|
||||||
|
|
||||||
Combines the existing thread-local filesystem cleanup with LangGraph
|
|
||||||
Platform-compatible thread management backed by the checkpointer.
|
|
||||||
|
|
||||||
Channel values returned in state responses are serialized through
|
|
||||||
:func:`deerflow.runtime.serialization.serialize_channel_values` to
|
|
||||||
ensure LangChain message objects are converted to JSON-safe dicts
|
|
||||||
matching the LangGraph Platform wire format expected by the
|
|
||||||
``useStream`` React hook.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_checkpointer
|
|
||||||
from app.gateway.utils import sanitize_log_param
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
|
||||||
from deerflow.runtime import serialize_channel_values
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.utils.time import coerce_iso, now_iso
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
|
||||||
|
|
||||||
|
|
||||||
# Metadata keys that the server controls; clients are not allowed to set
|
|
||||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
|
||||||
# inbound model below so a malicious client cannot reflect a forged
|
|
||||||
# owner identity through the API surface. Defense-in-depth — the
|
|
||||||
# row-level invariant is still ``threads_meta.user_id`` populated from
|
|
||||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
|
||||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
|
|
||||||
"""Return ``metadata`` with server-controlled keys removed."""
|
|
||||||
if not metadata:
|
|
||||||
return metadata or {}
|
|
||||||
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Response / request models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadDeleteResponse(BaseModel):
|
|
||||||
"""Response model for thread cleanup."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadResponse(BaseModel):
|
|
||||||
"""Response model for a single thread."""
|
|
||||||
|
|
||||||
thread_id: str = Field(description="Unique thread identifier")
|
|
||||||
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
|
|
||||||
created_at: str = Field(default="", description="ISO timestamp")
|
|
||||||
updated_at: str = Field(default="", description="ISO timestamp")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
|
|
||||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadCreateRequest(BaseModel):
|
|
||||||
"""Request body for creating a thread."""
|
|
||||||
|
|
||||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
|
||||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
|
||||||
|
|
||||||
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadSearchRequest(BaseModel):
|
|
||||||
"""Request body for searching threads."""
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
|
||||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
|
||||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
|
||||||
status: str | None = Field(default=None, description="Filter by thread status")
|
|
||||||
|
|
||||||
@field_validator("metadata")
|
|
||||||
@classmethod
|
|
||||||
def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Reject filter entries the SQL backend cannot compile.
|
|
||||||
|
|
||||||
Enforces consistent behaviour across SQL and memory backends.
|
|
||||||
See ``deerflow.persistence.json_compat`` for the shared validators.
|
|
||||||
"""
|
|
||||||
if not v:
|
|
||||||
return v
|
|
||||||
from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value
|
|
||||||
|
|
||||||
bad_entries: list[str] = []
|
|
||||||
for key, value in v.items():
|
|
||||||
if not validate_metadata_filter_key(key):
|
|
||||||
bad_entries.append(f"{key!r} (unsafe key)")
|
|
||||||
elif not validate_metadata_filter_value(value):
|
|
||||||
bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})")
|
|
||||||
if bad_entries:
|
|
||||||
raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateResponse(BaseModel):
|
|
||||||
"""Response model for thread state."""
|
|
||||||
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
|
||||||
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
|
||||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
|
||||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
|
||||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
|
||||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadPatchRequest(BaseModel):
|
|
||||||
"""Request body for patching thread metadata."""
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
|
|
||||||
|
|
||||||
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateUpdateRequest(BaseModel):
|
|
||||||
"""Request body for updating thread state (human-in-the-loop resume)."""
|
|
||||||
|
|
||||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
||||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
|
||||||
|
|
||||||
|
|
||||||
class HistoryEntry(BaseModel):
|
|
||||||
"""Single checkpoint history entry."""
|
|
||||||
|
|
||||||
checkpoint_id: str
|
|
||||||
parent_checkpoint_id: str | None = None
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
created_at: str | None = None
|
|
||||||
next: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadHistoryRequest(BaseModel):
|
|
||||||
"""Request body for checkpoint history."""
|
|
||||||
|
|
||||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
|
||||||
before: str | None = Field(default=None, description="Cursor for pagination")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
|
|
||||||
"""Delete local persisted filesystem data for a thread."""
|
|
||||||
path_manager = paths or get_paths()
|
|
||||||
try:
|
|
||||||
path_manager.delete_thread_dir(thread_id, user_id=user_id)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Not critical — thread data may not exist on disk
|
|
||||||
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
|
||||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
|
||||||
|
|
||||||
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
|
||||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
|
||||||
|
|
||||||
|
|
||||||
def _derive_thread_status(checkpoint_tuple) -> str:
|
|
||||||
"""Derive thread status from checkpoint metadata."""
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
return "idle"
|
|
||||||
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
|
|
||||||
|
|
||||||
# Check for error in pending writes
|
|
||||||
for pw in pending_writes:
|
|
||||||
if len(pw) >= 2 and pw[1] == "__error__":
|
|
||||||
return "error"
|
|
||||||
|
|
||||||
# Check for pending next tasks (indicates interrupt)
|
|
||||||
tasks = getattr(checkpoint_tuple, "tasks", None)
|
|
||||||
if tasks:
|
|
||||||
return "interrupted"
|
|
||||||
|
|
||||||
return "idle"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
|
||||||
"""Delete local persisted filesystem data for a thread.
|
|
||||||
|
|
||||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
|
||||||
and removes the thread_meta row from the configured ThreadMetaStore
|
|
||||||
(sqlite or memory).
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
# Clean local filesystem
|
|
||||||
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
|
|
||||||
|
|
||||||
# Remove checkpoints (best-effort)
|
|
||||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
|
||||||
if checkpointer is not None:
|
|
||||||
try:
|
|
||||||
if hasattr(checkpointer, "adelete_thread"):
|
|
||||||
await checkpointer.adelete_thread(thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
|
||||||
# so the deleted thread no longer appears in /threads/search.
|
|
||||||
try:
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
await thread_store.delete(thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ThreadResponse)
|
|
||||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
|
||||||
"""Create a new thread.
|
|
||||||
|
|
||||||
Writes a thread_meta record (so the thread appears in /threads/search)
|
|
||||||
and an empty checkpoint (so state endpoints work immediately).
|
|
||||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
|
||||||
now = now_iso()
|
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
|
||||||
existing_record = await thread_store.get(thread_id)
|
|
||||||
if existing_record is not None:
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=existing_record.get("status", "idle"),
|
|
||||||
created_at=coerce_iso(existing_record.get("created_at", "")),
|
|
||||||
updated_at=coerce_iso(existing_record.get("updated_at", "")),
|
|
||||||
metadata=existing_record.get("metadata", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write thread_meta so the thread appears in /threads/search immediately
|
|
||||||
try:
|
|
||||||
await thread_store.create(
|
|
||||||
thread_id,
|
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
|
||||||
metadata=body.metadata,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
|
||||||
|
|
||||||
# Write an empty checkpoint so state endpoints work immediately
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
ckpt_metadata = {
|
|
||||||
"step": -1,
|
|
||||||
"source": "input",
|
|
||||||
"writes": None,
|
|
||||||
"parents": {},
|
|
||||||
**body.metadata,
|
|
||||||
"created_at": now,
|
|
||||||
}
|
|
||||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
|
||||||
|
|
||||||
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status="idle",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
metadata=body.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_model=list[ThreadResponse])
|
|
||||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
|
||||||
"""Search and list threads.
|
|
||||||
|
|
||||||
Delegates to the configured ThreadMetaStore implementation
|
|
||||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
from deerflow.persistence.thread_meta import InvalidMetadataFilterError
|
|
||||||
|
|
||||||
repo = get_thread_store(request)
|
|
||||||
try:
|
|
||||||
rows = await repo.search(
|
|
||||||
metadata=body.metadata or None,
|
|
||||||
status=body.status,
|
|
||||||
limit=body.limit,
|
|
||||||
offset=body.offset,
|
|
||||||
)
|
|
||||||
except InvalidMetadataFilterError as exc:
|
|
||||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
||||||
return [
|
|
||||||
ThreadResponse(
|
|
||||||
thread_id=r["thread_id"],
|
|
||||||
status=r.get("status", "idle"),
|
|
||||||
# ``coerce_iso`` heals legacy unix-second values that
|
|
||||||
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
|
|
||||||
# SQL-backed rows already arrive as ISO strings and pass through.
|
|
||||||
created_at=coerce_iso(r.get("created_at", "")),
|
|
||||||
updated_at=coerce_iso(r.get("updated_at", "")),
|
|
||||||
metadata=r.get("metadata", {}),
|
|
||||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
|
||||||
interrupts={},
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
|
||||||
"""Merge metadata into a thread record."""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
record = await thread_store.get(thread_id)
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
|
||||||
try:
|
|
||||||
await thread_store.update_metadata(thread_id, body.metadata)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
|
||||||
|
|
||||||
# Re-read to get the merged metadata + refreshed updated_at
|
|
||||||
record = await thread_store.get(thread_id) or record
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=record.get("status", "idle"),
|
|
||||||
created_at=coerce_iso(record.get("created_at", "")),
|
|
||||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
|
||||||
metadata=record.get("metadata", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|
||||||
"""Get thread info.
|
|
||||||
|
|
||||||
Reads metadata from the ThreadMetaStore and derives the accurate
|
|
||||||
execution status from the checkpointer. Falls back to the checkpointer
|
|
||||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
record: dict | None = await thread_store.get(thread_id)
|
|
||||||
|
|
||||||
# Derive accurate status from the checkpointer
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
|
||||||
|
|
||||||
if record is None and checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# If the thread exists in the checkpointer but not in thread_meta (e.g.
|
|
||||||
# legacy data created before thread_meta adoption), synthesize a minimal
|
|
||||||
# record from the checkpoint metadata.
|
|
||||||
if record is None and checkpoint_tuple is not None:
|
|
||||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
record = {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"status": "idle",
|
|
||||||
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
|
|
||||||
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
|
||||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
|
||||||
}
|
|
||||||
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=status,
|
|
||||||
created_at=coerce_iso(record.get("created_at", "")),
|
|
||||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
|
||||||
metadata=record.get("metadata", {}),
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
|
||||||
"""Get the latest state snapshot for a thread.
|
|
||||||
|
|
||||||
Channel values are serialized to ensure LangChain message objects
|
|
||||||
are converted to JSON-safe dicts.
|
|
||||||
"""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
|
||||||
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
checkpoint_id = None
|
|
||||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
|
||||||
if ckpt_config:
|
|
||||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
|
||||||
parent_checkpoint_id = None
|
|
||||||
if parent_config:
|
|
||||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
|
||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
|
||||||
|
|
||||||
values = serialize_channel_values(channel_values)
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
|
||||||
values=values,
|
|
||||||
next=next_tasks,
|
|
||||||
metadata=metadata,
|
|
||||||
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
|
|
||||||
checkpoint_id=checkpoint_id,
|
|
||||||
parent_checkpoint_id=parent_checkpoint_id,
|
|
||||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
|
||||||
tasks=tasks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
|
||||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
|
||||||
|
|
||||||
Writes a new checkpoint that merges *body.values* into the latest
|
|
||||||
channel values, then syncs any updated ``title`` field through the
|
|
||||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
|
||||||
change immediately in both sqlite and memory backends.
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
|
|
||||||
# checkpoint_ns must be present in the config for aput — default to ""
|
|
||||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
|
||||||
# fetches the latest checkpoint for the thread.
|
|
||||||
read_config: dict[str, Any] = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"checkpoint_ns": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if body.checkpoint_id:
|
|
||||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
|
||||||
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# Work on mutable copies so we don't accidentally mutate cached objects.
|
|
||||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
|
||||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
|
||||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
|
||||||
|
|
||||||
if body.values:
|
|
||||||
channel_values.update(body.values)
|
|
||||||
|
|
||||||
checkpoint["channel_values"] = channel_values
|
|
||||||
metadata["updated_at"] = now_iso()
|
|
||||||
|
|
||||||
if body.as_node:
|
|
||||||
metadata["source"] = "update"
|
|
||||||
metadata["step"] = metadata.get("step", 0) + 1
|
|
||||||
metadata["writes"] = {body.as_node: body.values}
|
|
||||||
|
|
||||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
|
||||||
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
|
||||||
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
|
||||||
write_config: dict[str, Any] = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"checkpoint_ns": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
|
||||||
|
|
||||||
new_checkpoint_id: str | None = None
|
|
||||||
if isinstance(new_config, dict):
|
|
||||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
|
||||||
# reflects them immediately in both sqlite and memory backends.
|
|
||||||
if body.values and "title" in body.values:
|
|
||||||
new_title = body.values["title"]
|
|
||||||
if new_title: # Skip empty strings and None
|
|
||||||
try:
|
|
||||||
await thread_store.update_display_name(thread_id, new_title)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
next=[],
|
|
||||||
metadata=metadata,
|
|
||||||
checkpoint_id=new_checkpoint_id,
|
|
||||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
|
||||||
"""Get checkpoint history for a thread.
|
|
||||||
|
|
||||||
Messages are read from the checkpointer's channel values (the
|
|
||||||
authoritative source) and serialized via
|
|
||||||
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
|
|
||||||
Only the latest (first) checkpoint carries the ``messages`` key to
|
|
||||||
avoid duplicating them across every entry.
|
|
||||||
"""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
|
||||||
if body.before:
|
|
||||||
config["configurable"]["checkpoint_id"] = body.before
|
|
||||||
|
|
||||||
entries: list[HistoryEntry] = []
|
|
||||||
is_latest_checkpoint = True
|
|
||||||
try:
|
|
||||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
|
||||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
|
||||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
|
||||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
|
|
||||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
|
||||||
parent_id = None
|
|
||||||
if parent_config:
|
|
||||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
# Build values from checkpoint channel_values
|
|
||||||
values: dict[str, Any] = {}
|
|
||||||
if title := channel_values.get("title"):
|
|
||||||
values["title"] = title
|
|
||||||
if thread_data := channel_values.get("thread_data"):
|
|
||||||
values["thread_data"] = thread_data
|
|
||||||
|
|
||||||
# Attach messages only to the latest checkpoint entry.
|
|
||||||
if is_latest_checkpoint:
|
|
||||||
messages = channel_values.get("messages")
|
|
||||||
if messages:
|
|
||||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
|
||||||
is_latest_checkpoint = False
|
|
||||||
|
|
||||||
# Derive next tasks
|
|
||||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
|
||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
|
||||||
|
|
||||||
# Strip LangGraph internal keys from metadata
|
|
||||||
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
|
||||||
# Keep step for ordering context
|
|
||||||
if "step" in metadata:
|
|
||||||
user_meta["step"] = metadata["step"]
|
|
||||||
|
|
||||||
entries.append(
|
|
||||||
HistoryEntry(
|
|
||||||
checkpoint_id=checkpoint_id,
|
|
||||||
parent_checkpoint_id=parent_id,
|
|
||||||
metadata=user_meta,
|
|
||||||
values=values,
|
|
||||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
|
||||||
next=next_tasks,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
|
||||||
|
|
||||||
return entries
|
|
||||||
@@ -4,26 +4,21 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
from app.gateway.deps import get_config
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
UnsafeUploadPathError,
|
|
||||||
claim_unique_filename,
|
|
||||||
delete_file_safe,
|
delete_file_safe,
|
||||||
enrich_file_listing,
|
enrich_file_listing,
|
||||||
ensure_uploads_dir,
|
ensure_uploads_dir,
|
||||||
get_uploads_dir,
|
get_uploads_dir,
|
||||||
list_files_in_dir,
|
list_files_in_dir,
|
||||||
normalize_filename,
|
normalize_filename,
|
||||||
open_upload_file_no_symlink,
|
|
||||||
upload_artifact_url,
|
upload_artifact_url,
|
||||||
upload_virtual_path,
|
upload_virtual_path,
|
||||||
)
|
)
|
||||||
@@ -33,11 +28,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
||||||
|
|
||||||
UPLOAD_CHUNK_SIZE = 8192
|
|
||||||
DEFAULT_MAX_FILES = 10
|
|
||||||
DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
|
|
||||||
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class UploadResponse(BaseModel):
|
class UploadResponse(BaseModel):
|
||||||
"""Response model for file upload."""
|
"""Response model for file upload."""
|
||||||
@@ -45,15 +35,6 @@ class UploadResponse(BaseModel):
|
|||||||
success: bool
|
success: bool
|
||||||
files: list[dict[str, str]]
|
files: list[dict[str, str]]
|
||||||
message: str
|
message: str
|
||||||
skipped_files: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class UploadLimits(BaseModel):
|
|
||||||
"""Application-level upload limits exposed to clients."""
|
|
||||||
|
|
||||||
max_files: int
|
|
||||||
max_file_size: int
|
|
||||||
max_total_size: int
|
|
||||||
|
|
||||||
|
|
||||||
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||||
@@ -74,257 +55,107 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
|
|
||||||
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
|
|
||||||
|
|
||||||
|
|
||||||
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
|
|
||||||
"""Read a value from the uploads config, supporting dict and attribute access."""
|
|
||||||
uploads_cfg = getattr(app_config, "uploads", None)
|
|
||||||
if isinstance(uploads_cfg, dict):
|
|
||||||
return uploads_cfg.get(key, default)
|
|
||||||
return getattr(uploads_cfg, key, default)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int:
|
|
||||||
try:
|
|
||||||
value = _get_uploads_config_value(app_config, key, None)
|
|
||||||
if value is None and legacy_key is not None:
|
|
||||||
value = _get_uploads_config_value(app_config, legacy_key, None)
|
|
||||||
if value is None:
|
|
||||||
value = default
|
|
||||||
limit = int(value)
|
|
||||||
if limit <= 0:
|
|
||||||
raise ValueError
|
|
||||||
return limit
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Invalid uploads.%s value; falling back to %d", key, default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _get_upload_limits(app_config: AppConfig) -> UploadLimits:
|
|
||||||
return UploadLimits(
|
|
||||||
max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"),
|
|
||||||
max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"),
|
|
||||||
max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
|
|
||||||
for path in reversed(paths):
|
|
||||||
try:
|
|
||||||
os.unlink(path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _write_upload_file_with_limits(
|
|
||||||
file: UploadFile,
|
|
||||||
*,
|
|
||||||
uploads_dir: os.PathLike[str] | str,
|
|
||||||
display_filename: str,
|
|
||||||
max_single_file_size: int,
|
|
||||||
max_total_size: int,
|
|
||||||
total_size: int,
|
|
||||||
) -> tuple[os.PathLike[str] | str, int, int]:
|
|
||||||
file_size = 0
|
|
||||||
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
|
|
||||||
try:
|
|
||||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
|
||||||
file_size += len(chunk)
|
|
||||||
total_size += len(chunk)
|
|
||||||
if file_size > max_single_file_size:
|
|
||||||
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
|
|
||||||
if total_size > max_total_size:
|
|
||||||
raise HTTPException(status_code=413, detail="Total upload size too large")
|
|
||||||
fh.write(chunk)
|
|
||||||
except Exception:
|
|
||||||
fh.close()
|
|
||||||
try:
|
|
||||||
os.unlink(file_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
fh.close()
|
|
||||||
return file_path, file_size, total_size
|
|
||||||
|
|
||||||
|
|
||||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
|
||||||
"""Return whether automatic host-side document conversion is enabled.
|
|
||||||
|
|
||||||
The secure default is disabled unless an operator explicitly opts in via
|
|
||||||
uploads.auto_convert_documents in config.yaml.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
|
|
||||||
if isinstance(raw, str):
|
|
||||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
|
||||||
return bool(raw)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
files: list[UploadFile] = File(...),
|
files: list[UploadFile] = File(...),
|
||||||
config: AppConfig = Depends(get_config),
|
|
||||||
) -> UploadResponse:
|
) -> UploadResponse:
|
||||||
"""Upload multiple files to a thread's uploads directory."""
|
"""Upload multiple files to a thread's uploads directory."""
|
||||||
if not files:
|
if not files:
|
||||||
raise HTTPException(status_code=400, detail="No files provided")
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
limits = _get_upload_limits(config)
|
with bind_request_actor_context(request):
|
||||||
if len(files) > limits.max_files:
|
try:
|
||||||
raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}")
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
|
uploaded_files = []
|
||||||
|
|
||||||
try:
|
sandbox_provider = get_sandbox_provider()
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
|
||||||
uploaded_files = []
|
|
||||||
written_paths = []
|
|
||||||
sandbox_sync_targets = []
|
|
||||||
skipped_files = []
|
|
||||||
total_size = 0
|
|
||||||
# Track filenames within this request so duplicate form parts do not
|
|
||||||
# silently truncate each other. Existing uploads keep the historical
|
|
||||||
# overwrite behavior for a single replacement upload.
|
|
||||||
seen_filenames: set[str] = set()
|
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
|
||||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
|
||||||
sandbox = None
|
|
||||||
if sync_to_sandbox:
|
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
if sandbox is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
|
|
||||||
auto_convert_documents = _auto_convert_documents_enabled(config)
|
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
original_filename = normalize_filename(file.filename)
|
safe_filename = normalize_filename(file.filename)
|
||||||
safe_filename = claim_unique_filename(original_filename, seen_filenames)
|
except ValueError:
|
||||||
except ValueError:
|
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_path, file_size, total_size = await _write_upload_file_with_limits(
|
content = await file.read()
|
||||||
file,
|
file_path = uploads_dir / safe_filename
|
||||||
uploads_dir=uploads_dir,
|
file_path.write_bytes(content)
|
||||||
display_filename=safe_filename,
|
|
||||||
max_single_file_size=limits.max_file_size,
|
|
||||||
max_total_size=limits.max_total_size,
|
|
||||||
total_size=total_size,
|
|
||||||
)
|
|
||||||
written_paths.append(file_path)
|
|
||||||
|
|
||||||
virtual_path = upload_virtual_path(safe_filename)
|
virtual_path = upload_virtual_path(safe_filename)
|
||||||
|
|
||||||
if sync_to_sandbox:
|
if sandbox_id != "local":
|
||||||
sandbox_sync_targets.append((file_path, virtual_path))
|
_make_file_sandbox_writable(file_path)
|
||||||
|
sandbox.update_file(virtual_path, content)
|
||||||
|
|
||||||
file_info = {
|
file_info = {
|
||||||
"filename": safe_filename,
|
"filename": safe_filename,
|
||||||
"size": str(file_size),
|
"size": str(len(content)),
|
||||||
"path": str(sandbox_uploads / safe_filename),
|
"path": str(sandbox_uploads / safe_filename),
|
||||||
"virtual_path": virtual_path,
|
"virtual_path": virtual_path,
|
||||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||||
}
|
}
|
||||||
if safe_filename != original_filename:
|
|
||||||
file_info["original_filename"] = original_filename
|
|
||||||
|
|
||||||
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
|
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||||
|
|
||||||
file_ext = file_path.suffix.lower()
|
file_ext = file_path.suffix.lower()
|
||||||
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
|
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||||
md_path = await convert_file_to_markdown(file_path)
|
md_path = await convert_file_to_markdown(file_path)
|
||||||
if md_path:
|
if md_path:
|
||||||
written_paths.append(md_path)
|
md_virtual_path = upload_virtual_path(md_path.name)
|
||||||
md_virtual_path = upload_virtual_path(md_path.name)
|
|
||||||
|
|
||||||
if sync_to_sandbox:
|
if sandbox_id != "local":
|
||||||
sandbox_sync_targets.append((md_path, md_virtual_path))
|
_make_file_sandbox_writable(md_path)
|
||||||
|
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||||
|
|
||||||
file_info["markdown_file"] = md_path.name
|
file_info["markdown_file"] = md_path.name
|
||||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||||
file_info["markdown_virtual_path"] = md_virtual_path
|
file_info["markdown_virtual_path"] = md_virtual_path
|
||||||
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
||||||
|
|
||||||
uploaded_files.append(file_info)
|
uploaded_files.append(file_info)
|
||||||
|
|
||||||
except HTTPException as e:
|
except Exception as e:
|
||||||
_cleanup_uploaded_paths(written_paths)
|
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||||
raise e
|
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||||
except UnsafeUploadPathError as e:
|
|
||||||
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
|
|
||||||
skipped_files.append(safe_filename)
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
|
||||||
_cleanup_uploaded_paths(written_paths)
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
|
||||||
|
|
||||||
if sync_to_sandbox:
|
|
||||||
for file_path, virtual_path in sandbox_sync_targets:
|
|
||||||
_make_file_sandbox_writable(file_path)
|
|
||||||
sandbox.update_file(virtual_path, file_path.read_bytes())
|
|
||||||
|
|
||||||
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
|
|
||||||
if skipped_files:
|
|
||||||
message += f"; skipped {len(skipped_files)} unsafe file(s)"
|
|
||||||
|
|
||||||
return UploadResponse(
|
return UploadResponse(
|
||||||
success=not skipped_files,
|
success=True,
|
||||||
files=uploaded_files,
|
files=uploaded_files,
|
||||||
message=message,
|
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||||
skipped_files=skipped_files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/limits", response_model=UploadLimits)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_upload_limits(
|
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
config: AppConfig = Depends(get_config),
|
|
||||||
) -> UploadLimits:
|
|
||||||
"""Return upload limits used by the gateway for this thread."""
|
|
||||||
return _get_upload_limits(config)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=dict)
|
@router.get("/list", response_model=dict)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||||
"""List all files in a thread's uploads directory."""
|
"""List all files in a thread's uploads directory."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
uploads_dir = get_uploads_dir(thread_id)
|
try:
|
||||||
except ValueError as e:
|
uploads_dir = get_uploads_dir(thread_id)
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
except ValueError as e:
|
||||||
result = list_files_in_dir(uploads_dir)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
enrich_file_listing(result, thread_id)
|
result = list_files_in_dir(uploads_dir)
|
||||||
|
enrich_file_listing(result, thread_id)
|
||||||
|
|
||||||
# Gateway additionally includes the sandbox-relative path.
|
# Gateway additionally includes the sandbox-relative path.
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
for f in result["files"]:
|
for f in result["files"]:
|
||||||
f["path"] = str(sandbox_uploads / f["filename"])
|
f["path"] = str(sandbox_uploads / f["filename"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{filename}")
|
@router.delete("/{filename}")
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
||||||
"""Delete a file from a thread's uploads directory."""
|
"""Delete a file from a thread's uploads directory."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,404 +0,0 @@
|
|||||||
"""Run lifecycle service layer.
|
|
||||||
|
|
||||||
Centralizes the business logic for creating runs, formatting SSE
|
|
||||||
frames, and consuming stream bridge events. Router modules
|
|
||||||
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
from langchain_core.messages.utils import convert_to_messages
|
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
|
||||||
from app.gateway.utils import sanitize_log_param
|
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
from deerflow.runtime import (
|
|
||||||
END_SENTINEL,
|
|
||||||
HEARTBEAT_SENTINEL,
|
|
||||||
ConflictError,
|
|
||||||
DisconnectMode,
|
|
||||||
RunManager,
|
|
||||||
RunRecord,
|
|
||||||
RunStatus,
|
|
||||||
StreamBridge,
|
|
||||||
UnsupportedStrategyError,
|
|
||||||
run_agent,
|
|
||||||
)
|
|
||||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# SSE formatting
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
|
|
||||||
"""Format a single SSE frame.
|
|
||||||
|
|
||||||
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
|
|
||||||
This matches the LangGraph Platform wire format consumed by the
|
|
||||||
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
|
|
||||||
"""
|
|
||||||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
|
||||||
parts = [f"event: {event}", f"data: {payload}"]
|
|
||||||
if event_id:
|
|
||||||
parts.append(f"id: {event_id}")
|
|
||||||
parts.append("")
|
|
||||||
parts.append("")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Input / config helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
|
||||||
"""Normalize the stream_mode parameter to a list.
|
|
||||||
|
|
||||||
Default matches what ``useStream`` expects: values + messages-tuple.
|
|
||||||
"""
|
|
||||||
if raw is None:
|
|
||||||
return ["values"]
|
|
||||||
if isinstance(raw, str):
|
|
||||||
return [raw]
|
|
||||||
return raw if raw else ["values"]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|
||||||
"""Convert LangGraph Platform input format to LangChain state dict.
|
|
||||||
|
|
||||||
Delegates dict→message coercion to ``langchain_core.messages.utils.convert_to_messages``
|
|
||||||
so that ``additional_kwargs`` (e.g. uploaded-file metadata — gh #3132), ``id``,
|
|
||||||
``name``, and non-human roles (ai/system/tool) survive unchanged. An earlier
|
|
||||||
hand-rolled version only forwarded ``content`` and collapsed every role to
|
|
||||||
``HumanMessage``, which silently stripped frontend-supplied attachments.
|
|
||||||
|
|
||||||
Malformed message dicts (missing ``role``/``type``/``content``, unsupported
|
|
||||||
role, etc.) raise ``HTTPException(400)`` with the offending index, instead
|
|
||||||
of bubbling up as a 500. The gateway is a system boundary, so per-entry
|
|
||||||
validation errors are the right shape for clients to retry against.
|
|
||||||
"""
|
|
||||||
if raw_input is None:
|
|
||||||
return {}
|
|
||||||
messages = raw_input.get("messages")
|
|
||||||
if messages and isinstance(messages, list):
|
|
||||||
converted: list[Any] = []
|
|
||||||
for index, msg in enumerate(messages):
|
|
||||||
if isinstance(msg, BaseMessage):
|
|
||||||
converted.append(msg)
|
|
||||||
elif isinstance(msg, dict):
|
|
||||||
try:
|
|
||||||
converted.extend(convert_to_messages([msg]))
|
|
||||||
except (ValueError, TypeError, NotImplementedError) as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid message at input.messages[{index}]: {exc}",
|
|
||||||
) from exc
|
|
||||||
else:
|
|
||||||
converted.append(msg)
|
|
||||||
return {**raw_input, "messages": converted}
|
|
||||||
return raw_input
|
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
|
||||||
|
|
||||||
|
|
||||||
# Whitelist of run-context keys that the langgraph-compat layer forwards from
|
|
||||||
# ``body.context`` into the run config. ``config["context"]`` exists in
|
|
||||||
# LangGraph >=0.6, but these values must be written to both ``configurable``
|
|
||||||
# (for legacy ``_get_runtime_config`` consumers) and ``context`` because
|
|
||||||
# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to
|
|
||||||
# ``configurable`` for consumers like ``setup_agent``.
|
|
||||||
_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"model_name",
|
|
||||||
"mode",
|
|
||||||
"thinking_enabled",
|
|
||||||
"reasoning_effort",
|
|
||||||
"is_plan_mode",
|
|
||||||
"subagent_enabled",
|
|
||||||
"max_concurrent_subagents",
|
|
||||||
"agent_name",
|
|
||||||
"is_bootstrap",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None:
|
|
||||||
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
|
||||||
and ``config['context']`` so they are visible to legacy configurable readers and
|
|
||||||
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
|
||||||
see issue #2677)."""
|
|
||||||
if not context:
|
|
||||||
return
|
|
||||||
configurable = config.setdefault("configurable", {})
|
|
||||||
runtime_context = config.setdefault("context", {})
|
|
||||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
|
||||||
if key in context:
|
|
||||||
if isinstance(configurable, dict):
|
|
||||||
configurable.setdefault(key, context[key])
|
|
||||||
if isinstance(runtime_context, dict):
|
|
||||||
runtime_context.setdefault(key, context[key])
|
|
||||||
|
|
||||||
|
|
||||||
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
|
||||||
"""Stamp the authenticated user into the run context for background tools.
|
|
||||||
|
|
||||||
Tool execution may happen after the request handler has returned, so tools
|
|
||||||
that persist user-scoped files should not rely only on ambient ContextVars.
|
|
||||||
The value comes from server-side auth state, never from client context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = getattr(request.state, "user", None)
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
if user_id is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
runtime_context = config.setdefault("context", {})
|
|
||||||
if isinstance(runtime_context, dict):
|
|
||||||
runtime_context["user_id"] = str(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_agent_factory(assistant_id: str | None):
|
|
||||||
"""Resolve the agent factory callable from config.
|
|
||||||
|
|
||||||
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
|
||||||
injected into ``configurable`` or ``context`` — see
|
|
||||||
:func:`build_run_config`. All ``assistant_id`` values therefore map to the
|
|
||||||
same factory; the routing happens inside ``make_lead_agent`` when it reads
|
|
||||||
``cfg["agent_name"]``.
|
|
||||||
"""
|
|
||||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
|
||||||
|
|
||||||
return make_lead_agent
|
|
||||||
|
|
||||||
|
|
||||||
def build_run_config(
|
|
||||||
thread_id: str,
|
|
||||||
request_config: dict[str, Any] | None,
|
|
||||||
metadata: dict[str, Any] | None,
|
|
||||||
*,
|
|
||||||
assistant_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Build a RunnableConfig dict for the agent.
|
|
||||||
|
|
||||||
When *assistant_id* refers to a custom agent (anything other than
|
|
||||||
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
|
|
||||||
whichever runtime options container is active: ``context`` for
|
|
||||||
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
|
|
||||||
``make_lead_agent`` reads this key to load the matching
|
|
||||||
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
|
|
||||||
silently runs as the default lead agent.
|
|
||||||
|
|
||||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
|
||||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
|
||||||
identically.
|
|
||||||
"""
|
|
||||||
config: dict[str, Any] = {"recursion_limit": 100}
|
|
||||||
if request_config:
|
|
||||||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
|
||||||
# pass thread-level data and rejects requests that include both
|
|
||||||
# ``configurable`` and ``context``. If the caller already sends
|
|
||||||
# ``context``, honour it and skip our own ``configurable`` dict.
|
|
||||||
if "context" in request_config:
|
|
||||||
if "configurable" in request_config:
|
|
||||||
logger.warning(
|
|
||||||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
|
||||||
thread_id,
|
|
||||||
list(request_config.get("configurable", {}).keys()),
|
|
||||||
)
|
|
||||||
context_value = request_config["context"]
|
|
||||||
if context_value is None:
|
|
||||||
context = {}
|
|
||||||
elif isinstance(context_value, Mapping):
|
|
||||||
context = dict(context_value)
|
|
||||||
else:
|
|
||||||
raise ValueError("request config 'context' must be a mapping or null.")
|
|
||||||
config["context"] = context
|
|
||||||
else:
|
|
||||||
configurable = {"thread_id": thread_id}
|
|
||||||
configurable.update(request_config.get("configurable", {}))
|
|
||||||
config["configurable"] = configurable
|
|
||||||
for k, v in request_config.items():
|
|
||||||
if k not in ("configurable", "context"):
|
|
||||||
config[k] = v
|
|
||||||
else:
|
|
||||||
config["configurable"] = {"thread_id": thread_id}
|
|
||||||
|
|
||||||
# Inject custom agent name when the caller specified a non-default assistant.
|
|
||||||
# Honour an explicit agent_name in the active runtime options container.
|
|
||||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
|
|
||||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
|
||||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
|
||||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
|
||||||
if "configurable" in config:
|
|
||||||
target = config["configurable"]
|
|
||||||
elif "context" in config:
|
|
||||||
target = config["context"]
|
|
||||||
else:
|
|
||||||
target = config.setdefault("configurable", {})
|
|
||||||
if target is not None and "agent_name" not in target:
|
|
||||||
target["agent_name"] = normalized
|
|
||||||
config.setdefault("run_name", resolve_root_run_name(config, normalized))
|
|
||||||
if metadata:
|
|
||||||
config.setdefault("metadata", {}).update(metadata)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run lifecycle
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def start_run(
|
|
||||||
body: Any,
|
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
) -> RunRecord:
|
|
||||||
"""Create a RunRecord and launch the background agent task.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
body : RunCreateRequest
|
|
||||||
The validated request body (typed as Any to avoid circular import
|
|
||||||
with the router module that defines the Pydantic model).
|
|
||||||
thread_id : str
|
|
||||||
Target thread.
|
|
||||||
request : Request
|
|
||||||
FastAPI request — used to retrieve singletons from ``app.state``.
|
|
||||||
"""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
run_ctx = get_run_context(request)
|
|
||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
|
||||||
|
|
||||||
body_context = getattr(body, "context", None) or {}
|
|
||||||
model_name = body_context.get("model_name")
|
|
||||||
|
|
||||||
# Coerce non-string model_name values to str before truncation.
|
|
||||||
if model_name is not None and not isinstance(model_name, str):
|
|
||||||
model_name = str(model_name)
|
|
||||||
|
|
||||||
# Validate model against the allowlist when a model_name is provided.
|
|
||||||
if model_name:
|
|
||||||
app_config = get_app_config()
|
|
||||||
resolved = app_config.get_model_config(model_name)
|
|
||||||
if resolved is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
record = await run_mgr.create_or_reject(
|
|
||||||
thread_id,
|
|
||||||
body.assistant_id,
|
|
||||||
on_disconnect=disconnect,
|
|
||||||
metadata=body.metadata or {},
|
|
||||||
kwargs={"input": body.input, "config": body.config},
|
|
||||||
multitask_strategy=body.multitask_strategy,
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
except ConflictError as exc:
|
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
|
||||||
except UnsupportedStrategyError as exc:
|
|
||||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
||||||
|
|
||||||
# Upsert thread metadata so the thread appears in /threads/search,
|
|
||||||
# even for threads that were never explicitly created via POST /threads
|
|
||||||
# (e.g. stateless runs).
|
|
||||||
try:
|
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
|
||||||
if existing is None:
|
|
||||||
await run_ctx.thread_store.create(
|
|
||||||
thread_id,
|
|
||||||
assistant_id=body.assistant_id,
|
|
||||||
metadata=body.metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
|
||||||
graph_input = normalize_input(body.input)
|
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
|
||||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
|
||||||
inject_authenticated_user_context(config, request)
|
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
run_agent(
|
|
||||||
bridge,
|
|
||||||
run_mgr,
|
|
||||||
record,
|
|
||||||
ctx=run_ctx,
|
|
||||||
agent_factory=agent_factory,
|
|
||||||
graph_input=graph_input,
|
|
||||||
config=config,
|
|
||||||
stream_modes=stream_modes,
|
|
||||||
stream_subgraphs=body.stream_subgraphs,
|
|
||||||
interrupt_before=body.interrupt_before,
|
|
||||||
interrupt_after=body.interrupt_after,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record.task = task
|
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
|
||||||
# title from the checkpoint and calls thread_store.update_display_name
|
|
||||||
# after the run completes.
|
|
||||||
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def sse_consumer(
|
|
||||||
bridge: StreamBridge,
|
|
||||||
record: RunRecord,
|
|
||||||
request: Request,
|
|
||||||
run_mgr: RunManager,
|
|
||||||
):
|
|
||||||
"""Async generator that yields SSE frames from the bridge.
|
|
||||||
|
|
||||||
The ``finally`` block implements ``on_disconnect`` semantics:
|
|
||||||
- ``cancel``: abort the background task on client disconnect.
|
|
||||||
- ``continue``: let the task run; events are discarded.
|
|
||||||
"""
|
|
||||||
last_event_id = request.headers.get("Last-Event-ID")
|
|
||||||
try:
|
|
||||||
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
|
|
||||||
if await request.is_disconnected():
|
|
||||||
break
|
|
||||||
|
|
||||||
if entry is HEARTBEAT_SENTINEL:
|
|
||||||
yield ": heartbeat\n\n"
|
|
||||||
continue
|
|
||||||
|
|
||||||
if entry is END_SENTINEL:
|
|
||||||
yield format_sse("end", None, event_id=entry.id or None)
|
|
||||||
return
|
|
||||||
|
|
||||||
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if record.status in (RunStatus.pending, RunStatus.running):
|
|
||||||
if record.on_disconnect == DisconnectMode.cancel:
|
|
||||||
await run_mgr.cancel(record.run_id)
|
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""Gateway service layer."""
|
||||||
|
|
||||||
|
"""Compatibility package for app service submodules."""
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Runs app layer services."""
|
||||||
|
|
||||||
|
from app.infra.storage import StorageRunObserver
|
||||||
|
from .input import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
RunSpecBuilder,
|
||||||
|
UnsupportedRunFeatureError,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdaptedRunRequest",
|
||||||
|
"AppRunCreateStore",
|
||||||
|
"AppRunDeleteStore",
|
||||||
|
"AppRunQueryStore",
|
||||||
|
"RunSpecBuilder",
|
||||||
|
"StorageRunObserver",
|
||||||
|
"UnsupportedRunFeatureError",
|
||||||
|
"adapt_create_run_request",
|
||||||
|
"adapt_create_stream_request",
|
||||||
|
"adapt_create_wait_request",
|
||||||
|
"adapt_join_stream_request",
|
||||||
|
"adapt_join_wait_request",
|
||||||
|
]
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
"""Facade factory - assembles RunsFacade with dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.gateway.dependencies import get_checkpointer, get_stream_bridge
|
||||||
|
from deerflow.runtime.runs.facade import RunsFacade
|
||||||
|
from deerflow.runtime.runs.facade import RunsRuntime
|
||||||
|
from deerflow.runtime.runs.internal.execution.supervisor import RunSupervisor
|
||||||
|
from deerflow.runtime.runs.internal.planner import ExecutionPlanner
|
||||||
|
from deerflow.runtime.runs.internal.registry import RunRegistry
|
||||||
|
from deerflow.runtime.runs.internal.streams import RunStreamService
|
||||||
|
from deerflow.runtime.runs.internal.wait import RunWaitService
|
||||||
|
|
||||||
|
from app.infra.storage import StorageRunObserver, ThreadMetaStorage
|
||||||
|
from app.infra.storage.runs import RunDeleteRepository, RunReadRepository, RunWriteRepository
|
||||||
|
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deerflow.runtime.stream_bridge import StreamBridge
|
||||||
|
|
||||||
|
|
||||||
|
type AgentFactory = Callable[..., object]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton registry (shared across requests)
|
||||||
|
_registry: RunRegistry | None = None
|
||||||
|
_supervisor: RunSupervisor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_state(request: Request, attr: str, label: str):
|
||||||
|
value = getattr(request.app.state, attr, None)
|
||||||
|
if value is None:
|
||||||
|
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry() -> RunRegistry:
|
||||||
|
"""Get or create singleton registry."""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
_registry = RunRegistry()
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_supervisor() -> RunSupervisor:
|
||||||
|
"""Get or create singleton run supervisor."""
|
||||||
|
global _supervisor
|
||||||
|
if _supervisor is None:
|
||||||
|
_supervisor = RunSupervisor()
|
||||||
|
return _supervisor
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_agent_factory(assistant_id: str | None) -> AgentFactory:
|
||||||
|
"""Resolve the agent factory callable from config."""
|
||||||
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||||
|
|
||||||
|
return make_lead_agent
|
||||||
|
|
||||||
|
|
||||||
|
def build_runs_facade(
|
||||||
|
*,
|
||||||
|
stream_bridge: "StreamBridge",
|
||||||
|
checkpointer: object,
|
||||||
|
store: object | None = None,
|
||||||
|
run_read_repo: RunReadRepository | None = None,
|
||||||
|
run_write_repo: RunWriteRepository | None = None,
|
||||||
|
run_delete_repo: RunDeleteRepository | None = None,
|
||||||
|
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||||
|
run_event_store: object | None = None,
|
||||||
|
) -> RunsFacade:
|
||||||
|
"""
|
||||||
|
Build RunsFacade with all dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_bridge: StreamBridge instance
|
||||||
|
checkpointer: LangGraph checkpointer
|
||||||
|
store: Optional LangGraph runtime store
|
||||||
|
run_read_repo: Optional run repository for durable reads
|
||||||
|
run_write_repo: Optional run repository for durable writes
|
||||||
|
run_delete_repo: Optional run repository for durable deletes
|
||||||
|
thread_meta_storage: Optional thread metadata storage adapter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured RunsFacade instance
|
||||||
|
"""
|
||||||
|
registry = get_registry()
|
||||||
|
planner = ExecutionPlanner()
|
||||||
|
supervisor = get_supervisor()
|
||||||
|
|
||||||
|
stream_service = RunStreamService(stream_bridge)
|
||||||
|
wait_service = RunWaitService(stream_service)
|
||||||
|
query_store = AppRunQueryStore(run_read_repo) if run_read_repo else None
|
||||||
|
create_store = (
|
||||||
|
AppRunCreateStore(run_write_repo, thread_meta_storage=thread_meta_storage)
|
||||||
|
if run_write_repo
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
delete_store = AppRunDeleteStore(run_delete_repo) if run_delete_repo else None
|
||||||
|
|
||||||
|
# Build storage observer if repositories provided
|
||||||
|
storage_observer = None
|
||||||
|
if run_write_repo or thread_meta_storage:
|
||||||
|
storage_observer = StorageRunObserver(
|
||||||
|
run_write_repo=run_write_repo,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunsFacade(
|
||||||
|
registry=registry,
|
||||||
|
planner=planner,
|
||||||
|
supervisor=supervisor,
|
||||||
|
stream_service=stream_service,
|
||||||
|
wait_service=wait_service,
|
||||||
|
runtime=RunsRuntime(
|
||||||
|
bridge=stream_bridge,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
store=store,
|
||||||
|
event_store=run_event_store,
|
||||||
|
agent_factory_resolver=resolve_agent_factory,
|
||||||
|
),
|
||||||
|
observer=storage_observer,
|
||||||
|
query_store=query_store,
|
||||||
|
create_store=create_store,
|
||||||
|
delete_store=delete_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_runs_facade_from_request(request: "Request") -> RunsFacade:
|
||||||
|
"""
|
||||||
|
Build RunsFacade from FastAPI request context.
|
||||||
|
|
||||||
|
Extracts dependencies from request.app.state.
|
||||||
|
"""
|
||||||
|
app_state = request.app.state
|
||||||
|
|
||||||
|
return build_runs_facade(
|
||||||
|
stream_bridge=get_stream_bridge(request),
|
||||||
|
checkpointer=get_checkpointer(request),
|
||||||
|
store=getattr(request.app.state, "store", None),
|
||||||
|
run_read_repo=getattr(app_state, "run_read_repo", None),
|
||||||
|
run_write_repo=getattr(app_state, "run_write_repo", None),
|
||||||
|
run_delete_repo=getattr(app_state, "run_delete_repo", None),
|
||||||
|
thread_meta_storage=getattr(app_state, "thread_meta_storage", None),
|
||||||
|
run_event_store=getattr(app_state, "run_event_store", None),
|
||||||
|
)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""Input adapters for app-owned runs entrypoints."""
|
||||||
|
|
||||||
|
from .request_adapter import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from .spec_builder import RunSpecBuilder, UnsupportedRunFeatureError
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdaptedRunRequest",
|
||||||
|
"RunSpecBuilder",
|
||||||
|
"UnsupportedRunFeatureError",
|
||||||
|
"adapt_create_run_request",
|
||||||
|
"adapt_create_stream_request",
|
||||||
|
"adapt_create_wait_request",
|
||||||
|
"adapt_join_stream_request",
|
||||||
|
"adapt_join_wait_request",
|
||||||
|
]
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
"""App-owned request adapter for runs entrypoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue
|
||||||
|
from deerflow.runtime.runs.types import RunIntent
|
||||||
|
|
||||||
|
type RequestBody = dict[str, JSONValue]
|
||||||
|
type RequestQuery = dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AdaptedRunRequest:
|
||||||
|
"""
|
||||||
|
统一的内部请求 DTO.
|
||||||
|
|
||||||
|
路由层只负责提取 path/query/body,适配器负责转成稳定内部结构。
|
||||||
|
"""
|
||||||
|
|
||||||
|
intent: RunIntent
|
||||||
|
thread_id: str | None
|
||||||
|
run_id: str | None
|
||||||
|
body: RequestBody
|
||||||
|
headers: dict[str, str]
|
||||||
|
query: RequestQuery
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_event_id(self) -> str | None:
|
||||||
|
"""Extract Last-Event-ID from headers."""
|
||||||
|
return self.headers.get("last-event-id") or self.headers.get("Last-Event-ID")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stateless(self) -> bool:
|
||||||
|
"""Check if this is a stateless request."""
|
||||||
|
return self.thread_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_create_run_request(
|
||||||
|
*,
|
||||||
|
thread_id: str | None,
|
||||||
|
body: RequestBody,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt POST /threads/{thread_id}/runs or POST /runs."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="create_background",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=None,
|
||||||
|
body=body,
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_create_stream_request(
|
||||||
|
*,
|
||||||
|
thread_id: str | None,
|
||||||
|
body: RequestBody,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt POST /threads/{thread_id}/runs/stream or POST /runs/stream."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="create_and_stream",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=None,
|
||||||
|
body=body,
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_create_wait_request(
|
||||||
|
*,
|
||||||
|
thread_id: str | None,
|
||||||
|
body: RequestBody,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt POST /threads/{thread_id}/runs/wait or POST /runs/wait."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="create_and_wait",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=None,
|
||||||
|
body=body,
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_join_stream_request(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt GET /threads/{thread_id}/runs/{run_id}/stream."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="join_stream",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
body={},
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_join_wait_request(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt GET /threads/{thread_id}/runs/{run_id}/join."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="join_wait",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
body={},
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
@@ -0,0 +1,254 @@
|
|||||||
|
"""App-owned RunSpec builder."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.types import CheckpointRequest, RunScope, RunSpec
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue
|
||||||
|
|
||||||
|
from .request_adapter import AdaptedRunRequest
|
||||||
|
|
||||||
|
type JSONMapping = dict[str, JSONValue]
|
||||||
|
type GraphInput = dict[str, object]
|
||||||
|
type RunnableConfigDict = dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedRunFeatureError(ValueError):
|
||||||
|
"""Raised when a phase1-unsupported feature is requested."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RunSpecBuilder:
|
||||||
|
"""
|
||||||
|
Build RunSpec from AdaptedRunRequest.
|
||||||
|
|
||||||
|
Phase 1 rules:
|
||||||
|
1. messages-tuple normalized to messages
|
||||||
|
2. enqueue not supported
|
||||||
|
3. rollback not supported
|
||||||
|
4. after_seconds not supported
|
||||||
|
5. stream_resumable accepted
|
||||||
|
6. stateless auto-generates temporary thread
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Phase 1 unsupported features
|
||||||
|
UNSUPPORTED_MULTITASK_STRATEGIES = {"enqueue"}
|
||||||
|
UNSUPPORTED_ACTIONS = {"rollback"}
|
||||||
|
|
||||||
|
# Default stream modes
|
||||||
|
DEFAULT_STREAM_MODES = ["values", "messages"]
|
||||||
|
CONTEXT_CONFIGURABLE_KEYS = frozenset({
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
})
|
||||||
|
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _as_json_mapping(value: JSONValue | None) -> JSONMapping | None:
|
||||||
|
return value if isinstance(value, dict) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _as_string_list(value: JSONValue | None) -> list[str] | None:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return None
|
||||||
|
return [item for item in value if isinstance(item, str)]
|
||||||
|
|
||||||
|
def build(self, request: AdaptedRunRequest) -> RunSpec:
|
||||||
|
"""Build RunSpec from adapted request."""
|
||||||
|
body = request.body
|
||||||
|
|
||||||
|
# Validate phase1 constraints
|
||||||
|
self._validate_constraints(body)
|
||||||
|
|
||||||
|
# Build scope
|
||||||
|
scope = self._build_scope(request)
|
||||||
|
|
||||||
|
# Normalize stream modes
|
||||||
|
stream_modes = self._normalize_stream_modes(body.get("stream_mode"))
|
||||||
|
|
||||||
|
# Build checkpoint request
|
||||||
|
checkpoint_request = self._build_checkpoint_request(body)
|
||||||
|
|
||||||
|
config = self._build_runnable_config(
|
||||||
|
thread_id=scope.thread_id,
|
||||||
|
request_config=self._as_json_mapping(body.get("config")),
|
||||||
|
metadata=self._as_json_mapping(body.get("metadata")),
|
||||||
|
assistant_id=body.get("assistant_id"),
|
||||||
|
context=self._as_json_mapping(body.get("context")),
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunSpec(
|
||||||
|
intent=request.intent,
|
||||||
|
scope=scope,
|
||||||
|
assistant_id=body.get("assistant_id") if isinstance(body.get("assistant_id"), str) else None,
|
||||||
|
input=self._normalize_input(self._as_json_mapping(body.get("input"))),
|
||||||
|
command=self._as_json_mapping(body.get("command")),
|
||||||
|
runnable_config=config,
|
||||||
|
context=self._as_json_mapping(body.get("context")),
|
||||||
|
metadata=self._as_json_mapping(body.get("metadata")) or {},
|
||||||
|
stream_modes=stream_modes,
|
||||||
|
stream_subgraphs=bool(body.get("stream_subgraphs", False)),
|
||||||
|
stream_resumable=bool(body.get("stream_resumable", False)),
|
||||||
|
on_disconnect=body.get("on_disconnect", "cancel") if body.get("on_disconnect") in {"cancel", "continue"} else "cancel",
|
||||||
|
on_completion=body.get("on_completion", "keep") if body.get("on_completion") in {"delete", "keep"} else "keep",
|
||||||
|
multitask_strategy=body.get("multitask_strategy", "reject") if body.get("multitask_strategy") in {"reject", "interrupt"} else "reject",
|
||||||
|
interrupt_before="*" if body.get("interrupt_before") == "*" else self._as_string_list(body.get("interrupt_before")),
|
||||||
|
interrupt_after="*" if body.get("interrupt_after") == "*" else self._as_string_list(body.get("interrupt_after")),
|
||||||
|
checkpoint_request=checkpoint_request,
|
||||||
|
follow_up_to_run_id=body.get("follow_up_to_run_id") if isinstance(body.get("follow_up_to_run_id"), str) else None,
|
||||||
|
webhook=body.get("webhook") if isinstance(body.get("webhook"), str) else None,
|
||||||
|
feedback_keys=self._as_string_list(body.get("feedback_keys")),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_constraints(self, body: JSONMapping) -> None:
|
||||||
|
"""Validate phase1 constraints, raise UnsupportedRunFeatureError if violated."""
|
||||||
|
# Check multitask_strategy
|
||||||
|
strategy = body.get("multitask_strategy", "reject")
|
||||||
|
if strategy in self.UNSUPPORTED_MULTITASK_STRATEGIES:
|
||||||
|
raise UnsupportedRunFeatureError(
|
||||||
|
f"multitask_strategy '{strategy}' is not supported in phase1. "
|
||||||
|
f"Supported: reject, interrupt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for rollback action
|
||||||
|
command = self._as_json_mapping(body.get("command")) or {}
|
||||||
|
if command.get("action") in self.UNSUPPORTED_ACTIONS:
|
||||||
|
raise UnsupportedRunFeatureError(
|
||||||
|
f"action '{command.get('action')}' is not supported in phase1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for after_seconds
|
||||||
|
if body.get("after_seconds") is not None:
|
||||||
|
raise UnsupportedRunFeatureError("after_seconds is not supported in phase1")
|
||||||
|
|
||||||
|
def _build_scope(self, request: AdaptedRunRequest) -> RunScope:
|
||||||
|
"""Build RunScope from request."""
|
||||||
|
if request.is_stateless:
|
||||||
|
# Stateless: generate temporary thread
|
||||||
|
return RunScope(
|
||||||
|
kind="stateless",
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
temporary=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert request.thread_id is not None
|
||||||
|
return RunScope(
|
||||||
|
kind="stateful",
|
||||||
|
thread_id=request.thread_id,
|
||||||
|
temporary=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_stream_modes(self, stream_mode: JSONValue | None) -> list[str]:
|
||||||
|
"""Normalize stream_mode to list, convert messages-tuple to messages."""
|
||||||
|
if stream_mode is None:
|
||||||
|
return self.DEFAULT_STREAM_MODES.copy()
|
||||||
|
|
||||||
|
if isinstance(stream_mode, str):
|
||||||
|
modes = [stream_mode]
|
||||||
|
elif isinstance(stream_mode, list):
|
||||||
|
modes = [mode for mode in stream_mode if isinstance(mode, str)]
|
||||||
|
else:
|
||||||
|
return self.DEFAULT_STREAM_MODES.copy()
|
||||||
|
|
||||||
|
return ["messages" if m == "messages-tuple" else m for m in modes]
|
||||||
|
|
||||||
|
def _build_checkpoint_request(self, body: JSONMapping) -> CheckpointRequest | None:
|
||||||
|
"""Build CheckpointRequest if checkpoint data is provided."""
|
||||||
|
checkpoint_id = body.get("checkpoint_id")
|
||||||
|
checkpoint = self._as_json_mapping(body.get("checkpoint"))
|
||||||
|
|
||||||
|
if not isinstance(checkpoint_id, str) and checkpoint is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return CheckpointRequest(
|
||||||
|
checkpoint_id=checkpoint_id if isinstance(checkpoint_id, str) else None,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_input(self, raw_input: JSONMapping | None) -> GraphInput | None:
|
||||||
|
"""Convert HTTP-friendly message dicts into LangChain message objects."""
|
||||||
|
if raw_input is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages = raw_input.get("messages")
|
||||||
|
if not messages or not isinstance(messages, list):
|
||||||
|
return raw_input
|
||||||
|
|
||||||
|
converted: list[object] = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
role = msg.get("role", msg.get("type", "user"))
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role in ("user", "human"):
|
||||||
|
converted.append(HumanMessage(content=content))
|
||||||
|
else:
|
||||||
|
converted.append(HumanMessage(content=content))
|
||||||
|
else:
|
||||||
|
converted.append(msg)
|
||||||
|
return {**raw_input, "messages": converted}
|
||||||
|
|
||||||
|
def _build_runnable_config(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
request_config: JSONMapping | None,
|
||||||
|
metadata: JSONMapping | None,
|
||||||
|
assistant_id: str | None,
|
||||||
|
context: JSONMapping | None,
|
||||||
|
) -> RunnableConfigDict:
|
||||||
|
"""Build RunnableConfig from request payload and app-side rules."""
|
||||||
|
config: RunnableConfigDict = {"recursion_limit": 100}
|
||||||
|
|
||||||
|
if request_config:
|
||||||
|
if "context" in request_config:
|
||||||
|
config["context"] = request_config["context"]
|
||||||
|
else:
|
||||||
|
configurable = {"thread_id": thread_id}
|
||||||
|
raw_configurable = request_config.get("configurable")
|
||||||
|
if isinstance(raw_configurable, dict):
|
||||||
|
configurable.update(raw_configurable)
|
||||||
|
config["configurable"] = configurable
|
||||||
|
|
||||||
|
for key, value in request_config.items():
|
||||||
|
if key not in ("configurable", "context"):
|
||||||
|
config[key] = value
|
||||||
|
else:
|
||||||
|
config["configurable"] = {"thread_id": thread_id}
|
||||||
|
|
||||||
|
configurable = config.get("configurable")
|
||||||
|
if (
|
||||||
|
assistant_id
|
||||||
|
and assistant_id != self.DEFAULT_ASSISTANT_ID
|
||||||
|
and isinstance(configurable, dict)
|
||||||
|
and "agent_name" not in configurable
|
||||||
|
):
|
||||||
|
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||||
|
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization."
|
||||||
|
)
|
||||||
|
configurable["agent_name"] = normalized
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
existing_metadata = config.get("metadata")
|
||||||
|
if isinstance(existing_metadata, dict):
|
||||||
|
existing_metadata.update(metadata)
|
||||||
|
else:
|
||||||
|
config["metadata"] = dict(metadata)
|
||||||
|
|
||||||
|
if context and isinstance(configurable, dict):
|
||||||
|
for key in self.CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
|
||||||
|
return config
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""Compatibility wrapper for the app-owned storage observer."""
|
||||||
|
|
||||||
|
from app.infra.storage.runs import StorageRunObserver
|
||||||
|
|
||||||
|
__all__ = ["StorageRunObserver"]
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
"""App-owned runs store adapters."""
|
||||||
|
|
||||||
|
from .create_store import AppRunCreateStore
|
||||||
|
from .delete_store import AppRunDeleteStore
|
||||||
|
from .query_store import AppRunQueryStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppRunCreateStore",
|
||||||
|
"AppRunDeleteStore",
|
||||||
|
"AppRunQueryStore",
|
||||||
|
]
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
"""App-owned durable run creation adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store import RunCreateStore
|
||||||
|
from deerflow.runtime.runs.types import RunRecord
|
||||||
|
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from app.infra.storage.runs import RunWriteRepository
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunCreateStore(RunCreateStore):
|
||||||
|
"""Write the initial durable row for a newly created run."""
|
||||||
|
|
||||||
|
def __init__(self, repo: RunWriteRepository, thread_meta_storage: ThreadMetaStorage | None = None) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
self._thread_meta_storage = thread_meta_storage
|
||||||
|
|
||||||
|
async def create_run(self, record: RunRecord) -> None:
|
||||||
|
await self._repo.create(
|
||||||
|
run_id=record.run_id,
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
status=str(record.status),
|
||||||
|
metadata=record.metadata,
|
||||||
|
follow_up_to_run_id=record.follow_up_to_run_id,
|
||||||
|
created_at=record.created_at,
|
||||||
|
)
|
||||||
|
if self._thread_meta_storage is not None and record.assistant_id:
|
||||||
|
thread = await self._thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
)
|
||||||
|
if thread.assistant_id != record.assistant_id:
|
||||||
|
await self._thread_meta_storage.sync_thread_assistant_id(
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""App-owned durable run deletion adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store import RunDeleteStore
|
||||||
|
|
||||||
|
from app.infra.storage.runs import RunDeleteRepository
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunDeleteStore(RunDeleteStore):
|
||||||
|
"""Delete durable run rows via the app storage adapter."""
|
||||||
|
|
||||||
|
def __init__(self, repo: RunDeleteRepository) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
|
||||||
|
async def delete_run(self, run_id: str) -> bool:
|
||||||
|
return await self._repo.delete(run_id)
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""App-owned durable run query adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store import RunQueryStore
|
||||||
|
from deerflow.runtime.runs.types import RunRecord, RunStatus
|
||||||
|
|
||||||
|
from app.infra.storage.runs import RunReadRepository, RunRow
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunQueryStore(RunQueryStore):
|
||||||
|
"""Map app-side durable run rows into harness RunRecord DTOs."""
|
||||||
|
|
||||||
|
def __init__(self, repo: RunReadRepository) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
|
||||||
|
async def get_run(self, run_id: str) -> RunRecord | None:
|
||||||
|
row = await self._repo.get(run_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._to_run_record(row)
|
||||||
|
|
||||||
|
async def list_runs(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[RunRecord]:
|
||||||
|
rows = await self._repo.list_by_thread(thread_id, limit=limit)
|
||||||
|
return [self._to_run_record(row) for row in rows]
|
||||||
|
|
||||||
|
def _to_run_record(self, row: RunRow) -> RunRecord:
|
||||||
|
return RunRecord(
|
||||||
|
run_id=row["run_id"],
|
||||||
|
thread_id=row["thread_id"],
|
||||||
|
assistant_id=row.get("assistant_id"),
|
||||||
|
status=RunStatus(row.get("status", "pending")),
|
||||||
|
temporary=False,
|
||||||
|
multitask_strategy=row.get("multitask_strategy", "reject"),
|
||||||
|
metadata=row.get("metadata", {}),
|
||||||
|
follow_up_to_run_id=row.get("follow_up_to_run_id"),
|
||||||
|
created_at=row.get("created_at", ""),
|
||||||
|
updated_at=row.get("updated_at", ""),
|
||||||
|
started_at=row.get("started_at"),
|
||||||
|
ended_at=row.get("ended_at"),
|
||||||
|
error=row.get("error"),
|
||||||
|
)
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
"""Shared utility helpers for the Gateway layer."""
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_log_param(value: str) -> str:
|
|
||||||
"""Strip control characters to prevent log injection."""
|
|
||||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Application-owned infrastructure adapters and wiring."""
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Run event store backends owned by app infrastructure."""
|
||||||
|
|
||||||
|
from .factory import build_run_event_store
|
||||||
|
from .jsonl_store import JsonlRunEventStore
|
||||||
|
|
||||||
|
__all__ = ["JsonlRunEventStore", "build_run_event_store"]
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""Factory for app-owned run event store backends."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from app.infra.storage import AppRunEventStore
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
from .jsonl_store import JsonlRunEventStore
|
||||||
|
|
||||||
|
|
||||||
|
def build_run_event_store(session_factory: async_sessionmaker[AsyncSession]) -> AppRunEventStore | JsonlRunEventStore:
|
||||||
|
"""Build the run event store selected by app configuration."""
|
||||||
|
|
||||||
|
config = get_app_config().run_events
|
||||||
|
if config.backend == "db":
|
||||||
|
return AppRunEventStore(session_factory)
|
||||||
|
if config.backend == "jsonl":
|
||||||
|
return JsonlRunEventStore(
|
||||||
|
base_dir=Path(config.jsonl_base_dir),
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported run event backend: {config.backend}")
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
"""JSONL run event store backend owned by app infrastructure."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class JsonlRunEventStore:
|
||||||
|
"""Append-only JSONL implementation of the runs RunEventStore protocol."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_dir: Path | str = ".deer-flow/run-events",
|
||||||
|
) -> None:
|
||||||
|
self._base_dir = Path(base_dir)
|
||||||
|
self._locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._locks_guard = asyncio.Lock()
|
||||||
|
|
||||||
|
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
if not events:
|
||||||
|
return []
|
||||||
|
|
||||||
|
grouped: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
for event in events:
|
||||||
|
grouped.setdefault(str(event["thread_id"]), []).append(event)
|
||||||
|
|
||||||
|
records_by_thread: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
for thread_id, thread_events in grouped.items():
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
records_by_thread[thread_id] = self._append_thread_events(thread_id, thread_events)
|
||||||
|
|
||||||
|
indexes = {thread_id: 0 for thread_id in records_by_thread}
|
||||||
|
ordered: list[dict[str, Any]] = []
|
||||||
|
for event in events:
|
||||||
|
thread_id = str(event["thread_id"])
|
||||||
|
index = indexes[thread_id]
|
||||||
|
ordered.append(records_by_thread[thread_id][index])
|
||||||
|
indexes[thread_id] = index + 1
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
events = [event for event in await self._read_thread_events(thread_id) if event.get("category") == "message"]
|
||||||
|
if before_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||||
|
return events[-limit:]
|
||||||
|
if after_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||||
|
return events[:limit]
|
||||||
|
return events[-limit:]
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
event_types: list[str] | None = None,
|
||||||
|
limit: int = 500,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
event_type_set = set(event_types or [])
|
||||||
|
events = [
|
||||||
|
event
|
||||||
|
for event in await self._read_thread_events(thread_id)
|
||||||
|
if event.get("run_id") == run_id and (not event_type_set or event.get("event_type") in event_type_set)
|
||||||
|
]
|
||||||
|
return events[:limit]
|
||||||
|
|
||||||
|
async def list_messages_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
events = [
|
||||||
|
event
|
||||||
|
for event in await self._read_thread_events(thread_id)
|
||||||
|
if event.get("run_id") == run_id and event.get("category") == "message"
|
||||||
|
]
|
||||||
|
if before_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||||
|
return events[-limit:]
|
||||||
|
if after_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||||
|
return events[:limit]
|
||||||
|
return events[-limit:]
|
||||||
|
|
||||||
|
async def count_messages(self, thread_id: str) -> int:
|
||||||
|
return len(await self.list_messages(thread_id, limit=10**9))
|
||||||
|
|
||||||
|
async def delete_by_thread(self, thread_id: str) -> int:
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
count = len(self._read_thread_events_sync(thread_id))
|
||||||
|
shutil.rmtree(self._thread_dir(thread_id), ignore_errors=True)
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
events = self._read_thread_events_sync(thread_id)
|
||||||
|
kept = [event for event in events if event.get("run_id") != run_id]
|
||||||
|
deleted = len(events) - len(kept)
|
||||||
|
if deleted:
|
||||||
|
self._write_thread_events(thread_id, kept)
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
async def _thread_lock(self, thread_id: str) -> asyncio.Lock:
|
||||||
|
async with self._locks_guard:
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._locks[thread_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def _append_thread_events(self, thread_id: str, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
thread_dir = self._thread_dir(thread_id)
|
||||||
|
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
seq = self._read_seq(thread_id)
|
||||||
|
records: list[dict[str, Any]] = []
|
||||||
|
with self._events_path(thread_id).open("a", encoding="utf-8") as file:
|
||||||
|
for event in events:
|
||||||
|
seq += 1
|
||||||
|
record = self._normalize_event(event, seq=seq)
|
||||||
|
file.write(json.dumps(record, ensure_ascii=False, default=str))
|
||||||
|
file.write("\n")
|
||||||
|
records.append(record)
|
||||||
|
self._write_seq(thread_id, seq)
|
||||||
|
return records
|
||||||
|
|
||||||
|
def _normalize_event(self, event: dict[str, Any], *, seq: int) -> dict[str, Any]:
|
||||||
|
created_at = event.get("created_at")
|
||||||
|
if isinstance(created_at, datetime):
|
||||||
|
created_at_value = created_at.isoformat()
|
||||||
|
elif created_at:
|
||||||
|
created_at_value = str(created_at)
|
||||||
|
else:
|
||||||
|
created_at_value = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"thread_id": str(event["thread_id"]),
|
||||||
|
"run_id": str(event["run_id"]),
|
||||||
|
"seq": seq,
|
||||||
|
"event_type": str(event["event_type"]),
|
||||||
|
"category": str(event["category"]),
|
||||||
|
"content": event.get("content", ""),
|
||||||
|
"metadata": dict(event.get("metadata") or {}),
|
||||||
|
"created_at": created_at_value,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _read_thread_events(self, thread_id: str) -> list[dict[str, Any]]:
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
return self._read_thread_events_sync(thread_id)
|
||||||
|
|
||||||
|
def _read_thread_events_sync(self, thread_id: str) -> list[dict[str, Any]]:
|
||||||
|
path = self._events_path(thread_id)
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
with path.open(encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped:
|
||||||
|
events.append(json.loads(stripped))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _write_thread_events(self, thread_id: str, events: Iterable[dict[str, Any]]) -> None:
|
||||||
|
thread_dir = self._thread_dir(thread_id)
|
||||||
|
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = self._events_path(thread_id).with_suffix(".jsonl.tmp")
|
||||||
|
with temp_path.open("w", encoding="utf-8") as file:
|
||||||
|
for event in events:
|
||||||
|
file.write(json.dumps(event, ensure_ascii=False, default=str))
|
||||||
|
file.write("\n")
|
||||||
|
temp_path.replace(self._events_path(thread_id))
|
||||||
|
|
||||||
|
def _read_seq(self, thread_id: str) -> int:
|
||||||
|
path = self._seq_path(thread_id)
|
||||||
|
if not path.exists():
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
return int(path.read_text(encoding="utf-8").strip() or "0")
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _write_seq(self, thread_id: str, seq: int) -> None:
|
||||||
|
self._seq_path(thread_id).write_text(str(seq), encoding="utf-8")
|
||||||
|
|
||||||
|
def _thread_dir(self, thread_id: str) -> Path:
|
||||||
|
return self._base_dir / "threads" / thread_id
|
||||||
|
|
||||||
|
def _events_path(self, thread_id: str) -> Path:
|
||||||
|
return self._thread_dir(thread_id) / "events.jsonl"
|
||||||
|
|
||||||
|
def _seq_path(self, thread_id: str) -> Path:
|
||||||
|
return self._thread_dir(thread_id) / "seq"
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Storage-facing adapters owned by the app layer."""
|
||||||
|
|
||||||
|
from .run_events import AppRunEventStore
|
||||||
|
from .runs import FeedbackStoreAdapter, RunStoreAdapter, StorageRunObserver
|
||||||
|
from .thread_meta import ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppRunEventStore",
|
||||||
|
"FeedbackStoreAdapter",
|
||||||
|
"RunStoreAdapter",
|
||||||
|
"StorageRunObserver",
|
||||||
|
"ThreadMetaStorage",
|
||||||
|
"ThreadMetaStoreAdapter",
|
||||||
|
]
|
||||||
@@ -0,0 +1,166 @@
|
|||||||
|
"""App-owned adapter from runs callbacks to storage run event repository."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from store.repositories import RunEvent, RunEventCreate, build_run_event_repository, build_thread_meta_repository
|
||||||
|
|
||||||
|
from deerflow.runtime.actor_context import get_actor_context
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunEventStore:
|
||||||
|
"""Implements the harness RunEventStore protocol using storage repositories."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
if not events:
|
||||||
|
return []
|
||||||
|
|
||||||
|
denied = {str(event["thread_id"]) for event in events if not await self._thread_visible(str(event["thread_id"]))}
|
||||||
|
if denied:
|
||||||
|
raise PermissionError(f"actor is not allowed to append events for thread(s): {', '.join(sorted(denied))}")
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.append_batch([_event_create_from_dict(event) for event in events])
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return []
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.list_messages(
|
||||||
|
thread_id,
|
||||||
|
limit=limit,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
event_types: list[str] | None = None,
|
||||||
|
limit: int = 500,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return []
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.list_events(thread_id, run_id, event_types=event_types, limit=limit)
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_messages_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return []
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.list_messages_by_run(
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
limit=limit,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def count_messages(self, thread_id: str) -> int:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return 0
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
return await repo.count_messages(thread_id)
|
||||||
|
|
||||||
|
async def delete_by_thread(self, thread_id: str) -> int:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return 0
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
count = await repo.delete_by_thread(thread_id)
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return 0
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
count = await repo.delete_by_run(thread_id, run_id)
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _thread_visible(self, thread_id: str) -> bool:
|
||||||
|
actor = get_actor_context()
|
||||||
|
if actor is None or actor.user_id is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
thread_repo = build_thread_meta_repository(session)
|
||||||
|
thread = await thread_repo.get_thread_meta(thread_id)
|
||||||
|
|
||||||
|
if thread is None:
|
||||||
|
return True
|
||||||
|
return thread.user_id is None or thread.user_id == actor.user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _event_create_from_dict(event: dict[str, Any]) -> RunEventCreate:
|
||||||
|
created_at = event.get("created_at")
|
||||||
|
return RunEventCreate(
|
||||||
|
thread_id=str(event["thread_id"]),
|
||||||
|
run_id=str(event["run_id"]),
|
||||||
|
event_type=str(event["event_type"]),
|
||||||
|
category=str(event["category"]),
|
||||||
|
content=event.get("content", ""),
|
||||||
|
metadata=dict(event.get("metadata") or {}),
|
||||||
|
created_at=datetime.fromisoformat(created_at) if isinstance(created_at, str) else created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _event_to_dict(event: RunEvent) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"thread_id": event.thread_id,
|
||||||
|
"run_id": event.run_id,
|
||||||
|
"event_type": event.event_type,
|
||||||
|
"category": event.category,
|
||||||
|
"content": event.content,
|
||||||
|
"metadata": event.metadata,
|
||||||
|
"seq": event.seq,
|
||||||
|
"created_at": event.created_at.isoformat(),
|
||||||
|
}
|
||||||
@@ -0,0 +1,515 @@
|
|||||||
|
"""Run lifecycle persistence adapters owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Protocol, TypedDict, Unpack, cast
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from store.repositories import FeedbackCreate, Run, RunCreate, build_feedback_repository, build_run_repository
|
||||||
|
|
||||||
|
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||||
|
from deerflow.runtime.serialization import serialize_lc_object
|
||||||
|
from deerflow.runtime.runs.observer import LifecycleEventType, RunLifecycleEvent, RunObserver
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue
|
||||||
|
|
||||||
|
from .thread_meta import ThreadMetaStorage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreateFields(TypedDict, total=False):
|
||||||
|
status: str
|
||||||
|
created_at: str
|
||||||
|
started_at: str
|
||||||
|
ended_at: str
|
||||||
|
assistant_id: str | None
|
||||||
|
user_id: str | None
|
||||||
|
follow_up_to_run_id: str | None
|
||||||
|
metadata: dict[str, JSONValue]
|
||||||
|
kwargs: dict[str, JSONValue]
|
||||||
|
|
||||||
|
|
||||||
|
class RunStatusUpdateFields(TypedDict, total=False):
|
||||||
|
started_at: str
|
||||||
|
ended_at: str
|
||||||
|
metadata: dict[str, JSONValue]
|
||||||
|
|
||||||
|
|
||||||
|
class RunCompletionFields(TypedDict, total=False):
|
||||||
|
total_input_tokens: int
|
||||||
|
total_output_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
llm_call_count: int
|
||||||
|
lead_agent_tokens: int
|
||||||
|
subagent_tokens: int
|
||||||
|
middleware_tokens: int
|
||||||
|
message_count: int
|
||||||
|
last_ai_message: str | None
|
||||||
|
first_human_message: str | None
|
||||||
|
error: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class RunRow(TypedDict, total=False):
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None
|
||||||
|
status: str
|
||||||
|
multitask_strategy: str
|
||||||
|
follow_up_to_run_id: str | None
|
||||||
|
metadata: dict[str, JSONValue]
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
started_at: str | None
|
||||||
|
ended_at: str | None
|
||||||
|
error: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class RunReadRepository(Protocol):
|
||||||
|
"""Protocol for durable run queries."""
|
||||||
|
|
||||||
|
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: ...
|
||||||
|
|
||||||
|
async def list_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
) -> list[RunRow]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class RunWriteRepository(Protocol):
|
||||||
|
"""Protocol for durable run writes."""
|
||||||
|
|
||||||
|
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: ...
|
||||||
|
async def update_status(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunStatusUpdateFields],
|
||||||
|
) -> None: ...
|
||||||
|
async def set_error(self, run_id: str, error: str) -> None: ...
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunCompletionFields],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class RunDeleteRepository(Protocol):
|
||||||
|
"""Protocol for durable run deletion."""
|
||||||
|
|
||||||
|
async def delete(self, run_id: str) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _RepositoryContext:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
build_repo: Callable[[AsyncSession], object],
|
||||||
|
*,
|
||||||
|
commit: bool,
|
||||||
|
) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._build_repo = build_repo
|
||||||
|
self._commit = commit
|
||||||
|
self._session: AsyncSession | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self._session = self._session_factory()
|
||||||
|
return self._build_repo(self._session)
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
if self._session is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self._commit:
|
||||||
|
if exc_type is None:
|
||||||
|
await self._session.commit()
|
||||||
|
else:
|
||||||
|
await self._session.rollback()
|
||||||
|
finally:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_to_row(row: Run) -> RunRow:
|
||||||
|
return {
|
||||||
|
"run_id": row.run_id,
|
||||||
|
"thread_id": row.thread_id,
|
||||||
|
"assistant_id": row.assistant_id,
|
||||||
|
"user_id": row.user_id,
|
||||||
|
"status": row.status,
|
||||||
|
"model_name": row.model_name,
|
||||||
|
"multitask_strategy": row.multitask_strategy,
|
||||||
|
"follow_up_to_run_id": row.follow_up_to_run_id,
|
||||||
|
"metadata": cast(dict[str, JSONValue], row.metadata),
|
||||||
|
"kwargs": cast(dict[str, JSONValue], row.kwargs),
|
||||||
|
"created_at": row.created_time.isoformat(),
|
||||||
|
"updated_at": row.updated_time.isoformat() if row.updated_time else "",
|
||||||
|
"total_input_tokens": row.total_input_tokens,
|
||||||
|
"total_output_tokens": row.total_output_tokens,
|
||||||
|
"total_tokens": row.total_tokens,
|
||||||
|
"llm_call_count": row.llm_call_count,
|
||||||
|
"lead_agent_tokens": row.lead_agent_tokens,
|
||||||
|
"subagent_tokens": row.subagent_tokens,
|
||||||
|
"middleware_tokens": row.middleware_tokens,
|
||||||
|
"message_count": row.message_count,
|
||||||
|
"first_human_message": row.first_human_message,
|
||||||
|
"last_ai_message": row.last_ai_message,
|
||||||
|
"error": row.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackStoreAdapter:
|
||||||
|
"""Expose feedback route semantics on top of storage package repositories."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
rating: int,
|
||||||
|
owner_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
if rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
|
effective_user_id = user_id if user_id is not None else owner_id
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
row = await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id=str(uuid.uuid4()),
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=rating,
|
||||||
|
user_id=effective_user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _feedback_to_dict(row)
|
||||||
|
|
||||||
|
async def get(self, feedback_id: str) -> dict[str, object] | None:
|
||||||
|
async with self._read() as repo:
|
||||||
|
row = await repo.get_feedback(feedback_id)
|
||||||
|
return _feedback_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
async def list_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[dict[str, object]]:
|
||||||
|
async with self._read() as repo:
|
||||||
|
rows = await repo.list_feedback_by_run(run_id)
|
||||||
|
filtered = [row for row in rows if row.thread_id == thread_id]
|
||||||
|
if user_id is not None:
|
||||||
|
filtered = [row for row in filtered if row.user_id == user_id]
|
||||||
|
return [_feedback_to_dict(row) for row in filtered][:limit]
|
||||||
|
|
||||||
|
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict[str, object]]:
|
||||||
|
async with self._read() as repo:
|
||||||
|
rows = await repo.list_feedback_by_thread(thread_id)
|
||||||
|
return [_feedback_to_dict(row) for row in rows][:limit]
|
||||||
|
|
||||||
|
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict[str, object]:
|
||||||
|
rows = await self.list_by_run(thread_id, run_id)
|
||||||
|
positive = sum(1 for row in rows if row["rating"] == 1)
|
||||||
|
negative = sum(1 for row in rows if row["rating"] == -1)
|
||||||
|
return {"run_id": run_id, "total": len(rows), "positive": positive, "negative": negative}
|
||||||
|
|
||||||
|
async def delete(self, feedback_id: str) -> bool:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
return await repo.delete_feedback(feedback_id)
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
rating: int,
|
||||||
|
user_id: str,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
if rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
rows = await repo.list_feedback_by_run(run_id)
|
||||||
|
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||||
|
feedback_id = existing.feedback_id if existing is not None else str(uuid.uuid4())
|
||||||
|
if existing is not None:
|
||||||
|
await repo.delete_feedback(existing.feedback_id)
|
||||||
|
row = await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id=feedback_id,
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=rating,
|
||||||
|
user_id=user_id,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _feedback_to_dict(row)
|
||||||
|
|
||||||
|
async def delete_by_run(self, *, thread_id: str, run_id: str, user_id: str) -> bool:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
rows = await repo.list_feedback_by_run(run_id)
|
||||||
|
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||||
|
if existing is None:
|
||||||
|
return False
|
||||||
|
return await repo.delete_feedback(existing.feedback_id)
|
||||||
|
|
||||||
|
async def list_by_thread_grouped(self, thread_id: str, *, user_id: str) -> dict[str, dict[str, object]]:
|
||||||
|
rows = await self.list_by_thread(thread_id)
|
||||||
|
return {
|
||||||
|
row["run_id"]: row
|
||||||
|
for row in rows
|
||||||
|
if row["user_id"] == user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def _read(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=False)
|
||||||
|
|
||||||
|
def _transaction(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _feedback_to_dict(row) -> dict[str, object]:
|
||||||
|
return {
|
||||||
|
"feedback_id": row.feedback_id,
|
||||||
|
"run_id": row.run_id,
|
||||||
|
"thread_id": row.thread_id,
|
||||||
|
"user_id": row.user_id,
|
||||||
|
"owner_id": row.user_id,
|
||||||
|
"message_id": row.message_id,
|
||||||
|
"rating": row.rating,
|
||||||
|
"comment": row.comment,
|
||||||
|
"created_at": row.created_time.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RunStoreAdapter:
|
||||||
|
"""Expose runs facade storage semantics on top of storage package repositories."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None:
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.get")
|
||||||
|
async with self._read() as repo:
|
||||||
|
row = await repo.get_run(run_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
if effective_user_id is not None and row.user_id != effective_user_id:
|
||||||
|
return None
|
||||||
|
return _run_to_row(row)
|
||||||
|
|
||||||
|
async def list_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
) -> list[RunRow]:
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.list_by_thread")
|
||||||
|
async with self._read() as repo:
|
||||||
|
rows = await repo.list_runs_by_thread(thread_id, limit=limit, offset=0)
|
||||||
|
if effective_user_id is not None:
|
||||||
|
rows = [row for row in rows if row.user_id == effective_user_id]
|
||||||
|
return [_run_to_row(row) for row in rows]
|
||||||
|
|
||||||
|
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None:
|
||||||
|
metadata = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("metadata") or {}))
|
||||||
|
run_kwargs = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("kwargs") or {}))
|
||||||
|
effective_user_id = resolve_user_id(kwargs.get("user_id", AUTO), method_name="RunStoreAdapter.create")
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.create_run(
|
||||||
|
RunCreate(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=kwargs.get("assistant_id"),
|
||||||
|
user_id=effective_user_id,
|
||||||
|
status=kwargs.get("status", "pending"),
|
||||||
|
metadata=dict(metadata),
|
||||||
|
kwargs=dict(run_kwargs),
|
||||||
|
follow_up_to_run_id=kwargs.get("follow_up_to_run_id"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self, run_id: str, *, user_id: str | None | object = AUTO) -> bool:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
existing = await repo.get_run(run_id)
|
||||||
|
if existing is None:
|
||||||
|
return False
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.delete")
|
||||||
|
if effective_user_id is not None and existing.user_id != effective_user_id:
|
||||||
|
return False
|
||||||
|
await repo.delete_run(run_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def update_status(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunStatusUpdateFields],
|
||||||
|
) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_run_status(run_id, status)
|
||||||
|
|
||||||
|
async def set_error(self, run_id: str, error: str) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_run_status(run_id, "error", error=error)
|
||||||
|
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunCompletionFields],
|
||||||
|
) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_run_completion(
|
||||||
|
run_id,
|
||||||
|
status=status,
|
||||||
|
total_input_tokens=kwargs.get("total_input_tokens", 0),
|
||||||
|
total_output_tokens=kwargs.get("total_output_tokens", 0),
|
||||||
|
total_tokens=kwargs.get("total_tokens", 0),
|
||||||
|
llm_call_count=kwargs.get("llm_call_count", 0),
|
||||||
|
lead_agent_tokens=kwargs.get("lead_agent_tokens", 0),
|
||||||
|
subagent_tokens=kwargs.get("subagent_tokens", 0),
|
||||||
|
middleware_tokens=kwargs.get("middleware_tokens", 0),
|
||||||
|
message_count=kwargs.get("message_count", 0),
|
||||||
|
last_ai_message=kwargs.get("last_ai_message"),
|
||||||
|
first_human_message=kwargs.get("first_human_message"),
|
||||||
|
error=kwargs.get("error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_run_repository, commit=False)
|
||||||
|
|
||||||
|
def _transaction(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_run_repository, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRunObserver(RunObserver):
|
||||||
|
"""Persist run lifecycle state into app-owned repositories."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
run_write_repo: RunWriteRepository | None = None,
|
||||||
|
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._run_write_repo = run_write_repo
|
||||||
|
self._thread_meta_storage = thread_meta_storage
|
||||||
|
|
||||||
|
async def on_event(self, event: RunLifecycleEvent) -> None:
|
||||||
|
try:
|
||||||
|
await self._dispatch(event)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"StorageRunObserver failed to persist event %s for run %s",
|
||||||
|
event.event_type,
|
||||||
|
event.run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _dispatch(self, event: RunLifecycleEvent) -> None:
|
||||||
|
handlers = {
|
||||||
|
LifecycleEventType.RUN_STARTED: self._handle_run_started,
|
||||||
|
LifecycleEventType.RUN_COMPLETED: self._handle_run_completed,
|
||||||
|
LifecycleEventType.RUN_FAILED: self._handle_run_failed,
|
||||||
|
LifecycleEventType.RUN_CANCELLED: self._handle_run_cancelled,
|
||||||
|
LifecycleEventType.THREAD_STATUS_UPDATED: self._handle_thread_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = handlers.get(event.event_type)
|
||||||
|
if handler:
|
||||||
|
await handler(event)
|
||||||
|
|
||||||
|
async def _handle_run_started(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._run_write_repo:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="running",
|
||||||
|
started_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_run_completed(self, event: RunLifecycleEvent) -> None:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
if self._run_write_repo:
|
||||||
|
completion_data = payload.get("completion_data")
|
||||||
|
if isinstance(completion_data, dict):
|
||||||
|
await self._run_write_repo.update_run_completion(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="success",
|
||||||
|
**cast(RunCompletionFields, completion_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="success",
|
||||||
|
ended_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._thread_meta_storage and "title" in payload:
|
||||||
|
await self._thread_meta_storage.sync_thread_title(
|
||||||
|
thread_id=event.thread_id,
|
||||||
|
title=payload["title"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_run_failed(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._run_write_repo:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
error = payload.get("error", "Unknown error")
|
||||||
|
completion_data = payload.get("completion_data")
|
||||||
|
if isinstance(completion_data, dict):
|
||||||
|
await self._run_write_repo.update_run_completion(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="error",
|
||||||
|
error=str(error),
|
||||||
|
**cast(RunCompletionFields, completion_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="error",
|
||||||
|
ended_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
await self._run_write_repo.set_error(run_id=event.run_id, error=str(error))
|
||||||
|
|
||||||
|
async def _handle_run_cancelled(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._run_write_repo:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
completion_data = payload.get("completion_data")
|
||||||
|
if isinstance(completion_data, dict):
|
||||||
|
await self._run_write_repo.update_run_completion(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="interrupted",
|
||||||
|
**cast(RunCompletionFields, completion_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="interrupted",
|
||||||
|
ended_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_thread_status(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._thread_meta_storage:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
status = payload.get("status", "idle")
|
||||||
|
await self._thread_meta_storage.sync_thread_status(
|
||||||
|
thread_id=event.thread_id,
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
"""Thread metadata storage adapter owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from store.repositories import build_thread_meta_repository
|
||||||
|
from store.repositories.contracts import (
|
||||||
|
ThreadMeta,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMetaStoreAdapter:
|
||||||
|
"""Use storage package thread repositories with per-call sessions."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
return await repo.create_thread_meta(data)
|
||||||
|
|
||||||
|
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||||
|
async with self._read() as repo:
|
||||||
|
return await repo.get_thread_meta(thread_id)
|
||||||
|
|
||||||
|
async def update_thread_meta(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
display_name: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_thread_meta(
|
||||||
|
thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
display_name=display_name,
|
||||||
|
status=status,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_thread(self, thread_id: str) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.delete_thread(thread_id)
|
||||||
|
|
||||||
|
async def search_threads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ThreadMeta]:
|
||||||
|
async with self._read() as repo:
|
||||||
|
return await repo.search_threads(
|
||||||
|
metadata=metadata,
|
||||||
|
status=status,
|
||||||
|
user_id=user_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read(self):
|
||||||
|
return _ThreadMetaRepositoryContext(self._session_factory, commit=False)
|
||||||
|
|
||||||
|
def _transaction(self):
|
||||||
|
return _ThreadMetaRepositoryContext(self._session_factory, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _ThreadMetaRepositoryContext:
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, commit: bool) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._commit = commit
|
||||||
|
self._session: AsyncSession | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self._session = self._session_factory()
|
||||||
|
return build_thread_meta_repository(self._session)
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
if self._session is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self._commit:
|
||||||
|
if exc_type is None:
|
||||||
|
await self._session.commit()
|
||||||
|
else:
|
||||||
|
await self._session.rollback()
|
||||||
|
finally:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMetaStorage:
|
||||||
|
"""App-facing adapter around the storage thread metadata contract."""
|
||||||
|
|
||||||
|
def __init__(self, repo: ThreadMetaRepositoryProtocol) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
|
||||||
|
async def get_thread(self, thread_id: str, *, user_id: str | None | object = AUTO) -> ThreadMeta | None:
|
||||||
|
thread = await self._repo.get_thread_meta(thread_id)
|
||||||
|
if thread is None:
|
||||||
|
return None
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.get_thread")
|
||||||
|
if effective_user_id is not None and thread.user_id != effective_user_id:
|
||||||
|
return None
|
||||||
|
return thread
|
||||||
|
|
||||||
|
async def ensure_thread(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
) -> ThreadMeta:
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.ensure_thread")
|
||||||
|
existing = await self.get_thread(thread_id, user_id=effective_user_id)
|
||||||
|
if existing is not None:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
return await self._repo.create_thread_meta(
|
||||||
|
ThreadMetaCreate(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
user_id=effective_user_id,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ensure_thread_running(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> ThreadMeta | None:
|
||||||
|
existing = await self._repo.get_thread_meta(thread_id)
|
||||||
|
if existing is None:
|
||||||
|
return await self._repo.create_thread_meta(
|
||||||
|
ThreadMetaCreate(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
status="running",
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._repo.update_thread_meta(thread_id, status="running")
|
||||||
|
return await self._repo.get_thread_meta(thread_id)
|
||||||
|
|
||||||
|
async def sync_thread_title(self, *, thread_id: str, title: str) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, display_name=title)
|
||||||
|
|
||||||
|
async def sync_thread_assistant_id(self, *, thread_id: str, assistant_id: str) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, assistant_id=assistant_id)
|
||||||
|
|
||||||
|
async def sync_thread_status(self, *, thread_id: str, status: str) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, status=status)
|
||||||
|
|
||||||
|
async def sync_thread_metadata(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, metadata=metadata)
|
||||||
|
|
||||||
|
async def delete_thread(self, thread_id: str) -> None:
|
||||||
|
await self._repo.delete_thread(thread_id)
|
||||||
|
|
||||||
|
async def search_threads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ThreadMeta]:
|
||||||
|
normalized_status = status.strip() if status is not None else None
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.search_threads")
|
||||||
|
normalized_user_id = resolved_user_id.strip() if resolved_user_id is not None else None
|
||||||
|
normalized_assistant_id = (
|
||||||
|
assistant_id.strip() if assistant_id is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._repo.search_threads(
|
||||||
|
metadata=metadata,
|
||||||
|
status=normalized_status or None,
|
||||||
|
user_id=normalized_user_id or None,
|
||||||
|
assistant_id=normalized_assistant_id or None,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ThreadMetaStorage", "ThreadMetaStoreAdapter"]
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""App-owned stream bridge adapters and factory."""
|
||||||
|
|
||||||
|
from .factory import build_stream_bridge
|
||||||
|
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||||
|
|
||||||
|
__all__ = ["MemoryStreamBridge", "RedisStreamBridge", "build_stream_bridge"]
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Concrete stream bridge adapters owned by the app layer."""
|
||||||
|
|
||||||
|
from .memory import MemoryStreamBridge
|
||||||
|
from .redis import RedisStreamBridge
|
||||||
|
|
||||||
|
__all__ = ["MemoryStreamBridge", "RedisStreamBridge"]
|
||||||
@@ -0,0 +1,450 @@
|
|||||||
|
"""In-memory stream bridge implementation owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from deerflow.runtime.stream_bridge import (
|
||||||
|
CANCELLED_SENTINEL,
|
||||||
|
END_SENTINEL,
|
||||||
|
HEARTBEAT_SENTINEL,
|
||||||
|
TERMINAL_STATES,
|
||||||
|
ResumeResult,
|
||||||
|
StreamBridge,
|
||||||
|
StreamEvent,
|
||||||
|
StreamStatus,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.stream_bridge.exceptions import (
|
||||||
|
BridgeClosedError,
|
||||||
|
StreamCapacityExceededError,
|
||||||
|
StreamTerminatedError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _RunStream:
|
||||||
|
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||||
|
events: list[StreamEvent] = field(default_factory=list)
|
||||||
|
id_to_offset: dict[str, int] = field(default_factory=dict)
|
||||||
|
start_offset: int = 0
|
||||||
|
current_bytes: int = 0
|
||||||
|
seq: int = 0
|
||||||
|
status: StreamStatus = StreamStatus.ACTIVE
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
last_publish_at: float | None = None
|
||||||
|
ended_at: float | None = None
|
||||||
|
subscriber_count: int = 0
|
||||||
|
last_subscribe_at: float | None = None
|
||||||
|
awaiting_input: bool = False
|
||||||
|
awaiting_since: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStreamBridge(StreamBridge):
|
||||||
|
"""Per-run in-memory event log implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
max_events_per_stream: int = 256,
|
||||||
|
max_bytes_per_stream: int = 10 * 1024 * 1024,
|
||||||
|
max_active_streams: int = 1000,
|
||||||
|
stream_eviction_policy: Literal["reject", "lru"] = "lru",
|
||||||
|
terminal_retention_ttl: float = 300.0,
|
||||||
|
active_no_publish_timeout: float = 600.0,
|
||||||
|
orphan_timeout: float = 60.0,
|
||||||
|
max_stream_age: float = 86400.0,
|
||||||
|
hitl_extended_timeout: float = 7200.0,
|
||||||
|
cleanup_interval: float = 30.0,
|
||||||
|
queue_maxsize: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
if queue_maxsize is not None:
|
||||||
|
max_events_per_stream = queue_maxsize
|
||||||
|
|
||||||
|
self._max_events = max_events_per_stream
|
||||||
|
self._max_bytes = max_bytes_per_stream
|
||||||
|
self._max_streams = max_active_streams
|
||||||
|
self._eviction_policy = stream_eviction_policy
|
||||||
|
self._terminal_ttl = terminal_retention_ttl
|
||||||
|
self._active_timeout = active_no_publish_timeout
|
||||||
|
self._orphan_timeout = orphan_timeout
|
||||||
|
self._max_age = max_stream_age
|
||||||
|
self._hitl_timeout = hitl_extended_timeout
|
||||||
|
self._cleanup_interval = cleanup_interval
|
||||||
|
self._streams: dict[str, _RunStream] = {}
|
||||||
|
self._registry_lock = asyncio.Lock()
|
||||||
|
self._closed = False
|
||||||
|
self._cleanup_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info(
|
||||||
|
"MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)",
|
||||||
|
self._max_events,
|
||||||
|
self._max_bytes,
|
||||||
|
self._max_streams,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
async with self._registry_lock:
|
||||||
|
self._closed = True
|
||||||
|
if self._cleanup_task is not None:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
for stream in self._streams.values():
|
||||||
|
async with stream.condition:
|
||||||
|
stream.status = StreamStatus.CLOSED
|
||||||
|
stream.condition.notify_all()
|
||||||
|
|
||||||
|
self._streams.clear()
|
||||||
|
logger.info("MemoryStreamBridge closed")
|
||||||
|
|
||||||
|
async def _get_or_create_stream(self, run_id: str) -> _RunStream:
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is not None:
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async with self._registry_lock:
|
||||||
|
if self._closed:
|
||||||
|
raise BridgeClosedError("Stream bridge is closed")
|
||||||
|
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is not None:
|
||||||
|
return stream
|
||||||
|
|
||||||
|
if len(self._streams) >= self._max_streams:
|
||||||
|
if self._eviction_policy == "reject":
|
||||||
|
raise StreamCapacityExceededError(
|
||||||
|
f"Max {self._max_streams} active streams reached"
|
||||||
|
)
|
||||||
|
evicted = self._evict_oldest_terminal()
|
||||||
|
if evicted is None:
|
||||||
|
raise StreamCapacityExceededError("All streams active, cannot evict")
|
||||||
|
logger.info("Evicted stream %s to make room", evicted)
|
||||||
|
|
||||||
|
stream = _RunStream()
|
||||||
|
self._streams[run_id] = stream
|
||||||
|
logger.debug("Created stream for run %s", run_id)
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def _evict_oldest_terminal(self) -> str | None:
|
||||||
|
oldest_run_id: str | None = None
|
||||||
|
oldest_ended_at: float = float("inf")
|
||||||
|
for run_id, stream in self._streams.items():
|
||||||
|
if stream.status in TERMINAL_STATES and stream.ended_at is not None:
|
||||||
|
if stream.ended_at < oldest_ended_at:
|
||||||
|
oldest_ended_at = stream.ended_at
|
||||||
|
oldest_run_id = run_id
|
||||||
|
if oldest_run_id is not None:
|
||||||
|
del self._streams[oldest_run_id]
|
||||||
|
return oldest_run_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _next_id(self, stream: _RunStream) -> str:
|
||||||
|
stream.seq += 1
|
||||||
|
return f"{int(time.time() * 1000)}-{stream.seq}"
|
||||||
|
|
||||||
|
def _estimate_size(self, event: StreamEvent) -> int:
|
||||||
|
base = len(event.id) + len(event.event) + 100
|
||||||
|
if event.data is None:
|
||||||
|
return base
|
||||||
|
if isinstance(event.data, str):
|
||||||
|
return base + len(event.data)
|
||||||
|
if isinstance(event.data, (dict, list)):
|
||||||
|
try:
|
||||||
|
return base + len(json.dumps(event.data, default=str))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return base + 200
|
||||||
|
return base + 50
|
||||||
|
|
||||||
|
def _evict_overflow(self, stream: _RunStream) -> None:
|
||||||
|
while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes:
|
||||||
|
if not stream.events:
|
||||||
|
break
|
||||||
|
evicted = stream.events.pop(0)
|
||||||
|
stream.id_to_offset.pop(evicted.id, None)
|
||||||
|
stream.current_bytes -= self._estimate_size(evicted)
|
||||||
|
stream.start_offset += 1
|
||||||
|
|
||||||
|
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||||
|
stream = await self._get_or_create_stream(run_id)
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status != StreamStatus.ACTIVE:
|
||||||
|
raise StreamTerminatedError(
|
||||||
|
f"Cannot publish to {stream.status.value} stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = StreamEvent(id=self._next_id(stream), event=event, data=data)
|
||||||
|
absolute_offset = stream.start_offset + len(stream.events)
|
||||||
|
stream.events.append(entry)
|
||||||
|
stream.id_to_offset[entry.id] = absolute_offset
|
||||||
|
stream.current_bytes += self._estimate_size(entry)
|
||||||
|
stream.last_publish_at = time.monotonic()
|
||||||
|
self._evict_overflow(stream)
|
||||||
|
stream.condition.notify_all()
|
||||||
|
return entry.id
|
||||||
|
|
||||||
|
async def publish_end(self, run_id: str) -> str:
|
||||||
|
return await self.publish_terminal(run_id, StreamStatus.ENDED)
|
||||||
|
|
||||||
|
async def publish_terminal(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
kind: StreamStatus,
|
||||||
|
data: Any = None,
|
||||||
|
) -> str:
|
||||||
|
if kind not in TERMINAL_STATES:
|
||||||
|
raise ValueError(f"Invalid terminal kind: {kind}")
|
||||||
|
|
||||||
|
stream = await self._get_or_create_stream(run_id)
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status != StreamStatus.ACTIVE:
|
||||||
|
for evt in reversed(stream.events):
|
||||||
|
if evt.event in ("end", "cancel", "error", "dead_letter"):
|
||||||
|
return evt.id
|
||||||
|
return ""
|
||||||
|
|
||||||
|
event_name = {
|
||||||
|
StreamStatus.ENDED: "end",
|
||||||
|
StreamStatus.CANCELLED: "cancel",
|
||||||
|
StreamStatus.ERRORED: "error",
|
||||||
|
}[kind]
|
||||||
|
entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data)
|
||||||
|
absolute_offset = stream.start_offset + len(stream.events)
|
||||||
|
stream.events.append(entry)
|
||||||
|
stream.id_to_offset[entry.id] = absolute_offset
|
||||||
|
stream.current_bytes += self._estimate_size(entry)
|
||||||
|
stream.status = kind
|
||||||
|
stream.ended_at = time.monotonic()
|
||||||
|
stream.awaiting_input = False
|
||||||
|
stream.condition.notify_all()
|
||||||
|
logger.debug("Stream %s terminal: %s", run_id, kind.value)
|
||||||
|
return entry.id
|
||||||
|
|
||||||
|
async def cancel(self, run_id: str) -> None:
|
||||||
|
await self.publish_terminal(run_id, StreamStatus.CANCELLED)
|
||||||
|
|
||||||
|
async def subscribe(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
heartbeat_interval: float = 15.0,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
stream = await self._get_or_create_stream(run_id)
|
||||||
|
resume = self._resolve_resume_point(stream, last_event_id)
|
||||||
|
next_offset = resume.next_offset
|
||||||
|
|
||||||
|
async with stream.condition:
|
||||||
|
stream.subscriber_count += 1
|
||||||
|
stream.last_subscribe_at = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
entry_to_yield: StreamEvent | None = None
|
||||||
|
sentinel_to_yield: StreamEvent | None = None
|
||||||
|
should_return = False
|
||||||
|
should_wait = False
|
||||||
|
|
||||||
|
async with stream.condition:
|
||||||
|
if self._closed or stream.status == StreamStatus.CLOSED:
|
||||||
|
sentinel_to_yield = CANCELLED_SENTINEL
|
||||||
|
should_return = True
|
||||||
|
elif next_offset < stream.start_offset:
|
||||||
|
next_offset = stream.start_offset
|
||||||
|
else:
|
||||||
|
local_index = next_offset - stream.start_offset
|
||||||
|
if 0 <= local_index < len(stream.events):
|
||||||
|
entry_to_yield = stream.events[local_index]
|
||||||
|
next_offset += 1
|
||||||
|
if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"):
|
||||||
|
should_return = True
|
||||||
|
elif stream.status in TERMINAL_STATES:
|
||||||
|
sentinel_to_yield = END_SENTINEL
|
||||||
|
should_return = True
|
||||||
|
else:
|
||||||
|
should_wait = True
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
stream.condition.wait(),
|
||||||
|
timeout=heartbeat_interval,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sentinel_to_yield is not None:
|
||||||
|
yield sentinel_to_yield
|
||||||
|
if should_return:
|
||||||
|
return
|
||||||
|
continue
|
||||||
|
|
||||||
|
if entry_to_yield is not None:
|
||||||
|
yield entry_to_yield
|
||||||
|
if should_return:
|
||||||
|
return
|
||||||
|
continue
|
||||||
|
|
||||||
|
if should_wait:
|
||||||
|
async with stream.condition:
|
||||||
|
local_index = next_offset - stream.start_offset
|
||||||
|
has_events = 0 <= local_index < len(stream.events)
|
||||||
|
is_terminal = stream.status in TERMINAL_STATES
|
||||||
|
if not has_events and not is_terminal:
|
||||||
|
yield HEARTBEAT_SENTINEL
|
||||||
|
|
||||||
|
finally:
|
||||||
|
async with stream.condition:
|
||||||
|
stream.subscriber_count = max(0, stream.subscriber_count - 1)
|
||||||
|
|
||||||
|
async def mark_awaiting_input(self, run_id: str) -> None:
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status == StreamStatus.ACTIVE:
|
||||||
|
stream.awaiting_input = True
|
||||||
|
stream.awaiting_since = time.monotonic()
|
||||||
|
logger.debug("Stream %s marked as awaiting input", run_id)
|
||||||
|
|
||||||
|
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||||
|
if delay > 0:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
await self._do_cleanup(run_id, "manual")
|
||||||
|
|
||||||
|
async def _do_cleanup(self, run_id: str, reason: str) -> None:
|
||||||
|
async with self._registry_lock:
|
||||||
|
stream = self._streams.pop(run_id, None)
|
||||||
|
if stream is not None:
|
||||||
|
async with stream.condition:
|
||||||
|
stream.status = StreamStatus.CLOSED
|
||||||
|
stream.condition.notify_all()
|
||||||
|
logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason)
|
||||||
|
|
||||||
|
async def _mark_dead_letter(self, run_id: str, reason: str) -> None:
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status != StreamStatus.ACTIVE:
|
||||||
|
return
|
||||||
|
entry = StreamEvent(
|
||||||
|
id=self._next_id(stream),
|
||||||
|
event="dead_letter",
|
||||||
|
data={"reason": reason, "timestamp": time.time()},
|
||||||
|
)
|
||||||
|
absolute_offset = stream.start_offset + len(stream.events)
|
||||||
|
stream.events.append(entry)
|
||||||
|
stream.id_to_offset[entry.id] = absolute_offset
|
||||||
|
stream.current_bytes += self._estimate_size(entry)
|
||||||
|
stream.status = StreamStatus.ERRORED
|
||||||
|
stream.ended_at = time.monotonic()
|
||||||
|
stream.condition.notify_all()
|
||||||
|
logger.warning("Stream %s marked as dead letter: %s", run_id, reason)
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
while not self._closed:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self._cleanup_interval)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
to_cleanup: list[tuple[str, str]] = []
|
||||||
|
to_mark_dead: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
async with self._registry_lock:
|
||||||
|
for run_id, stream in list(self._streams.items()):
|
||||||
|
if now - stream.created_at > self._max_age:
|
||||||
|
to_cleanup.append((run_id, "max_age_exceeded"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if stream.status == StreamStatus.ACTIVE:
|
||||||
|
timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout
|
||||||
|
last_activity = stream.last_publish_at or stream.created_at
|
||||||
|
if now - last_activity > timeout:
|
||||||
|
to_mark_dead.append((run_id, "no_publish_timeout"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if stream.status in TERMINAL_STATES and stream.ended_at:
|
||||||
|
if stream.subscriber_count > 0:
|
||||||
|
continue
|
||||||
|
last_sub = stream.last_subscribe_at or stream.ended_at
|
||||||
|
if now - last_sub > self._orphan_timeout:
|
||||||
|
to_cleanup.append((run_id, "orphan"))
|
||||||
|
continue
|
||||||
|
if now - stream.ended_at > self._terminal_ttl:
|
||||||
|
to_cleanup.append((run_id, "ttl_expired"))
|
||||||
|
|
||||||
|
for run_id, reason in to_mark_dead:
|
||||||
|
await self._mark_dead_letter(run_id, reason)
|
||||||
|
for run_id, reason in to_cleanup:
|
||||||
|
await self._do_cleanup(run_id, reason)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE)
|
||||||
|
terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES)
|
||||||
|
total_events = sum(len(s.events) for s in self._streams.values())
|
||||||
|
total_bytes = sum(s.current_bytes for s in self._streams.values())
|
||||||
|
total_subs = sum(s.subscriber_count for s in self._streams.values())
|
||||||
|
return {
|
||||||
|
"total_streams": len(self._streams),
|
||||||
|
"active_streams": active,
|
||||||
|
"terminal_streams": terminal,
|
||||||
|
"total_events": total_events,
|
||||||
|
"total_bytes": total_bytes,
|
||||||
|
"total_subscribers": total_subs,
|
||||||
|
"closed": self._closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _resolve_resume_point(
|
||||||
|
self,
|
||||||
|
stream: _RunStream,
|
||||||
|
last_event_id: str | None,
|
||||||
|
) -> ResumeResult:
|
||||||
|
if last_event_id is None:
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="fresh")
|
||||||
|
if last_event_id in stream.id_to_offset:
|
||||||
|
return ResumeResult(
|
||||||
|
next_offset=stream.id_to_offset[last_event_id] + 1,
|
||||||
|
status="resumed",
|
||||||
|
)
|
||||||
|
|
||||||
|
parts = last_event_id.split("-")
|
||||||
|
if len(parts) != 2:
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||||
|
try:
|
||||||
|
event_ts = int(parts[0])
|
||||||
|
_event_seq = int(parts[1])
|
||||||
|
except ValueError:
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||||
|
|
||||||
|
if stream.events:
|
||||||
|
try:
|
||||||
|
oldest_parts = stream.events[0].id.split("-")
|
||||||
|
oldest_ts = int(oldest_parts[0])
|
||||||
|
if event_ts < oldest_ts:
|
||||||
|
return ResumeResult(
|
||||||
|
next_offset=stream.start_offset,
|
||||||
|
status="evicted",
|
||||||
|
gap_count=stream.start_offset,
|
||||||
|
)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="unknown")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MemoryStreamBridge"]
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Redis-backed stream bridge placeholder owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deerflow.runtime.stream_bridge import StreamBridge, StreamEvent
|
||||||
|
|
||||||
|
|
||||||
|
class RedisStreamBridge(StreamBridge):
|
||||||
|
"""Reserved app-owned Redis implementation.
|
||||||
|
|
||||||
|
Phase 1 intentionally keeps Redis out of the harness package. The concrete
|
||||||
|
implementation will live here once cross-process streaming is introduced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, redis_url: str) -> None:
|
||||||
|
self._redis_url = redis_url
|
||||||
|
|
||||||
|
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
|
|
||||||
|
async def publish_end(self, run_id: str) -> str:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
|
|
||||||
|
def subscribe(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
heartbeat_interval: float = 15.0,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
|
|
||||||
|
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""App-owned stream bridge factory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
|
|
||||||
|
from deerflow.config.stream_bridge_config import get_stream_bridge_config
|
||||||
|
from deerflow.runtime.stream_bridge import StreamBridge
|
||||||
|
|
||||||
|
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def build_stream_bridge(config=None) -> AbstractAsyncContextManager[StreamBridge]:
|
||||||
|
"""Build the configured app-owned stream bridge."""
|
||||||
|
return _build_stream_bridge_impl(config)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _build_stream_bridge_impl(config=None) -> AsyncIterator[StreamBridge]:
|
||||||
|
if config is None:
|
||||||
|
config = get_stream_bridge_config()
|
||||||
|
|
||||||
|
if config is None or config.type == "memory":
|
||||||
|
maxsize = config.queue_maxsize if config is not None else 256
|
||||||
|
bridge = MemoryStreamBridge(queue_maxsize=maxsize)
|
||||||
|
await bridge.start()
|
||||||
|
logger.info("Stream bridge initialised: memory (queue_maxsize=%d)", maxsize)
|
||||||
|
try:
|
||||||
|
yield bridge
|
||||||
|
finally:
|
||||||
|
await bridge.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
if config.type == "redis":
|
||||||
|
if not config.redis_url:
|
||||||
|
raise ValueError("Redis stream bridge requires redis_url")
|
||||||
|
bridge = RedisStreamBridge(redis_url=config.redis_url)
|
||||||
|
await bridge.start()
|
||||||
|
logger.info("Stream bridge initialised: redis (%s)", config.redis_url)
|
||||||
|
try:
|
||||||
|
yield bridge
|
||||||
|
finally:
|
||||||
|
await bridge.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown stream bridge type: {config.type!r}")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user