mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 15:11:09 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 565ab432fc | |||
| df63c104a7 | |||
| 7b9d224b3a | |||
| 0572ef44b9 | |||
| 839563f308 | |||
| 62bdfe3abc | |||
| b61ce3527b | |||
| 2d5f6f1b3d | |||
| 69bf3dafd8 | |||
| 6cbec13495 | |||
| 31e5b586a1 | |||
| e75a2ff29a | |||
| 185f5649dd |
+2
-23
@@ -1,6 +1,3 @@
|
|||||||
# Serper API Key (Google Search) - https://serper.dev
|
|
||||||
SERPER_API_KEY=your-serper-api-key
|
|
||||||
|
|
||||||
# TAVILY API Key
|
# TAVILY API Key
|
||||||
TAVILY_API_KEY=your-tavily-api-key
|
TAVILY_API_KEY=your-tavily-api-key
|
||||||
|
|
||||||
@@ -9,9 +6,8 @@ JINA_API_KEY=your-jina-api-key
|
|||||||
|
|
||||||
# InfoQuest API Key
|
# InfoQuest API Key
|
||||||
INFOQUEST_API_KEY=your-infoquest-api-key
|
INFOQUEST_API_KEY=your-infoquest-api-key
|
||||||
# Browser CORS allowlist for split-origin or port-forwarded deployments (comma-separated exact origins).
|
# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001
|
||||||
# Leave unset when using the unified nginx endpoint, e.g. http://localhost:2026.
|
# CORS_ORIGINS=http://localhost:3000
|
||||||
# GATEWAY_CORS_ORIGINS=http://localhost:3000,http://127.0.0.1:3000
|
|
||||||
|
|
||||||
# Optional:
|
# Optional:
|
||||||
# FIRECRAWL_API_KEY=your-firecrawl-api-key
|
# FIRECRAWL_API_KEY=your-firecrawl-api-key
|
||||||
@@ -28,7 +24,6 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
# SLACK_BOT_TOKEN=your-slack-bot-token
|
# SLACK_BOT_TOKEN=your-slack-bot-token
|
||||||
# SLACK_APP_TOKEN=your-slack-app-token
|
# SLACK_APP_TOKEN=your-slack-app-token
|
||||||
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
|
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
|
||||||
# DISCORD_BOT_TOKEN=your-discord-bot-token
|
|
||||||
|
|
||||||
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
|
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
|
||||||
# LANGSMITH_TRACING=true
|
# LANGSMITH_TRACING=true
|
||||||
@@ -44,19 +39,3 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
#
|
#
|
||||||
# WECOM_BOT_ID=your-wecom-bot-id
|
# WECOM_BOT_ID=your-wecom-bot-id
|
||||||
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
||||||
# DINGTALK_CLIENT_ID=your-dingtalk-client-id
|
|
||||||
# DINGTALK_CLIENT_SECRET=your-dingtalk-client-secret
|
|
||||||
|
|
||||||
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
|
|
||||||
# GATEWAY_ENABLE_DOCS=false
|
|
||||||
|
|
||||||
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
|
|
||||||
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
|
|
||||||
# /api/* rewrites). They default to localhost values that match `make dev` and
|
|
||||||
# `make start`, so most local users do not need to set them.
|
|
||||||
#
|
|
||||||
# Override only when the Gateway is not on localhost:8001 (e.g. when the
|
|
||||||
# frontend and gateway run on different hosts, in containers with a service
|
|
||||||
# alias, or behind a different port). docker-compose already sets these.
|
|
||||||
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
|
||||||
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
name: Publish Containers
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- "v*"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
|
|
||||||
backend-container:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
packages: write
|
|
||||||
attestations: write
|
|
||||||
id-token: write
|
|
||||||
env:
|
|
||||||
REGISTRY: ghcr.io
|
|
||||||
IMAGE_NAME: ${{ github.repository }}-backend
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
- name: Log in to the Container registry
|
|
||||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
|
||||||
with:
|
|
||||||
registry: ${{ env.REGISTRY }}
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
|
||||||
with:
|
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
|
||||||
tags: |
|
|
||||||
type=ref,event=tag
|
|
||||||
type=ref,event=branch
|
|
||||||
type=sha
|
|
||||||
type=raw,value=latest,enable={{is_default_branch}}
|
|
||||||
- name: Build and push Docker image
|
|
||||||
id: push
|
|
||||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: backend/Dockerfile
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
|
|
||||||
- name: Generate artifact attestation
|
|
||||||
uses: actions/attest-build-provenance@v2
|
|
||||||
with:
|
|
||||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
|
||||||
subject-digest: ${{ steps.push.outputs.digest }}
|
|
||||||
push-to-registry: true
|
|
||||||
|
|
||||||
frontend-container:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
packages: write
|
|
||||||
attestations: write
|
|
||||||
id-token: write
|
|
||||||
env:
|
|
||||||
REGISTRY: ghcr.io
|
|
||||||
IMAGE_NAME: ${{ github.repository }}-frontend
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
- name: Log in to the Container registry
|
|
||||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
|
||||||
with:
|
|
||||||
registry: ${{ env.REGISTRY }}
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
|
||||||
with:
|
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
|
||||||
tags: |
|
|
||||||
type=ref,event=tag
|
|
||||||
type=ref,event=branch
|
|
||||||
type=sha
|
|
||||||
type=raw,value=latest,enable={{is_default_branch}}
|
|
||||||
- name: Build and push Docker image
|
|
||||||
id: push
|
|
||||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: frontend/Dockerfile
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
|
|
||||||
- name: Generate artifact attestation
|
|
||||||
uses: actions/attest-build-provenance@v2
|
|
||||||
with:
|
|
||||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
|
||||||
subject-digest: ${{ steps.push.outputs.digest }}
|
|
||||||
push-to-registry: true
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
name: E2E Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ 'main' ]
|
|
||||||
paths:
|
|
||||||
- 'frontend/**'
|
|
||||||
- '.github/workflows/e2e-tests.yml'
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
|
||||||
- 'frontend/**'
|
|
||||||
- '.github/workflows/e2e-tests.yml'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
e2e-tests:
|
|
||||||
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 15
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '22'
|
|
||||||
|
|
||||||
- name: Enable Corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Use pinned pnpm version
|
|
||||||
run: corepack prepare pnpm@10.26.2 --activate
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Install Playwright Chromium
|
|
||||||
working-directory: frontend
|
|
||||||
run: npx playwright install chromium --with-deps
|
|
||||||
|
|
||||||
- name: Run E2E tests
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm exec playwright test
|
|
||||||
env:
|
|
||||||
SKIP_ENV_VALIDATION: '1'
|
|
||||||
|
|
||||||
- name: Upload Playwright report
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
if: ${{ !cancelled() }}
|
|
||||||
with:
|
|
||||||
name: playwright-report
|
|
||||||
path: frontend/playwright-report/
|
|
||||||
retention-days: 7
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
name: Frontend Unit Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ 'main' ]
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
frontend-unit-tests:
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 15
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '22'
|
|
||||||
|
|
||||||
- name: Enable Corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Use pinned pnpm version
|
|
||||||
run: corepack prepare pnpm@10.26.2 --activate
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Run unit tests of frontend
|
|
||||||
working-directory: frontend
|
|
||||||
run: make test
|
|
||||||
@@ -40,7 +40,6 @@ coverage/
|
|||||||
skills/custom/*
|
skills/custom/*
|
||||||
logs/
|
logs/
|
||||||
log/
|
log/
|
||||||
debug.log
|
|
||||||
|
|
||||||
# Local git hooks (keep only on this machine, do not push)
|
# Local git hooks (keep only on this machine, do not push)
|
||||||
.githooks/
|
.githooks/
|
||||||
@@ -56,7 +55,5 @@ web/
|
|||||||
backend/Dockerfile.langgraph
|
backend/Dockerfile.langgraph
|
||||||
config.yaml.bak
|
config.yaml.bak
|
||||||
.playwright-mcp
|
.playwright-mcp
|
||||||
/frontend/test-results/
|
|
||||||
/frontend/playwright-report/
|
|
||||||
.gstack/
|
.gstack/
|
||||||
.worktrees
|
.worktrees
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
repos:
|
|
||||||
# Backend: ruff lint + format via uv (uses the same ruff version as backend deps)
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: ruff
|
|
||||||
name: ruff lint
|
|
||||||
entry: bash -c 'cd backend && uv run ruff check --fix "${@/#backend\//}"' --
|
|
||||||
language: system
|
|
||||||
types_or: [python]
|
|
||||||
files: ^backend/
|
|
||||||
- id: ruff-format
|
|
||||||
name: ruff format
|
|
||||||
entry: bash -c 'cd backend && uv run ruff format "${@/#backend\//}"' --
|
|
||||||
language: system
|
|
||||||
types_or: [python]
|
|
||||||
files: ^backend/
|
|
||||||
|
|
||||||
# Frontend: eslint + prettier (must run from frontend/ for node_modules resolution)
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: frontend-eslint
|
|
||||||
name: eslint (frontend)
|
|
||||||
entry: bash -c 'cd frontend && npx eslint --fix "${@/#frontend\//}"' --
|
|
||||||
language: system
|
|
||||||
types_or: [javascript, tsx, ts]
|
|
||||||
files: ^frontend/
|
|
||||||
|
|
||||||
- id: frontend-prettier
|
|
||||||
name: prettier (frontend)
|
|
||||||
entry: bash -c 'cd frontend && npx prettier --write "${@/#frontend\//}"' --
|
|
||||||
language: system
|
|
||||||
files: ^frontend/
|
|
||||||
types_or: [javascript, tsx, ts, json, css]
|
|
||||||
+26
-25
@@ -46,12 +46,12 @@ Docker provides a consistent, isolated environment with all dependencies pre-con
|
|||||||
All services will start with hot-reload enabled:
|
All services will start with hot-reload enabled:
|
||||||
- Frontend changes are automatically reloaded
|
- Frontend changes are automatically reloaded
|
||||||
- Backend changes trigger automatic restart
|
- Backend changes trigger automatic restart
|
||||||
- Gateway-hosted LangGraph-compatible runtime supports hot-reload
|
- LangGraph server supports hot-reload
|
||||||
|
|
||||||
4. **Access the application**:
|
4. **Access the application**:
|
||||||
- Web Interface: http://localhost:2026
|
- Web Interface: http://localhost:2026
|
||||||
- API Gateway: http://localhost:2026/api/*
|
- API Gateway: http://localhost:2026/api/*
|
||||||
- LangGraph-compatible API: http://localhost:2026/api/langgraph/*
|
- LangGraph: http://localhost:2026/api/langgraph/*
|
||||||
|
|
||||||
#### Docker Commands
|
#### Docker Commands
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ Use these as practical starting points for development and review environments:
|
|||||||
If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket:
|
If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
unable to get image 'deer-flow-gateway': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock
|
unable to get image 'deer-flow-dev-langgraph': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock
|
||||||
```
|
```
|
||||||
|
|
||||||
Recommended fix: add your current user to the `docker` group so Docker commands work without `sudo`.
|
Recommended fix: add your current user to the `docker` group so Docker commands work without `sudo`.
|
||||||
@@ -131,8 +131,9 @@ Host Machine
|
|||||||
Docker Compose (deer-flow-dev)
|
Docker Compose (deer-flow-dev)
|
||||||
├→ nginx (port 2026) ← Reverse proxy
|
├→ nginx (port 2026) ← Reverse proxy
|
||||||
├→ web (port 3000) ← Frontend with hot-reload
|
├→ web (port 3000) ← Frontend with hot-reload
|
||||||
├→ gateway (port 8001) ← Gateway API + LangGraph-compatible runtime with hot-reload
|
├→ api (port 8001) ← Gateway API with hot-reload
|
||||||
└→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode
|
├→ langgraph (port 2024) ← LangGraph server with hot-reload
|
||||||
|
└→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode
|
||||||
```
|
```
|
||||||
|
|
||||||
**Benefits of Docker Development**:
|
**Benefits of Docker Development**:
|
||||||
@@ -165,7 +166,7 @@ Required tools:
|
|||||||
|
|
||||||
1. **Configure the application** (same as Docker setup above)
|
1. **Configure the application** (same as Docker setup above)
|
||||||
|
|
||||||
2. **Install dependencies** (this also sets up pre-commit hooks):
|
2. **Install dependencies**:
|
||||||
```bash
|
```bash
|
||||||
make install
|
make install
|
||||||
```
|
```
|
||||||
@@ -183,13 +184,17 @@ Required tools:
|
|||||||
|
|
||||||
If you need to start services individually:
|
If you need to start services individually:
|
||||||
|
|
||||||
1. **Start backend service**:
|
1. **Start backend services**:
|
||||||
```bash
|
```bash
|
||||||
# Terminal 1: Start Gateway API + embedded agent runtime (port 8001)
|
# Terminal 1: Start LangGraph Server (port 2024)
|
||||||
cd backend
|
cd backend
|
||||||
make dev
|
make dev
|
||||||
|
|
||||||
# Terminal 2: Start Frontend (port 3000)
|
# Terminal 2: Start Gateway API (port 8001)
|
||||||
|
cd backend
|
||||||
|
make gateway
|
||||||
|
|
||||||
|
# Terminal 3: Start Frontend (port 3000)
|
||||||
cd frontend
|
cd frontend
|
||||||
pnpm dev
|
pnpm dev
|
||||||
```
|
```
|
||||||
@@ -207,10 +212,10 @@ If you need to start services individually:
|
|||||||
|
|
||||||
The nginx configuration provides:
|
The nginx configuration provides:
|
||||||
- Unified entry point on port 2026
|
- Unified entry point on port 2026
|
||||||
- Rewrites `/api/langgraph/*` to Gateway's LangGraph-compatible API (8001)
|
- Routes `/api/langgraph/*` to LangGraph Server (2024)
|
||||||
- Routes other `/api/*` endpoints to Gateway API (8001)
|
- Routes other `/api/*` endpoints to Gateway API (8001)
|
||||||
- Routes non-API requests to Frontend (3000)
|
- Routes non-API requests to Frontend (3000)
|
||||||
- Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist
|
- Centralized CORS handling
|
||||||
- SSE/streaming support for real-time agent responses
|
- SSE/streaming support for real-time agent responses
|
||||||
- Optimized timeouts for long-running operations
|
- Optimized timeouts for long-running operations
|
||||||
|
|
||||||
@@ -230,8 +235,8 @@ deer-flow/
|
|||||||
│ └── nginx.local.conf # Nginx config for local dev
|
│ └── nginx.local.conf # Nginx config for local dev
|
||||||
├── backend/ # Backend application
|
├── backend/ # Backend application
|
||||||
│ ├── src/
|
│ ├── src/
|
||||||
│ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001)
|
│ │ ├── gateway/ # Gateway API (port 8001)
|
||||||
│ │ ├── agents/ # LangGraph agent runtime used by Gateway
|
│ │ ├── agents/ # LangGraph agents (port 2024)
|
||||||
│ │ ├── mcp/ # Model Context Protocol integration
|
│ │ ├── mcp/ # Model Context Protocol integration
|
||||||
│ │ ├── skills/ # Skills system
|
│ │ ├── skills/ # Skills system
|
||||||
│ │ └── sandbox/ # Sandbox execution
|
│ │ └── sandbox/ # Sandbox execution
|
||||||
@@ -251,7 +256,8 @@ Browser
|
|||||||
↓
|
↓
|
||||||
Nginx (port 2026) ← Unified entry point
|
Nginx (port 2026) ← Unified entry point
|
||||||
├→ Frontend (port 3000) ← / (non-API requests)
|
├→ Frontend (port 3000) ← / (non-API requests)
|
||||||
└→ Gateway API (port 8001) ← /api/* and /api/langgraph/* (LangGraph-compatible agent interactions)
|
├→ Gateway API (port 8001) ← /api/models, /api/mcp, /api/skills, /api/threads/*/artifacts
|
||||||
|
└→ LangGraph Server (port 2024) ← /api/langgraph/* (agent interactions)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Development Workflow
|
## Development Workflow
|
||||||
@@ -292,24 +298,19 @@ Nginx (port 2026) ← Unified entry point
|
|||||||
```bash
|
```bash
|
||||||
# Backend tests
|
# Backend tests
|
||||||
cd backend
|
cd backend
|
||||||
make test
|
uv run pytest
|
||||||
|
|
||||||
# Frontend unit tests
|
# Frontend checks
|
||||||
cd frontend
|
cd frontend
|
||||||
make test
|
pnpm check
|
||||||
|
|
||||||
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
|
|
||||||
cd frontend
|
|
||||||
make test-e2e
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### PR Regression Checks
|
### PR Regression Checks
|
||||||
|
|
||||||
Every pull request triggers the following CI workflows:
|
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including:
|
||||||
|
|
||||||
- **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
|
- `tests/test_provisioner_kubeconfig.py`
|
||||||
- **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
|
- `tests/test_docker_sandbox_mode_detection.py`
|
||||||
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
|
|
||||||
|
|
||||||
## Code Style
|
## Code Style
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# DeerFlow - Unified Development Environment
|
# DeerFlow - Unified Development Environment
|
||||||
|
|
||||||
.PHONY: help config config-upgrade check install setup doctor detect-thread-boundaries dev dev-daemon start start-daemon stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
.PHONY: help config config-upgrade check install setup doctor dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||||
|
|
||||||
BASH ?= bash
|
BASH ?= bash
|
||||||
BACKEND_UV_RUN = cd backend && uv run
|
BACKEND_UV_RUN = cd backend && uv run
|
||||||
@@ -23,23 +23,28 @@ help:
|
|||||||
@echo " make config - Generate local config files (aborts if config already exists)"
|
@echo " make config - Generate local config files (aborts if config already exists)"
|
||||||
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
|
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
|
||||||
@echo " make check - Check if all required tools are installed"
|
@echo " make check - Check if all required tools are installed"
|
||||||
@echo " make detect-thread-boundaries - Inventory async/thread boundary points"
|
@echo " make install - Install all dependencies (frontend + backend)"
|
||||||
@echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)"
|
|
||||||
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
|
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
|
||||||
@echo " make dev - Start all services in development mode (with hot-reloading)"
|
@echo " make dev - Start all services in development mode (with hot-reloading)"
|
||||||
|
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
|
||||||
@echo " make dev-daemon - Start dev services in background (daemon mode)"
|
@echo " make dev-daemon - Start dev services in background (daemon mode)"
|
||||||
|
@echo " make dev-daemon-pro - Start dev daemon + Gateway mode (experimental)"
|
||||||
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
|
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
|
||||||
|
@echo " make start-pro - Start in prod + Gateway mode (experimental)"
|
||||||
@echo " make start-daemon - Start prod services in background (daemon mode)"
|
@echo " make start-daemon - Start prod services in background (daemon mode)"
|
||||||
|
@echo " make start-daemon-pro - Start prod daemon + Gateway mode (experimental)"
|
||||||
@echo " make stop - Stop all running services"
|
@echo " make stop - Stop all running services"
|
||||||
@echo " make clean - Clean up processes and temporary files"
|
@echo " make clean - Clean up processes and temporary files"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Docker Production Commands:"
|
@echo "Docker Production Commands:"
|
||||||
@echo " make up - Build and start production Docker services (localhost:2026)"
|
@echo " make up - Build and start production Docker services (localhost:2026)"
|
||||||
|
@echo " make up-pro - Build and start production Docker in Gateway mode (experimental)"
|
||||||
@echo " make down - Stop and remove production Docker containers"
|
@echo " make down - Stop and remove production Docker containers"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Docker Development Commands:"
|
@echo "Docker Development Commands:"
|
||||||
@echo " make docker-init - Pull the sandbox image"
|
@echo " make docker-init - Pull the sandbox image"
|
||||||
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
|
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
|
||||||
|
@echo " make docker-start-pro - Start Docker in Gateway mode (experimental, no LangGraph container)"
|
||||||
@echo " make docker-stop - Stop Docker development services"
|
@echo " make docker-stop - Stop Docker development services"
|
||||||
@echo " make docker-logs - View Docker development logs"
|
@echo " make docker-logs - View Docker development logs"
|
||||||
@echo " make docker-logs-frontend - View Docker frontend logs"
|
@echo " make docker-logs-frontend - View Docker frontend logs"
|
||||||
@@ -52,9 +57,6 @@ setup:
|
|||||||
doctor:
|
doctor:
|
||||||
@$(BACKEND_UV_RUN) python ../scripts/doctor.py
|
@$(BACKEND_UV_RUN) python ../scripts/doctor.py
|
||||||
|
|
||||||
detect-thread-boundaries:
|
|
||||||
@$(PYTHON) ./scripts/detect_thread_boundaries.py
|
|
||||||
|
|
||||||
config:
|
config:
|
||||||
@$(PYTHON) ./scripts/configure.py
|
@$(PYTHON) ./scripts/configure.py
|
||||||
|
|
||||||
@@ -71,8 +73,6 @@ install:
|
|||||||
@cd backend && uv sync
|
@cd backend && uv sync
|
||||||
@echo "Installing frontend dependencies..."
|
@echo "Installing frontend dependencies..."
|
||||||
@cd frontend && pnpm install
|
@cd frontend && pnpm install
|
||||||
@echo "Installing pre-commit hooks..."
|
|
||||||
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
|
|
||||||
@echo "✓ All dependencies installed"
|
@echo "✓ All dependencies installed"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "=========================================="
|
@echo "=========================================="
|
||||||
@@ -99,7 +99,7 @@ setup-sandbox:
|
|||||||
echo ""; \
|
echo ""; \
|
||||||
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
|
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
|
||||||
echo "Detected Apple Container on macOS, pulling image..."; \
|
echo "Detected Apple Container on macOS, pulling image..."; \
|
||||||
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
|
container pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
|
||||||
fi; \
|
fi; \
|
||||||
if command -v docker >/dev/null 2>&1; then \
|
if command -v docker >/dev/null 2>&1; then \
|
||||||
echo "Pulling image using Docker..."; \
|
echo "Pulling image using Docker..."; \
|
||||||
@@ -121,21 +121,41 @@ dev:
|
|||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev
|
||||||
|
|
||||||
|
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
|
||||||
|
dev-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway
|
||||||
|
|
||||||
# Start all services in production mode (with optimizations)
|
# Start all services in production mode (with optimizations)
|
||||||
start:
|
start:
|
||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod
|
||||||
|
|
||||||
|
# Start all services in prod + Gateway mode (experimental)
|
||||||
|
start-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway
|
||||||
|
|
||||||
# Start all services in daemon mode (background)
|
# Start all services in daemon mode (background)
|
||||||
dev-daemon:
|
dev-daemon:
|
||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --daemon
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --daemon
|
||||||
|
|
||||||
|
# Start daemon + Gateway mode (experimental)
|
||||||
|
dev-daemon-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway --daemon
|
||||||
|
|
||||||
# Start prod services in daemon mode (background)
|
# Start prod services in daemon mode (background)
|
||||||
start-daemon:
|
start-daemon:
|
||||||
@$(PYTHON) ./scripts/check.py
|
@$(PYTHON) ./scripts/check.py
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --daemon
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --daemon
|
||||||
|
|
||||||
|
# Start prod daemon + Gateway mode (experimental)
|
||||||
|
start-daemon-pro:
|
||||||
|
@$(PYTHON) ./scripts/check.py
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway --daemon
|
||||||
|
|
||||||
# Stop all services
|
# Stop all services
|
||||||
stop:
|
stop:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --stop
|
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --stop
|
||||||
@@ -160,6 +180,10 @@ docker-init:
|
|||||||
docker-start:
|
docker-start:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start
|
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start
|
||||||
|
|
||||||
|
# Start Docker in Gateway mode (experimental)
|
||||||
|
docker-start-pro:
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start --gateway
|
||||||
|
|
||||||
# Stop Docker development environment
|
# Stop Docker development environment
|
||||||
docker-stop:
|
docker-stop:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh stop
|
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh stop
|
||||||
@@ -182,6 +206,10 @@ docker-logs-gateway:
|
|||||||
up:
|
up:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh
|
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh
|
||||||
|
|
||||||
|
# Build and start production services in Gateway mode
|
||||||
|
up-pro:
|
||||||
|
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh --gateway
|
||||||
|
|
||||||
# Stop and remove production containers
|
# Stop and remove production containers
|
||||||
down:
|
down:
|
||||||
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh down
|
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh down
|
||||||
|
|||||||
@@ -243,9 +243,10 @@ make up # Build images and start all production services
|
|||||||
make down # Stop and remove containers
|
make down # Stop and remove containers
|
||||||
```
|
```
|
||||||
|
|
||||||
Access: http://localhost:2026
|
> [!NOTE]
|
||||||
|
> The LangGraph agent server currently runs via `langgraph dev` (the open-source CLI server).
|
||||||
|
|
||||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
Access: http://localhost:2026
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||||
|
|
||||||
@@ -253,7 +254,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
|||||||
|
|
||||||
If you prefer running services locally:
|
If you prefer running services locally:
|
||||||
|
|
||||||
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root. Set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or `DEER_FLOW_CONFIG_PATH` to point at a specific config file. Runtime state defaults to `.deer-flow` under the project root and can be moved with `DEER_FLOW_HOME`; skills default to `skills/` under the project root and can be moved with `DEER_FLOW_SKILLS_PATH`. Run `make doctor` to verify your setup before starting.
|
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root (can be overridden via `DEER_FLOW_CONFIG_PATH`). Run `make doctor` to verify your setup before starting.
|
||||||
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
||||||
|
|
||||||
1. **Check prerequisites**:
|
1. **Check prerequisites**:
|
||||||
@@ -263,7 +264,7 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
|
|||||||
|
|
||||||
2. **Install dependencies**:
|
2. **Install dependencies**:
|
||||||
```bash
|
```bash
|
||||||
make install # Install backend + frontend dependencies + pre-commit hooks
|
make install # Install backend + frontend dependencies
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **(Optional) Pre-pull sandbox image**:
|
3. **(Optional) Pre-pull sandbox image**:
|
||||||
@@ -288,31 +289,53 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
|
|||||||
|
|
||||||
#### Startup Modes
|
#### Startup Modes
|
||||||
|
|
||||||
DeerFlow runs the agent runtime inside the Gateway API. Development mode enables hot-reload; production mode uses a pre-built frontend.
|
DeerFlow supports multiple startup modes across two dimensions:
|
||||||
|
|
||||||
|
- **Dev / Prod** — dev enables hot-reload; prod uses pre-built frontend
|
||||||
|
- **Standard / Gateway** — standard uses a separate LangGraph server (4 processes); Gateway mode (experimental) embeds the agent runtime in the Gateway API (3 processes)
|
||||||
|
|
||||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||||
|
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||||
|
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||||
|
|
||||||
| Action | Local | Docker Dev | Docker Prod |
|
| Action | Local | Docker Dev | Docker Prod |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||||
|
|
||||||
Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx.
|
> **Gateway mode** eliminates the LangGraph server process — the Gateway API handles agent execution directly via async tasks, managing its own concurrency.
|
||||||
|
|
||||||
|
#### Why Gateway Mode?
|
||||||
|
|
||||||
|
In standard mode, DeerFlow runs a dedicated [LangGraph Platform](https://langchain-ai.github.io/langgraph/) server alongside the Gateway API. This architecture works well but has trade-offs:
|
||||||
|
|
||||||
|
| | Standard Mode | Gateway Mode |
|
||||||
|
|---|---|---|
|
||||||
|
| **Architecture** | Gateway (REST API) + LangGraph (agent runtime) | Gateway embeds agent runtime |
|
||||||
|
| **Concurrency** | `--n-jobs-per-worker` per worker (requires license) | `--workers` × async tasks (no per-worker cap) |
|
||||||
|
| **Containers / Processes** | 4 (frontend, gateway, langgraph, nginx) | 3 (frontend, gateway, nginx) |
|
||||||
|
| **Resource usage** | Higher (two Python runtimes) | Lower (single Python runtime) |
|
||||||
|
| **LangGraph Platform license** | Required for production images | Not required |
|
||||||
|
| **Cold start** | Slower (two services to initialize) | Faster |
|
||||||
|
|
||||||
|
Both modes are functionally equivalent — the same agents, tools, and skills work in either mode.
|
||||||
|
|
||||||
#### Docker Production Deployment
|
#### Docker Production Deployment
|
||||||
|
|
||||||
`deploy.sh` supports building and starting separately:
|
`deploy.sh` supports building and starting separately. Images are mode-agnostic — runtime mode is selected at start time:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# One-step (build + start)
|
# One-step (build + start)
|
||||||
deploy.sh
|
deploy.sh # standard mode (default)
|
||||||
|
deploy.sh --gateway # gateway mode
|
||||||
|
|
||||||
# Two-step (build once, start later)
|
# Two-step (build once, start with any mode)
|
||||||
deploy.sh build # build all images
|
deploy.sh build # build all images
|
||||||
deploy.sh start # start pre-built images
|
deploy.sh start # start in standard mode
|
||||||
|
deploy.sh start --gateway # start in gateway mode
|
||||||
|
|
||||||
# Stop
|
# Stop
|
||||||
deploy.sh down
|
deploy.sh down
|
||||||
@@ -347,14 +370,13 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
|
|||||||
| Feishu / Lark | WebSocket | Moderate |
|
| Feishu / Lark | WebSocket | Moderate |
|
||||||
| WeChat | Tencent iLink (long-polling) | Moderate |
|
| WeChat | Tencent iLink (long-polling) | Moderate |
|
||||||
| WeCom | WebSocket | Moderate |
|
| WeCom | WebSocket | Moderate |
|
||||||
| DingTalk | Stream Push (WebSocket) | Moderate |
|
|
||||||
|
|
||||||
**Configuration in `config.yaml`:**
|
**Configuration in `config.yaml`:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api)
|
# LangGraph Server URL (default: http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL (default: http://localhost:8001)
|
# Gateway API URL (default: http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -417,19 +439,11 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # Client ID of your DingTalk application
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # Client Secret of your DingTalk application
|
|
||||||
allowed_users: [] # empty = allow all
|
|
||||||
card_template_id: "" # Optional: AI Card template ID for streaming typewriter effect
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
|
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
|
||||||
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
|
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
|
||||||
- IM channel workers call Gateway's LangGraph-compatible API internally and automatically attach process-local internal auth plus the CSRF cookie/header pair required for thread and run creation.
|
|
||||||
|
|
||||||
Set the corresponding API keys in your `.env` file:
|
Set the corresponding API keys in your `.env` file:
|
||||||
|
|
||||||
@@ -452,10 +466,6 @@ WECHAT_ILINK_BOT_ID=your_ilink_bot_id
|
|||||||
# WeCom
|
# WeCom
|
||||||
WECOM_BOT_ID=your_bot_id
|
WECOM_BOT_ID=your_bot_id
|
||||||
WECOM_BOT_SECRET=your_bot_secret
|
WECOM_BOT_SECRET=your_bot_secret
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Telegram Setup**
|
**Telegram Setup**
|
||||||
@@ -494,15 +504,7 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
|
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
|
||||||
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
|
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
|
||||||
|
|
||||||
**DingTalk Setup**
|
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://langgraph:2024` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
|
|
||||||
1. Create a DingTalk application in the [DingTalk Developer Console](https://open.dingtalk.com/) and enable **Robot** capability.
|
|
||||||
2. Set the message receiving mode to **Stream Mode** in the robot configuration page.
|
|
||||||
3. Copy the `Client ID` and `Client Secret`, set `DINGTALK_CLIENT_ID` and `DINGTALK_CLIENT_SECRET` in `.env`, and enable the channel in `config.yaml`.
|
|
||||||
4. *(Optional)* To enable streaming AI Card replies (typewriter effect), create an **AI Card** template on the [DingTalk Card Platform](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), then set `card_template_id` in `config.yaml` to the template ID. You also need to apply for the `Card.Streaming.Write` and `Card.Instance.Write` permissions.
|
|
||||||
|
|
||||||
|
|
||||||
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://gateway:8001/api` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
|
||||||
|
|
||||||
**Commands**
|
**Commands**
|
||||||
|
|
||||||
@@ -628,7 +630,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
|
|||||||
|
|
||||||
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
||||||
|
|
||||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
|
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
|
||||||
|
|
||||||
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
||||||
|
|
||||||
@@ -656,8 +658,6 @@ This is the difference between a chatbot with tool access and an agent with an a
|
|||||||
|
|
||||||
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
|
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
|
||||||
|
|
||||||
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
|
|
||||||
|
|
||||||
### Long-Term Memory
|
### Long-Term Memory
|
||||||
|
|
||||||
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
|
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
|
||||||
|
|||||||
+3
-22
@@ -228,7 +228,7 @@ make down # Stop and remove containers
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Le runtime d'agent s'exécute actuellement dans la Gateway. nginx réécrit `/api/langgraph/*` vers l'API compatible LangGraph servie par la Gateway.
|
> Le serveur d'agents LangGraph fonctionne actuellement via `langgraph dev` (le serveur CLI open source).
|
||||||
|
|
||||||
Accès : http://localhost:2026
|
Accès : http://localhost:2026
|
||||||
|
|
||||||
@@ -290,14 +290,13 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca
|
|||||||
| Telegram | Bot API (long-polling) | Facile |
|
| Telegram | Bot API (long-polling) | Facile |
|
||||||
| Slack | Socket Mode | Modérée |
|
| Slack | Socket Mode | Modérée |
|
||||||
| Feishu / Lark | WebSocket | Modérée |
|
| Feishu / Lark | WebSocket | Modérée |
|
||||||
| DingTalk | Stream Push (WebSocket) | Modérée |
|
|
||||||
|
|
||||||
**Configuration dans `config.yaml` :**
|
**Configuration dans `config.yaml` :**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api)
|
# LangGraph Server URL (default: http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL (default: http://localhost:8001)
|
# Gateway API URL (default: http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -342,13 +341,6 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # ClientId depuis DingTalk Open Platform
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret depuis DingTalk Open Platform
|
|
||||||
allowed_users: [] # vide = tout le monde autorisé
|
|
||||||
card_template_id: "" # Optionnel : ID de modèle AI Card pour l'effet machine à écrire en streaming
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Définissez les clés API correspondantes dans votre fichier `.env` :
|
Définissez les clés API correspondantes dans votre fichier `.env` :
|
||||||
@@ -364,10 +356,6 @@ SLACK_APP_TOKEN=xapp-...
|
|||||||
# Feishu / Lark
|
# Feishu / Lark
|
||||||
FEISHU_APP_ID=cli_xxxx
|
FEISHU_APP_ID=cli_xxxx
|
||||||
FEISHU_APP_SECRET=your_app_secret
|
FEISHU_APP_SECRET=your_app_secret
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Configuration Telegram**
|
**Configuration Telegram**
|
||||||
@@ -390,13 +378,6 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
3. Dans **Events**, abonnez-vous à `im.message.receive_v1` et sélectionnez le mode **Long Connection**.
|
3. Dans **Events**, abonnez-vous à `im.message.receive_v1` et sélectionnez le mode **Long Connection**.
|
||||||
4. Copiez l'App ID et l'App Secret. Définissez `FEISHU_APP_ID` et `FEISHU_APP_SECRET` dans `.env` et activez le canal dans `config.yaml`.
|
4. Copiez l'App ID et l'App Secret. Définissez `FEISHU_APP_ID` et `FEISHU_APP_SECRET` dans `.env` et activez le canal dans `config.yaml`.
|
||||||
|
|
||||||
**Configuration DingTalk**
|
|
||||||
|
|
||||||
1. Créez une application sur [DingTalk Open Platform](https://open.dingtalk.com/) et activez la capacité **Robot**.
|
|
||||||
2. Dans la page de configuration du robot, définissez le mode de réception des messages sur **Stream**.
|
|
||||||
3. Copiez le `Client ID` et le `Client Secret`. Définissez `DINGTALK_CLIENT_ID` et `DINGTALK_CLIENT_SECRET` dans `.env` et activez le canal dans `config.yaml`.
|
|
||||||
4. *(Optionnel)* Pour activer les réponses en streaming AI Card (effet machine à écrire), créez un modèle **AI Card** sur la [plateforme de cartes DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), puis définissez `card_template_id` dans `config.yaml` avec l'ID du modèle. Vous devez également demander les permissions `Card.Streaming.Write` et `Card.Instance.Write`.
|
|
||||||
|
|
||||||
**Commandes**
|
**Commandes**
|
||||||
|
|
||||||
Une fois un canal connecté, vous pouvez interagir avec DeerFlow directement depuis le chat :
|
Une fois un canal connecté, vous pouvez interagir avec DeerFlow directement depuis le chat :
|
||||||
|
|||||||
+3
-22
@@ -181,7 +181,7 @@ make down # コンテナを停止して削除
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Agentランタイムは現在Gateway内で実行されます。`/api/langgraph/*`はnginxによってGatewayのLangGraph-compatible APIへ書き換えられます。
|
> LangGraphエージェントサーバーは現在`langgraph dev`(オープンソースCLIサーバー)経由で実行されます。
|
||||||
|
|
||||||
アクセス: http://localhost:2026
|
アクセス: http://localhost:2026
|
||||||
|
|
||||||
@@ -243,14 +243,13 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート
|
|||||||
| Telegram | Bot API(ロングポーリング) | 簡単 |
|
| Telegram | Bot API(ロングポーリング) | 簡単 |
|
||||||
| Slack | Socket Mode | 中程度 |
|
| Slack | Socket Mode | 中程度 |
|
||||||
| Feishu / Lark | WebSocket | 中程度 |
|
| Feishu / Lark | WebSocket | 中程度 |
|
||||||
| DingTalk | Stream Push(WebSocket) | 中程度 |
|
|
||||||
|
|
||||||
**`config.yaml`での設定:**
|
**`config.yaml`での設定:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL(デフォルト: http://localhost:8001/api)
|
# LangGraphサーバーURL(デフォルト: http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL(デフォルト: http://localhost:8001)
|
# Gateway API URL(デフォルト: http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -295,13 +294,6 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # DingTalk Open PlatformのClientId
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # DingTalk Open PlatformのClientSecret
|
|
||||||
allowed_users: [] # 空 = 全員許可
|
|
||||||
card_template_id: "" # オプション:ストリーミングタイプライター効果用のAIカードテンプレートID
|
|
||||||
```
|
```
|
||||||
|
|
||||||
対応するAPIキーを`.env`ファイルに設定します:
|
対応するAPIキーを`.env`ファイルに設定します:
|
||||||
@@ -317,10 +309,6 @@ SLACK_APP_TOKEN=xapp-...
|
|||||||
# Feishu / Lark
|
# Feishu / Lark
|
||||||
FEISHU_APP_ID=cli_xxxx
|
FEISHU_APP_ID=cli_xxxx
|
||||||
FEISHU_APP_SECRET=your_app_secret
|
FEISHU_APP_SECRET=your_app_secret
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Telegramのセットアップ**
|
**Telegramのセットアップ**
|
||||||
@@ -343,13 +331,6 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
3. **イベント**で`im.message.receive_v1`を購読し、**ロングコネクション**モードを選択。
|
3. **イベント**で`im.message.receive_v1`を購読し、**ロングコネクション**モードを選択。
|
||||||
4. App IDとApp Secretをコピー。`.env`に`FEISHU_APP_ID`と`FEISHU_APP_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
|
4. App IDとApp Secretをコピー。`.env`に`FEISHU_APP_ID`と`FEISHU_APP_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
|
||||||
|
|
||||||
**DingTalkのセットアップ**
|
|
||||||
|
|
||||||
1. [DingTalk Open Platform](https://open.dingtalk.com/)でアプリを作成し、**ロボット**機能を有効化します。
|
|
||||||
2. ロボット設定ページでメッセージ受信モードを**Streamモード**に設定します。
|
|
||||||
3. `Client ID`と`Client Secret`をコピー。`.env`に`DINGTALK_CLIENT_ID`と`DINGTALK_CLIENT_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
|
|
||||||
4. *(オプション)* ストリーミングAIカード返信(タイプライター効果)を有効にするには、[DingTalkカードプラットフォーム](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)で**AIカード**テンプレートを作成し、`config.yaml`の`card_template_id`にテンプレートIDを設定します。`Card.Streaming.Write` および `Card.Instance.Write` 権限の申請も必要です。
|
|
||||||
|
|
||||||
**コマンド**
|
**コマンド**
|
||||||
|
|
||||||
チャネル接続後、チャットから直接DeerFlowと対話できます:
|
チャネル接続後、チャットから直接DeerFlowと対話できます:
|
||||||
|
|||||||
@@ -256,7 +256,6 @@ DeerFlow принимает задачи прямо из мессенджеро
|
|||||||
| Telegram | Bot API (long-polling) | Просто |
|
| Telegram | Bot API (long-polling) | Просто |
|
||||||
| Slack | Socket Mode | Средне |
|
| Slack | Socket Mode | Средне |
|
||||||
| Feishu / Lark | WebSocket | Средне |
|
| Feishu / Lark | WebSocket | Средне |
|
||||||
| DingTalk | Stream Push (WebSocket) | Средне |
|
|
||||||
|
|
||||||
**Конфигурация в `config.yaml`:**
|
**Конфигурация в `config.yaml`:**
|
||||||
|
|
||||||
@@ -279,13 +278,6 @@ channels:
|
|||||||
enabled: true
|
enabled: true
|
||||||
bot_token: $TELEGRAM_BOT_TOKEN
|
bot_token: $TELEGRAM_BOT_TOKEN
|
||||||
allowed_users: []
|
allowed_users: []
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # ClientId с DingTalk Open Platform
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret с DingTalk Open Platform
|
|
||||||
allowed_users: [] # пусто = разрешить всем
|
|
||||||
card_template_id: "" # Опционально: ID шаблона AI Card для потокового эффекта печатной машинки
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Настройка Telegram**
|
**Настройка Telegram**
|
||||||
@@ -293,13 +285,6 @@ channels:
|
|||||||
1. Напишите [@BotFather](https://t.me/BotFather), отправьте `/newbot` и скопируйте HTTP API-токен.
|
1. Напишите [@BotFather](https://t.me/BotFather), отправьте `/newbot` и скопируйте HTTP API-токен.
|
||||||
2. Укажите `TELEGRAM_BOT_TOKEN` в `.env` и включите канал в `config.yaml`.
|
2. Укажите `TELEGRAM_BOT_TOKEN` в `.env` и включите канал в `config.yaml`.
|
||||||
|
|
||||||
**Настройка DingTalk**
|
|
||||||
|
|
||||||
1. Создайте приложение на [DingTalk Open Platform](https://open.dingtalk.com/) и включите возможность **Робот**.
|
|
||||||
2. На странице настроек робота установите режим приёма сообщений на **Stream**.
|
|
||||||
3. Скопируйте `Client ID` и `Client Secret`. Укажите `DINGTALK_CLIENT_ID` и `DINGTALK_CLIENT_SECRET` в `.env` и включите канал в `config.yaml`.
|
|
||||||
4. *(Опционально)* Для включения потоковых ответов AI Card (эффект печатной машинки) создайте шаблон **AI Card** на [платформе карточек DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), затем укажите `card_template_id` в `config.yaml` с ID шаблона. Также необходимо запросить разрешения `Card.Streaming.Write` и `Card.Instance.Write`.
|
|
||||||
|
|
||||||
**Доступные команды**
|
**Доступные команды**
|
||||||
|
|
||||||
| Команда | Описание |
|
| Команда | Описание |
|
||||||
|
|||||||
+4
-23
@@ -184,7 +184,7 @@ make down # 停止并移除容器
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 当前 Agent 运行时嵌入在 Gateway 中运行,`/api/langgraph/*` 会由 nginx 重写到 Gateway 的 LangGraph-compatible API。
|
> 当前 LangGraph agent server 通过开源 CLI 服务 `langgraph dev` 运行。
|
||||||
|
|
||||||
访问地址:http://localhost:2026
|
访问地址:http://localhost:2026
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ make down # 停止并移除容器
|
|||||||
|
|
||||||
如果你更希望直接在本地启动各个服务:
|
如果你更希望直接在本地启动各个服务:
|
||||||
|
|
||||||
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`。可以用 `DEER_FLOW_PROJECT_ROOT` 显式指定项目根目录,也可以用 `DEER_FLOW_CONFIG_PATH` 指向某个具体配置文件。运行期状态默认写到项目根目录下的 `.deer-flow`,可用 `DEER_FLOW_HOME` 覆盖;skills 默认读取项目根目录下的 `skills/`,可用 `DEER_FLOW_SKILLS_PATH` 覆盖。
|
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`,也可以通过 `DEER_FLOW_CONFIG_PATH` 覆盖。
|
||||||
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
|
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
|
||||||
|
|
||||||
1. **检查依赖环境**:
|
1. **检查依赖环境**:
|
||||||
@@ -248,14 +248,13 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
|
|||||||
| Slack | Socket Mode | 中等 |
|
| Slack | Socket Mode | 中等 |
|
||||||
| Feishu / Lark | WebSocket | 中等 |
|
| Feishu / Lark | WebSocket | 中等 |
|
||||||
| 企业微信智能机器人 | WebSocket | 中等 |
|
| 企业微信智能机器人 | WebSocket | 中等 |
|
||||||
| 钉钉 | Stream Push(WebSocket) | 中等 |
|
|
||||||
|
|
||||||
**`config.yaml` 中的配置示例:**
|
**`config.yaml` 中的配置示例:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
channels:
|
channels:
|
||||||
# LangGraph-compatible Gateway API base URL(默认:http://localhost:8001/api)
|
# LangGraph Server URL(默认:http://localhost:2024)
|
||||||
langgraph_url: http://localhost:8001/api
|
langgraph_url: http://localhost:2024
|
||||||
# Gateway API URL(默认:http://localhost:8001)
|
# Gateway API URL(默认:http://localhost:8001)
|
||||||
gateway_url: http://localhost:8001
|
gateway_url: http://localhost:8001
|
||||||
|
|
||||||
@@ -305,13 +304,6 @@ channels:
|
|||||||
context:
|
context:
|
||||||
thinking_enabled: true
|
thinking_enabled: true
|
||||||
subagent_enabled: true
|
subagent_enabled: true
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID # 钉钉开放平台 ClientId
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET # 钉钉开放平台 ClientSecret
|
|
||||||
allowed_users: [] # 留空表示允许所有人
|
|
||||||
card_template_id: "" # 可选:AI 卡片模板 ID,用于流式打字机效果
|
|
||||||
```
|
```
|
||||||
|
|
||||||
说明:
|
说明:
|
||||||
@@ -335,10 +327,6 @@ FEISHU_APP_SECRET=your_app_secret
|
|||||||
# 企业微信智能机器人
|
# 企业微信智能机器人
|
||||||
WECOM_BOT_ID=your_bot_id
|
WECOM_BOT_ID=your_bot_id
|
||||||
WECOM_BOT_SECRET=your_bot_secret
|
WECOM_BOT_SECRET=your_bot_secret
|
||||||
|
|
||||||
# 钉钉
|
|
||||||
DINGTALK_CLIENT_ID=your_client_id
|
|
||||||
DINGTALK_CLIENT_SECRET=your_client_secret
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Telegram 配置**
|
**Telegram 配置**
|
||||||
@@ -369,13 +357,6 @@ DINGTALK_CLIENT_SECRET=your_client_secret
|
|||||||
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
|
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
|
||||||
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
|
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
|
||||||
|
|
||||||
**钉钉配置**
|
|
||||||
|
|
||||||
1. 在 [钉钉开放平台](https://open.dingtalk.com/) 创建应用,并启用 **机器人** 能力。
|
|
||||||
2. 在机器人配置页面设置消息接收模式为 **Stream模式**。
|
|
||||||
3. 复制 `Client ID` 和 `Client Secret`,在 `.env` 中设置 `DINGTALK_CLIENT_ID` 和 `DINGTALK_CLIENT_SECRET`,并在 `config.yaml` 中启用该渠道。
|
|
||||||
4. *(可选)* 如需开启流式 AI 卡片回复(打字机效果),请在[钉钉卡片平台](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)创建 **AI 卡片**模板,然后在 `config.yaml` 中将 `card_template_id` 设为该模板 ID。同时需要申请 `Card.Streaming.Write` 和 `Card.Instance.Write` 权限。
|
|
||||||
|
|
||||||
**命令**
|
**命令**
|
||||||
|
|
||||||
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
|
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
|
||||||
|
|||||||
+51
-60
@@ -7,13 +7,15 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
DeerFlow is a LangGraph-based AI super agent system with a full-stack architecture. The backend provides a "super agent" with sandbox execution, persistent memory, subagent delegation, and extensible tool integration - all operating in per-thread isolated environments.
|
DeerFlow is a LangGraph-based AI super agent system with a full-stack architecture. The backend provides a "super agent" with sandbox execution, persistent memory, subagent delegation, and extensible tool integration - all operating in per-thread isolated environments.
|
||||||
|
|
||||||
**Architecture**:
|
**Architecture**:
|
||||||
- **Gateway API** (port 8001): REST API plus embedded LangGraph-compatible agent runtime
|
- **LangGraph Server** (port 2024): Agent runtime and workflow execution
|
||||||
|
- **Gateway API** (port 8001): REST API for models, MCP, skills, memory, artifacts, uploads, and local thread cleanup
|
||||||
- **Frontend** (port 3000): Next.js web interface
|
- **Frontend** (port 3000): Next.js web interface
|
||||||
- **Nginx** (port 2026): Unified reverse proxy entry point
|
- **Nginx** (port 2026): Unified reverse proxy entry point
|
||||||
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
|
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
|
||||||
|
|
||||||
**Runtime**:
|
**Runtime Modes**:
|
||||||
- `make dev`, Docker dev, and production all run the agent runtime in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Nginx exposes that runtime at `/api/langgraph/*` and rewrites it to Gateway's native `/api/*` routers.
|
- **Standard mode** (`make dev`): LangGraph Server handles agent execution as a separate process. 4 processes total.
|
||||||
|
- **Gateway mode** (`make dev-pro`, experimental): Agent runtime embedded in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Service manages its own concurrency via async tasks. 3 processes total, no LangGraph Server.
|
||||||
|
|
||||||
**Project Structure**:
|
**Project Structure**:
|
||||||
```
|
```
|
||||||
@@ -23,7 +25,7 @@ deer-flow/
|
|||||||
├── extensions_config.json # MCP servers and skills configuration
|
├── extensions_config.json # MCP servers and skills configuration
|
||||||
├── backend/ # Backend application (this directory)
|
├── backend/ # Backend application (this directory)
|
||||||
│ ├── Makefile # Backend-only commands (dev, gateway, lint)
|
│ ├── Makefile # Backend-only commands (dev, gateway, lint)
|
||||||
│ ├── langgraph.json # LangGraph Studio graph configuration
|
│ ├── langgraph.json # LangGraph server configuration
|
||||||
│ ├── packages/
|
│ ├── packages/
|
||||||
│ │ └── harness/ # deerflow-harness package (import: deerflow.*)
|
│ │ └── harness/ # deerflow-harness package (import: deerflow.*)
|
||||||
│ │ ├── pyproject.toml
|
│ │ ├── pyproject.toml
|
||||||
@@ -81,15 +83,16 @@ When making code changes, you MUST update the relevant documentation:
|
|||||||
```bash
|
```bash
|
||||||
make check # Check system requirements
|
make check # Check system requirements
|
||||||
make install # Install all dependencies (frontend + backend)
|
make install # Install all dependencies (frontend + backend)
|
||||||
make dev # Start all services (Gateway + Frontend + Nginx), with config.yaml preflight
|
make dev # Start all services (LangGraph + Gateway + Frontend + Nginx), with config.yaml preflight
|
||||||
make start # Start production services locally
|
make dev-pro # Gateway mode (experimental): skip LangGraph, agent runtime embedded in Gateway
|
||||||
|
make start-pro # Production + Gateway mode (experimental)
|
||||||
make stop # Stop all services
|
make stop # Stop all services
|
||||||
```
|
```
|
||||||
|
|
||||||
**Backend directory** (for backend development only):
|
**Backend directory** (for backend development only):
|
||||||
```bash
|
```bash
|
||||||
make install # Install backend dependencies
|
make install # Install backend dependencies
|
||||||
make dev # Run Gateway API with reload (port 8001)
|
make dev # Run LangGraph server only (port 2024)
|
||||||
make gateway # Run Gateway API only (port 8001)
|
make gateway # Run Gateway API only (port 8001)
|
||||||
make test # Run all backend tests
|
make test # Run all backend tests
|
||||||
make lint # Lint with ruff
|
make lint # Lint with ruff
|
||||||
@@ -112,7 +115,7 @@ CI runs these regression tests for every pull request via [.github/workflows/bac
|
|||||||
The backend is split into two layers with a strict dependency direction:
|
The backend is split into two layers with a strict dependency direction:
|
||||||
|
|
||||||
- **Harness** (`packages/harness/deerflow/`): Publishable agent framework package (`deerflow-harness`). Import prefix: `deerflow.*`. Contains agent orchestration, tools, sandbox, models, MCP, skills, config — everything needed to build and run agents.
|
- **Harness** (`packages/harness/deerflow/`): Publishable agent framework package (`deerflow-harness`). Import prefix: `deerflow.*`. Contains agent orchestration, tools, sandbox, models, MCP, skills, config — everything needed to build and run agents.
|
||||||
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram, DingTalk).
|
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram).
|
||||||
|
|
||||||
**Dependency rule**: App imports deerflow, but deerflow never imports app. This boundary is enforced by `tests/test_harness_boundary.py` which runs in CI.
|
**Dependency rule**: App imports deerflow, but deerflow never imports app. This boundary is enforced by `tests/test_harness_boundary.py` which runs in CI.
|
||||||
|
|
||||||
@@ -153,26 +156,20 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
### Middleware Chain
|
### Middleware Chain
|
||||||
|
|
||||||
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
|
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
|
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||||
5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
|
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
||||||
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||||
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||||
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||||
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
|
||||||
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||||
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
|
||||||
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
|
||||||
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
|
|
||||||
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
|
||||||
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
|
||||||
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
|
||||||
|
|
||||||
### Configuration System
|
### Configuration System
|
||||||
|
|
||||||
@@ -205,9 +202,7 @@ Configuration priority:
|
|||||||
|
|
||||||
### Gateway API (`app/gateway/`)
|
### Gateway API (`app/gateway/`)
|
||||||
|
|
||||||
FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWAY_ENABLE_DOCS=false` to disable `/docs`, `/redoc`, and `/openapi.json` in production (default: enabled).
|
FastAPI application on port 8001 with health check at `GET /health`.
|
||||||
|
|
||||||
CORS is same-origin by default when requests enter through nginx on port 2026. Split-origin or port-forwarded browser clients must opt in with `GATEWAY_CORS_ORIGINS` (comma-separated exact origins); Gateway `CORSMiddleware` and `CSRFMiddleware` both read that variable so browser CORS and auth-origin checks stay aligned.
|
|
||||||
|
|
||||||
**Routers**:
|
**Routers**:
|
||||||
|
|
||||||
@@ -225,33 +220,27 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
|||||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
|
|
||||||
**RunManager / RunStore contract**:
|
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||||
- `RunManager.get()` is async; direct callers must `await` it.
|
|
||||||
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
|
|
||||||
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
|
|
||||||
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
|
|
||||||
|
|
||||||
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
|
|
||||||
|
|
||||||
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
||||||
|
|
||||||
**Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir`
|
**Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle
|
**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle
|
||||||
**Implementations**:
|
**Implementations**:
|
||||||
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
- `LocalSandboxProvider` - Singleton local filesystem execution with path mappings
|
||||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||||
- Translation: `LocalSandboxProvider` builds per-thread `PathMapping`s for the user-data prefixes at acquire time; `tools.py` keeps `replace_virtual_path()` / `replace_virtual_paths_in_command()` as a defense-in-depth layer (and for path validation). AIO has the directories volume-mounted at the same virtual paths inside its container, so both implementations accept `/mnt/user-data/...` natively.
|
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||||
- Detection: `is_local_sandbox()` accepts both `sandbox_id == "local"` (legacy / no-thread) and `sandbox_id.startswith("local:")` (per-thread)
|
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||||
|
|
||||||
**Sandbox Tools** (in `packages/harness/deerflow/sandbox/tools.py`):
|
**Sandbox Tools** (in `packages/harness/deerflow/sandbox/tools.py`):
|
||||||
- `bash` - Execute commands with path translation and error handling
|
- `bash` - Execute commands with path translation and error handling
|
||||||
- `ls` - Directory listing (tree format, max 2 levels)
|
- `ls` - Directory listing (tree format, max 2 levels)
|
||||||
- `read_file` - Read file contents with optional line range
|
- `read_file` - Read file contents with optional line range
|
||||||
- `write_file` - Write/append to files, creates directories; overwrites by default and exposes the `append` argument in the model-facing schema for end-of-file writes
|
- `write_file` - Write/append to files, creates directories
|
||||||
- `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process
|
- `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process
|
||||||
|
|
||||||
### Subagent System (`packages/harness/deerflow/subagents/`)
|
### Subagent System (`packages/harness/deerflow/subagents/`)
|
||||||
@@ -271,10 +260,8 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
|
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
|
||||||
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
|
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
|
||||||
- `view_image` - Read image as base64 (added only if model supports vision)
|
- `view_image` - Read image as base64 (added only if model supports vision)
|
||||||
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
|
|
||||||
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
|
|
||||||
4. **Subagent tool** (if enabled):
|
4. **Subagent tool** (if enabled):
|
||||||
- `task` - Delegate to subagent (description, prompt, subagent_type)
|
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
|
||||||
|
|
||||||
**Community tools** (`packages/harness/deerflow/community/`):
|
**Community tools** (`packages/harness/deerflow/community/`):
|
||||||
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
|
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
|
||||||
@@ -322,10 +309,9 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
|
|
||||||
### IM Channels System (`app/channels/`)
|
### IM Channels System (`app/channels/`)
|
||||||
|
|
||||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via the LangGraph Server.
|
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via the LangGraph Server.
|
||||||
|
|
||||||
|
**Architecture**: Channels communicate with the LangGraph Server through `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side.
|
||||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
|
||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
||||||
@@ -333,25 +319,23 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
||||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||||
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
- `slack.py` / `feishu.py` / `telegram.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place)
|
||||||
|
|
||||||
**Message Flow**:
|
**Message Flow**:
|
||||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||||
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
3. For chat: look up/create thread on LangGraph Server
|
||||||
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||||
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||||
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||||
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
7. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||||
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
8. Outbound → channel callbacks → platform reply
|
||||||
9. Outbound → channel callbacks → platform reply
|
|
||||||
|
|
||||||
**Configuration** (`config.yaml` -> `channels`):
|
**Configuration** (`config.yaml` -> `channels`):
|
||||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
- `langgraph_url` - LangGraph Server URL (default: `http://localhost:2024`)
|
||||||
- `gateway_url` - Gateway API URL for auxiliary commands (default: `http://localhost:8001`)
|
- `gateway_url` - Gateway API URL for auxiliary commands (default: `http://localhost:8001`)
|
||||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://langgraph:2024` / `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token)
|
||||||
|
|
||||||
|
|
||||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||||
|
|
||||||
@@ -364,11 +348,10 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
**Per-User Isolation**:
|
**Per-User Isolation**:
|
||||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||||
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
|
|
||||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||||
- Absolute `storage_path` in config opts out of per-user isolation
|
- Absolute `storage_path` in config opts out of per-user isolation
|
||||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
|
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||||
|
|
||||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||||
@@ -421,9 +404,9 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
|
|||||||
|
|
||||||
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes.
|
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes.
|
||||||
|
|
||||||
**Architecture**: Imports the same `deerflow` modules that Gateway API uses. Shares the same config files and data directories. No FastAPI dependency.
|
**Architecture**: Imports the same `deerflow` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
|
||||||
|
|
||||||
**Agent Conversation**:
|
**Agent Conversation** (replaces LangGraph Server):
|
||||||
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
|
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
|
||||||
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
|
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
|
||||||
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
|
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
|
||||||
@@ -486,15 +469,20 @@ This starts all services and makes the application available at `http://localhos
|
|||||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||||
|
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||||
|
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||||
|
|
||||||
| Action | Local | Docker Dev | Docker Prod |
|
| Action | Local | Docker Dev | Docker Prod |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||||
|
|
||||||
|
Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
|
||||||
|
|
||||||
**Nginx routing**:
|
**Nginx routing**:
|
||||||
- `/api/langgraph/*` → Gateway embedded runtime (8001), rewritten to `/api/*`
|
- Standard mode: `/api/langgraph/*` → LangGraph Server (2024)
|
||||||
|
- Gateway mode: `/api/langgraph/*` → Gateway embedded runtime (8001) (via envsubst)
|
||||||
- `/api/*` (other) → Gateway API (8001)
|
- `/api/*` (other) → Gateway API (8001)
|
||||||
- `/` (non-API) → Frontend (3000)
|
- `/` (non-API) → Frontend (3000)
|
||||||
|
|
||||||
@@ -503,11 +491,15 @@ This starts all services and makes the application available at `http://localhos
|
|||||||
From the **backend** directory:
|
From the **backend** directory:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Gateway API
|
# Terminal 1: LangGraph server
|
||||||
|
make dev
|
||||||
|
|
||||||
|
# Terminal 2: Gateway API
|
||||||
make gateway
|
make gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
Direct access (without nginx):
|
Direct access (without nginx):
|
||||||
|
- LangGraph: `http://localhost:2024`
|
||||||
- Gateway: `http://localhost:8001`
|
- Gateway: `http://localhost:8001`
|
||||||
|
|
||||||
### Frontend Configuration
|
### Frontend Configuration
|
||||||
@@ -528,7 +520,6 @@ Multi-file upload with automatic document conversion:
|
|||||||
- Rejects directory inputs before copying so uploads stay all-or-nothing
|
- Rejects directory inputs before copying so uploads stay all-or-nothing
|
||||||
- Reuses one conversion worker per request when called from an active event loop
|
- Reuses one conversion worker per request when called from an active event loop
|
||||||
- Files stored in thread-isolated directories
|
- Files stored in thread-isolated directories
|
||||||
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
|
|
||||||
- Agent receives uploaded file list via `UploadsMiddleware`
|
- Agent receives uploaded file list via `UploadsMiddleware`
|
||||||
|
|
||||||
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
|
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
|
||||||
|
|||||||
@@ -56,8 +56,11 @@ export OPENAI_API_KEY="your-api-key"
|
|||||||
### Run the Development Server
|
### Run the Development Server
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Gateway API + embedded agent runtime
|
# Terminal 1: LangGraph server
|
||||||
make dev
|
make dev
|
||||||
|
|
||||||
|
# Terminal 2: Gateway API
|
||||||
|
make gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|||||||
@@ -50,12 +50,6 @@ COPY backend ./backend
|
|||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
||||||
|
|
||||||
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
|
|
||||||
# containers where locale configuration may be missing and the default encoding is not UTF-8.
|
|
||||||
ENV LANG=C.UTF-8
|
|
||||||
ENV LC_ALL=C.UTF-8
|
|
||||||
ENV PYTHONIOENCODING=utf-8
|
|
||||||
|
|
||||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||||
# source distributions in development containers.
|
# source distributions in development containers.
|
||||||
@@ -72,10 +66,6 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
|
|||||||
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
||||||
FROM python:3.12-slim-bookworm
|
FROM python:3.12-slim-bookworm
|
||||||
|
|
||||||
ENV LANG=C.UTF-8
|
|
||||||
ENV LC_ALL=C.UTF-8
|
|
||||||
ENV PYTHONIOENCODING=utf-8
|
|
||||||
|
|
||||||
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
|
|||||||
+1
-1
@@ -2,7 +2,7 @@ install:
|
|||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
dev:
|
dev:
|
||||||
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload
|
uv run langgraph dev --no-browser --no-reload --n-jobs-per-worker 10
|
||||||
|
|
||||||
gateway:
|
gateway:
|
||||||
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||||
|
|||||||
+33
-29
@@ -11,26 +11,31 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent
|
|||||||
│ Nginx (Port 2026) │
|
│ Nginx (Port 2026) │
|
||||||
│ Unified reverse proxy │
|
│ Unified reverse proxy │
|
||||||
└───────┬──────────────────┬───────────┘
|
└───────┬──────────────────┬───────────┘
|
||||||
│
|
│ │
|
||||||
/api/langgraph/* │ /api/* (other)
|
/api/langgraph/* │ │ /api/* (other)
|
||||||
rewritten to /api/* │
|
▼ ▼
|
||||||
▼
|
┌────────────────────┐ ┌────────────────────────┐
|
||||||
┌────────────────────────────────────────┐
|
│ LangGraph Server │ │ Gateway API (8001) │
|
||||||
│ Gateway API (8001) │
|
│ (Port 2024) │ │ FastAPI REST │
|
||||||
│ FastAPI REST + agent runtime │
|
│ │ │ │
|
||||||
│ │
|
│ ┌────────────────┐ │ │ Models, MCP, Skills, │
|
||||||
│ Models, MCP, Skills, Memory, Uploads, │
|
│ │ Lead Agent │ │ │ Memory, Uploads, │
|
||||||
│ Artifacts, Threads, Runs, Streaming │
|
│ │ ┌──────────┐ │ │ │ Artifacts │
|
||||||
│ │
|
│ │ │Middleware│ │ │ └────────────────────────┘
|
||||||
│ ┌────────────────────────────────────┐ │
|
│ │ │ Chain │ │ │
|
||||||
│ │ Lead Agent │ │
|
│ │ └──────────┘ │ │
|
||||||
│ │ Middleware Chain, Tools, Subagents │ │
|
│ │ ┌──────────┐ │ │
|
||||||
│ └────────────────────────────────────┘ │
|
│ │ │ Tools │ │ │
|
||||||
└────────────────────────────────────────┘
|
│ │ └──────────┘ │ │
|
||||||
|
│ │ ┌──────────┐ │ │
|
||||||
|
│ │ │Subagents │ │ │
|
||||||
|
│ │ └──────────┘ │ │
|
||||||
|
│ └────────────────┘ │
|
||||||
|
└────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
**Request Routing** (via Nginx):
|
**Request Routing** (via Nginx):
|
||||||
- `/api/langgraph/*` → Gateway LangGraph-compatible API - agent interactions, threads, streaming
|
- `/api/langgraph/*` → LangGraph Server - agent interactions, threads, streaming
|
||||||
- `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup
|
- `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup
|
||||||
- `/` (non-API) → Frontend - Next.js web interface
|
- `/` (non-API) → Frontend - Next.js web interface
|
||||||
|
|
||||||
@@ -74,7 +79,7 @@ Per-thread isolated execution with virtual path translation:
|
|||||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||||
- **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match
|
- **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match
|
||||||
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`write_file` overwrites by default and exposes `append` for end-of-file writes; `bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
|
- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access)
|
||||||
|
|
||||||
### Subagent System
|
### Subagent System
|
||||||
|
|
||||||
@@ -119,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
|
|||||||
| `POST /api/memory/reload` | Force memory reload |
|
| `POST /api/memory/reload` | Force memory reload |
|
||||||
| `GET /api/memory/config` | Memory configuration |
|
| `GET /api/memory/config` | Memory configuration |
|
||||||
| `GET /api/memory/status` | Combined config + data |
|
| `GET /api/memory/status` | Combined config + data |
|
||||||
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
|
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
|
||||||
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
|
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
|
||||||
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
|
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
|
||||||
@@ -188,7 +193,7 @@ export OPENAI_API_KEY="your-api-key-here"
|
|||||||
**Full Application** (from project root):
|
**Full Application** (from project root):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make dev # Starts Gateway + Frontend + Nginx
|
make dev # Starts LangGraph + Gateway + Frontend + Nginx
|
||||||
```
|
```
|
||||||
|
|
||||||
Access at: http://localhost:2026
|
Access at: http://localhost:2026
|
||||||
@@ -196,11 +201,14 @@ Access at: http://localhost:2026
|
|||||||
**Backend Only** (from backend directory):
|
**Backend Only** (from backend directory):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Gateway API + embedded agent runtime
|
# Terminal 1: LangGraph server
|
||||||
make dev
|
make dev
|
||||||
|
|
||||||
|
# Terminal 2: Gateway API
|
||||||
|
make gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
Direct access: Gateway at http://localhost:8001
|
Direct access: LangGraph at http://localhost:2024, Gateway at http://localhost:8001
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -236,16 +244,12 @@ backend/
|
|||||||
│ └── utils/ # Utilities
|
│ └── utils/ # Utilities
|
||||||
├── docs/ # Documentation
|
├── docs/ # Documentation
|
||||||
├── tests/ # Test suite
|
├── tests/ # Test suite
|
||||||
├── langgraph.json # LangGraph graph registry for tooling/Studio compatibility
|
├── langgraph.json # LangGraph server configuration
|
||||||
├── pyproject.toml # Python dependencies
|
├── pyproject.toml # Python dependencies
|
||||||
├── Makefile # Development commands
|
├── Makefile # Development commands
|
||||||
└── Dockerfile # Container build
|
└── Dockerfile # Container build
|
||||||
```
|
```
|
||||||
|
|
||||||
`langgraph.json` is not the default service entrypoint. The scripts and Docker
|
|
||||||
deployments run the Gateway embedded runtime; the file is kept for LangGraph
|
|
||||||
tooling, Studio, or direct LangGraph Server compatibility.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
@@ -358,8 +362,8 @@ If a provider is explicitly enabled but required credentials are missing, or the
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
make install # Install dependencies
|
make install # Install dependencies
|
||||||
make dev # Run Gateway API + embedded agent runtime (port 8001)
|
make dev # Run LangGraph server (port 2024)
|
||||||
make gateway # Run Gateway API without reload (port 8001)
|
make gateway # Run Gateway API (port 8001)
|
||||||
make lint # Run linter (ruff)
|
make lint # Run linter (ruff)
|
||||||
make format # Format code (ruff)
|
make format # Format code (ruff)
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Provides a pluggable channel system that connects external messaging platforms
|
Provides a pluggable channel system that connects external messaging platforms
|
||||||
(Feishu/Lark, Slack, Telegram) to the DeerFlow agent via the ChannelManager,
|
(Feishu/Lark, Slack, Telegram) to the DeerFlow agent via the ChannelManager,
|
||||||
which uses ``langgraph-sdk`` to communicate with Gateway's LangGraph-compatible API.
|
which uses ``langgraph-sdk`` to communicate with the underlying LangGraph Server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
|||||||
@@ -31,10 +31,6 @@ class Channel(ABC):
|
|||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return self._running
|
return self._running
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
# -- lifecycle ---------------------------------------------------------
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -1,740 +0,0 @@
|
|||||||
"""DingTalk channel implementation."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from app.channels.base import Channel
|
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DINGTALK_API_BASE = "https://api.dingtalk.com"
|
|
||||||
|
|
||||||
_TOKEN_REFRESH_MARGIN_SECONDS = 300
|
|
||||||
|
|
||||||
_CONVERSATION_TYPE_P2P = "1"
|
|
||||||
_CONVERSATION_TYPE_GROUP = "2"
|
|
||||||
|
|
||||||
_MAX_UPLOAD_SIZE_BYTES = 20 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_conversation_type(raw: Any) -> str:
|
|
||||||
"""Normalize ``conversationType`` to ``"1"`` (P2P) or ``"2"`` (group).
|
|
||||||
|
|
||||||
Stream payloads may send int or string values.
|
|
||||||
"""
|
|
||||||
if raw is None:
|
|
||||||
return _CONVERSATION_TYPE_P2P
|
|
||||||
s = str(raw).strip()
|
|
||||||
if s == _CONVERSATION_TYPE_GROUP:
|
|
||||||
return _CONVERSATION_TYPE_GROUP
|
|
||||||
return _CONVERSATION_TYPE_P2P
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|
||||||
if allowed_users is None:
|
|
||||||
return set()
|
|
||||||
if isinstance(allowed_users, str):
|
|
||||||
values = [allowed_users]
|
|
||||||
elif isinstance(allowed_users, (list, tuple, set)):
|
|
||||||
values = allowed_users
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"DingTalk allowed_users should be a list of user IDs; treating %s as one string value",
|
|
||||||
type(allowed_users).__name__,
|
|
||||||
)
|
|
||||||
values = [allowed_users]
|
|
||||||
return {str(uid) for uid in values if str(uid)}
|
|
||||||
|
|
||||||
|
|
||||||
def _is_dingtalk_command(text: str) -> bool:
|
|
||||||
if not text.startswith("/"):
|
|
||||||
return False
|
|
||||||
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_rich_text(rich_text_list: list) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in rich_text_list:
|
|
||||||
if isinstance(item, dict) and "text" in item:
|
|
||||||
parts.append(item["text"])
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
_FENCED_CODE_BLOCK_RE = re.compile(r"```(\w*)\n(.*?)```", re.DOTALL)
|
|
||||||
_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
|
|
||||||
_HORIZONTAL_RULE_RE = re.compile(r"^-{3,}$", re.MULTILINE)
|
|
||||||
_TABLE_SEPARATOR_RE = re.compile(r"^\|[-:| ]+\|$", re.MULTILINE)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_markdown_table(text: str) -> str:
|
|
||||||
# DingTalk sampleMarkdown does not render pipe-delimited tables.
|
|
||||||
lines = text.split("\n")
|
|
||||||
result: list[str] = []
|
|
||||||
i = 0
|
|
||||||
while i < len(lines):
|
|
||||||
line = lines[i]
|
|
||||||
# Detect table: header row followed by separator row
|
|
||||||
if i + 1 < len(lines) and line.strip().startswith("|") and _TABLE_SEPARATOR_RE.match(lines[i + 1].strip()):
|
|
||||||
headers = [h.strip() for h in line.strip().strip("|").split("|")]
|
|
||||||
i += 2 # skip header + separator
|
|
||||||
while i < len(lines) and lines[i].strip().startswith("|"):
|
|
||||||
cells = [c.strip() for c in lines[i].strip().strip("|").split("|")]
|
|
||||||
for h, c in zip(headers, cells):
|
|
||||||
result.append(f"> **{h}**: {c}")
|
|
||||||
result.append("")
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
result.append(line)
|
|
||||||
i += 1
|
|
||||||
return "\n".join(result)
|
|
||||||
|
|
||||||
|
|
||||||
def _adapt_markdown_for_dingtalk(text: str) -> str:
|
|
||||||
"""Adapt markdown for DingTalk's limited sampleMarkdown renderer."""
|
|
||||||
|
|
||||||
def _code_block_to_quote(match: re.Match) -> str:
|
|
||||||
lang = match.group(1)
|
|
||||||
code = match.group(2).rstrip("\n")
|
|
||||||
prefix = f"> **{lang}**\n" if lang else ""
|
|
||||||
quoted_lines = "\n".join(f"> {line}" for line in code.split("\n"))
|
|
||||||
return f"{prefix}{quoted_lines}\n"
|
|
||||||
|
|
||||||
text = _FENCED_CODE_BLOCK_RE.sub(_code_block_to_quote, text)
|
|
||||||
text = _INLINE_CODE_RE.sub(r"**\1**", text)
|
|
||||||
text = _convert_markdown_table(text)
|
|
||||||
text = _HORIZONTAL_RULE_RE.sub("───────────", text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class DingTalkChannel(Channel):
|
|
||||||
"""DingTalk IM channel using Stream Push (WebSocket, no public IP needed)."""
|
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
|
||||||
super().__init__(name="dingtalk", bus=bus, config=config)
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._client_id: str = ""
|
|
||||||
self._client_secret: str = ""
|
|
||||||
self._allowed_users: set[str] = _normalize_allowed_users(config.get("allowed_users"))
|
|
||||||
self._cached_token: str = ""
|
|
||||||
self._token_expires_at: float = 0.0
|
|
||||||
self._token_lock = asyncio.Lock()
|
|
||||||
self._card_template_id: str = config.get("card_template_id", "")
|
|
||||||
self._card_track_ids: dict[str, str] = {}
|
|
||||||
self._dingtalk_client: Any = None
|
|
||||||
self._stream_client: Any = None
|
|
||||||
self._incoming_messages: dict[str, Any] = {}
|
|
||||||
self._incoming_messages_lock = threading.Lock()
|
|
||||||
self._card_repliers: dict[str, Any] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return bool(self._card_template_id)
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
import dingtalk_stream # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
logger.error("dingtalk-stream is not installed. Install it with: uv add dingtalk-stream")
|
|
||||||
return
|
|
||||||
|
|
||||||
client_id = self.config.get("client_id", "")
|
|
||||||
client_secret = self.config.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id or not client_secret:
|
|
||||||
logger.error("DingTalk channel requires client_id and client_secret")
|
|
||||||
return
|
|
||||||
|
|
||||||
self._client_id = client_id
|
|
||||||
self._client_secret = client_secret
|
|
||||||
self._main_loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
if self._card_template_id:
|
|
||||||
logger.info("[DingTalk] AI Card mode enabled (template=%s)", self._card_template_id)
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
self._thread = threading.Thread(
|
|
||||||
target=self._run_stream,
|
|
||||||
args=(client_id, client_secret),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
self._thread.start()
|
|
||||||
logger.info("DingTalk channel started")
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
self._running = False
|
|
||||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
stream_client = self._stream_client
|
|
||||||
if stream_client is not None:
|
|
||||||
try:
|
|
||||||
if hasattr(stream_client, "disconnect"):
|
|
||||||
stream_client.disconnect()
|
|
||||||
except Exception:
|
|
||||||
logger.debug("[DingTalk] error disconnecting stream client", exc_info=True)
|
|
||||||
|
|
||||||
self._dingtalk_client = None
|
|
||||||
self._stream_client = None
|
|
||||||
with self._incoming_messages_lock:
|
|
||||||
self._incoming_messages.clear()
|
|
||||||
self._card_repliers.clear()
|
|
||||||
self._card_track_ids.clear()
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=5)
|
|
||||||
self._thread = None
|
|
||||||
logger.info("DingTalk channel stopped")
|
|
||||||
|
|
||||||
def _resolve_routing(self, msg: OutboundMessage) -> tuple[str, str, str]:
|
|
||||||
"""Return (conversation_type, sender_staff_id, conversation_id).
|
|
||||||
|
|
||||||
Uses msg.chat_id as the primary routing key; metadata as fallback.
|
|
||||||
"""
|
|
||||||
conversation_type = _normalize_conversation_type(msg.metadata.get("conversation_type"))
|
|
||||||
sender_staff_id = msg.metadata.get("sender_staff_id", "")
|
|
||||||
conversation_id = msg.metadata.get("conversation_id", "")
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
conversation_id = msg.chat_id or conversation_id
|
|
||||||
else:
|
|
||||||
sender_staff_id = msg.chat_id or sender_staff_id
|
|
||||||
return conversation_type, sender_staff_id, conversation_id
|
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
|
||||||
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
|
|
||||||
robot_code = self._client_id
|
|
||||||
|
|
||||||
# Card mode: stream update to existing AI card
|
|
||||||
source_key = self._make_card_source_key_from_outbound(msg)
|
|
||||||
out_track_id = self._card_track_ids.get(source_key)
|
|
||||||
|
|
||||||
# ``card_template_id`` enables ``runs.stream`` (non-final + final outbounds).
|
|
||||||
# If card creation failed, skip non-final chunks to avoid duplicate messages.
|
|
||||||
if self._card_template_id and not out_track_id and not msg.is_final:
|
|
||||||
return
|
|
||||||
|
|
||||||
if out_track_id:
|
|
||||||
try:
|
|
||||||
await self._stream_update_card(
|
|
||||||
out_track_id,
|
|
||||||
msg.text,
|
|
||||||
is_finalize=msg.is_final,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("[DingTalk] card stream failed, falling back to sampleMarkdown")
|
|
||||||
if msg.is_final:
|
|
||||||
self._card_track_ids.pop(source_key, None)
|
|
||||||
self._card_repliers.pop(out_track_id, None)
|
|
||||||
await self._send_markdown_fallback(robot_code, conversation_type, sender_staff_id, conversation_id, msg.text)
|
|
||||||
return
|
|
||||||
if msg.is_final:
|
|
||||||
self._card_track_ids.pop(source_key, None)
|
|
||||||
self._card_repliers.pop(out_track_id, None)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Non-card mode: send sampleMarkdown with retry
|
|
||||||
last_exc: Exception | None = None
|
|
||||||
for attempt in range(_max_retries):
|
|
||||||
try:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
|
|
||||||
else:
|
|
||||||
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
|
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
delay = 2**attempt
|
|
||||||
logger.warning(
|
|
||||||
"[DingTalk] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
_max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[DingTalk] send failed after %d attempts: %s", _max_retries, last_exc)
|
|
||||||
if last_exc is None:
|
|
||||||
raise RuntimeError("DingTalk send failed without an exception from any attempt")
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
async def _send_markdown_fallback(
|
|
||||||
self,
|
|
||||||
robot_code: str,
|
|
||||||
conversation_type: str,
|
|
||||||
sender_staff_id: str,
|
|
||||||
conversation_id: str,
|
|
||||||
text: str,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
await self._send_group_message(robot_code, conversation_id, text)
|
|
||||||
else:
|
|
||||||
await self._send_p2p_message(robot_code, sender_staff_id, text)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] markdown fallback also failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
|
||||||
if attachment.size > _MAX_UPLOAD_SIZE_BYTES:
|
|
||||||
logger.warning("[DingTalk] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
|
|
||||||
robot_code = self._client_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
media_id = await self._upload_media(attachment.actual_path, "image" if attachment.is_image else "file")
|
|
||||||
if not media_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if attachment.is_image:
|
|
||||||
msg_key = "sampleImageMsg"
|
|
||||||
msg_param = json.dumps({"photoURL": media_id})
|
|
||||||
else:
|
|
||||||
msg_key = "sampleFile"
|
|
||||||
msg_param = json.dumps(
|
|
||||||
{
|
|
||||||
"fileUrl": media_id,
|
|
||||||
"fileName": attachment.filename,
|
|
||||||
"fileSize": str(attachment.size),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": msg_key,
|
|
||||||
"msgParam": msg_param,
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"openConversationId": conversation_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": msg_key,
|
|
||||||
"msgParam": msg_param,
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"userIds": [sender_staff_id],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
logger.info("[DingTalk] file sent: %s", attachment.filename)
|
|
||||||
return True
|
|
||||||
except (httpx.HTTPError, OSError, ValueError, TypeError, AttributeError):
|
|
||||||
logger.exception("[DingTalk] failed to send file: %s", attachment.filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# -- stream client (runs in dedicated thread) --------------------------
|
|
||||||
|
|
||||||
def _run_stream(self, client_id: str, client_secret: str) -> None:
|
|
||||||
try:
|
|
||||||
import dingtalk_stream
|
|
||||||
|
|
||||||
credential = dingtalk_stream.Credential(client_id, client_secret)
|
|
||||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
|
||||||
self._stream_client = client
|
|
||||||
client.register_callback_handler(
|
|
||||||
dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
|
|
||||||
_DingTalkMessageHandler(self),
|
|
||||||
)
|
|
||||||
client.start_forever()
|
|
||||||
except Exception:
|
|
||||||
if self._running:
|
|
||||||
logger.exception("DingTalk Stream Push error")
|
|
||||||
finally:
|
|
||||||
self._stream_client = None
|
|
||||||
|
|
||||||
def _on_chatbot_message(self, message: Any) -> None:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
sender_staff_id = message.sender_staff_id or ""
|
|
||||||
conversation_type = _normalize_conversation_type(message.conversation_type)
|
|
||||||
conversation_id = message.conversation_id or ""
|
|
||||||
msg_id = message.message_id or ""
|
|
||||||
sender_nick = message.sender_nick or ""
|
|
||||||
|
|
||||||
if self._allowed_users and sender_staff_id not in self._allowed_users:
|
|
||||||
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
text = self._extract_text(message)
|
|
||||||
if not text:
|
|
||||||
logger.info("[DingTalk] empty text, ignoring message")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text=%r",
|
|
||||||
conversation_type,
|
|
||||||
msg_id,
|
|
||||||
sender_staff_id,
|
|
||||||
sender_nick,
|
|
||||||
text[:100],
|
|
||||||
)
|
|
||||||
|
|
||||||
if _is_dingtalk_command(text):
|
|
||||||
msg_type = InboundMessageType.COMMAND
|
|
||||||
else:
|
|
||||||
msg_type = InboundMessageType.CHAT
|
|
||||||
|
|
||||||
# P2P: topic_id=None (single thread per user, like Telegram private chat)
|
|
||||||
# Group: topic_id=msg_id (each new message starts a new topic, like Feishu)
|
|
||||||
topic_id: str | None = msg_id if conversation_type == _CONVERSATION_TYPE_GROUP else None
|
|
||||||
|
|
||||||
# chat_id uses conversation_id for groups, sender_staff_id for P2P
|
|
||||||
chat_id = conversation_id if conversation_type == _CONVERSATION_TYPE_GROUP else sender_staff_id
|
|
||||||
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=sender_staff_id,
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=msg_id,
|
|
||||||
metadata={
|
|
||||||
"conversation_type": conversation_type,
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"sender_staff_id": sender_staff_id,
|
|
||||||
"sender_nick": sender_nick,
|
|
||||||
"message_id": msg_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = topic_id
|
|
||||||
|
|
||||||
if self._card_template_id:
|
|
||||||
source_key = self._make_card_source_key(inbound)
|
|
||||||
with self._incoming_messages_lock:
|
|
||||||
self._incoming_messages[source_key] = message
|
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
|
||||||
logger.info("[DingTalk] publishing inbound message to bus (type=%s, msg_id=%s)", msg_type.value, msg_id)
|
|
||||||
fut = asyncio.run_coroutine_threadsafe(
|
|
||||||
self._prepare_inbound(chat_id, inbound),
|
|
||||||
self._main_loop,
|
|
||||||
)
|
|
||||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "prepare_inbound", mid))
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] main loop not running, cannot publish inbound message")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] error processing chatbot message")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_text(message: Any) -> str:
|
|
||||||
msg_type = message.message_type
|
|
||||||
if msg_type == "text" and message.text:
|
|
||||||
return message.text.content.strip()
|
|
||||||
if msg_type == "richText" and message.rich_text_content:
|
|
||||||
return _extract_text_from_rich_text(message.rich_text_content.rich_text_list).strip()
|
|
||||||
return ""
|
|
||||||
|
|
||||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
|
||||||
# Running reply must finish before publish_inbound so AI card tracks are
|
|
||||||
# registered before the manager emits streaming outbounds.
|
|
||||||
await self._send_running_reply(chat_id, inbound)
|
|
||||||
await self.bus.publish_inbound(inbound)
|
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
|
||||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
|
||||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
|
||||||
conversation_id = inbound.metadata.get("conversation_id", "")
|
|
||||||
text = "\u23f3 Working on it..."
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self._card_template_id:
|
|
||||||
source_key = self._make_card_source_key(inbound)
|
|
||||||
with self._incoming_messages_lock:
|
|
||||||
chatbot_message = self._incoming_messages.pop(source_key, None)
|
|
||||||
out_track_id = await self._create_and_deliver_card(
|
|
||||||
text,
|
|
||||||
chatbot_message=chatbot_message,
|
|
||||||
)
|
|
||||||
if out_track_id:
|
|
||||||
self._card_track_ids[source_key] = out_track_id
|
|
||||||
logger.info("[DingTalk] AI card running reply sent for chat=%s", chat_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
robot_code = self._client_id
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
|
||||||
else:
|
|
||||||
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
|
||||||
logger.info("[DingTalk] 'Working on it...' reply sent for chat=%s", chat_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] failed to send running reply for chat=%s", chat_id)
|
|
||||||
|
|
||||||
# -- DingTalk API helpers ----------------------------------------------
|
|
||||||
|
|
||||||
async def _get_access_token(self) -> str:
|
|
||||||
if self._cached_token and time.monotonic() < self._token_expires_at:
|
|
||||||
return self._cached_token
|
|
||||||
async with self._token_lock:
|
|
||||||
if self._cached_token and time.monotonic() < self._token_expires_at:
|
|
||||||
return self._cached_token
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/oauth2/accessToken",
|
|
||||||
json={"appKey": self._client_id, "appSecret": self._client_secret}, # DingTalk API field names
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
raise ValueError(f"DingTalk access token response must be a JSON object, got {type(data).__name__}")
|
|
||||||
|
|
||||||
access_token = data.get("accessToken")
|
|
||||||
if not isinstance(access_token, str) or not access_token.strip():
|
|
||||||
raise ValueError("DingTalk access token response did not contain a usable accessToken")
|
|
||||||
|
|
||||||
raw_expires_in = data.get("expireIn", 7200)
|
|
||||||
try:
|
|
||||||
expires_in = int(raw_expires_in)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
logger.warning("[DingTalk] invalid expireIn value %r, using default 7200s", raw_expires_in)
|
|
||||||
expires_in = 7200
|
|
||||||
|
|
||||||
self._cached_token = access_token.strip()
|
|
||||||
self._token_expires_at = time.monotonic() + expires_in - _TOKEN_REFRESH_MARGIN_SECONDS
|
|
||||||
return self._cached_token
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _api_headers(token: str) -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
"x-acs-dingtalk-access-token": token,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _send_text_message_to_user(self, robot_code: str, user_id: str, text: str) -> None:
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleText",
|
|
||||||
"msgParam": json.dumps({"content": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"userIds": [user_id],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def _send_text_message_to_group(self, robot_code: str, conversation_id: str, text: str) -> None:
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleText",
|
|
||||||
"msgParam": json.dumps({"content": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"openConversationId": conversation_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def _send_p2p_message(self, robot_code: str, user_id: str, text: str) -> None:
|
|
||||||
text = _adapt_markdown_for_dingtalk(text)
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleMarkdown",
|
|
||||||
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"userIds": [user_id],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
if data.get("processQueryKey"):
|
|
||||||
logger.info("[DingTalk] P2P message sent to user=%s", user_id)
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] P2P send response: %s", data)
|
|
||||||
|
|
||||||
async def _send_group_message(
|
|
||||||
self,
|
|
||||||
robot_code: str,
|
|
||||||
conversation_id: str,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
at_user_ids: list[str] | None = None, # noqa: ARG002
|
|
||||||
) -> None:
|
|
||||||
# at_user_ids accepted for call-site compatibility but not passed to the API
|
|
||||||
# (sampleMarkdown does not support @mentions).
|
|
||||||
text = _adapt_markdown_for_dingtalk(text)
|
|
||||||
token = await self._get_access_token()
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
|
|
||||||
headers=self._api_headers(token),
|
|
||||||
json={
|
|
||||||
"msgKey": "sampleMarkdown",
|
|
||||||
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
|
|
||||||
"robotCode": robot_code,
|
|
||||||
"openConversationId": conversation_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
if data.get("processQueryKey"):
|
|
||||||
logger.info("[DingTalk] group message sent to conversation=%s", conversation_id)
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] group send response: %s", data)
|
|
||||||
|
|
||||||
# -- AI Card streaming helpers -------------------------------------------
|
|
||||||
|
|
||||||
def _make_card_source_key(self, inbound: InboundMessage) -> str:
|
|
||||||
m = inbound.metadata
|
|
||||||
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{m.get('message_id', '')}"
|
|
||||||
|
|
||||||
def _make_card_source_key_from_outbound(self, msg: OutboundMessage) -> str:
|
|
||||||
m = msg.metadata
|
|
||||||
correlation_id = m.get("message_id") or msg.thread_ts or ""
|
|
||||||
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{correlation_id}"
|
|
||||||
|
|
||||||
async def _create_and_deliver_card(
|
|
||||||
self,
|
|
||||||
initial_text: str,
|
|
||||||
*,
|
|
||||||
chatbot_message: Any = None,
|
|
||||||
) -> str | None:
|
|
||||||
if self._dingtalk_client is None or chatbot_message is None:
|
|
||||||
logger.warning("[DingTalk] SDK client or chatbot_message unavailable, skipping AI card")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dingtalk_stream.card_replier import AICardReplier
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("[DingTalk] dingtalk-stream card_replier not available")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
replier = AICardReplier(self._dingtalk_client, chatbot_message)
|
|
||||||
card_instance_id = await replier.async_create_and_deliver_card(
|
|
||||||
card_template_id=self._card_template_id,
|
|
||||||
card_data={"content": initial_text},
|
|
||||||
)
|
|
||||||
if not card_instance_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
self._card_repliers[card_instance_id] = replier
|
|
||||||
logger.info("[DingTalk] AI card created: outTrackId=%s", card_instance_id)
|
|
||||||
return card_instance_id
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[DingTalk] failed to create AI card")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _stream_update_card(
|
|
||||||
self,
|
|
||||||
out_track_id: str,
|
|
||||||
content: str,
|
|
||||||
*,
|
|
||||||
is_finalize: bool = False,
|
|
||||||
is_error: bool = False,
|
|
||||||
) -> None:
|
|
||||||
replier = self._card_repliers.get(out_track_id)
|
|
||||||
if not replier:
|
|
||||||
raise RuntimeError(f"No AICardReplier found for track ID {out_track_id}")
|
|
||||||
|
|
||||||
await replier.async_streaming(
|
|
||||||
card_instance_id=out_track_id,
|
|
||||||
content_key="content",
|
|
||||||
content_value=content,
|
|
||||||
append=False,
|
|
||||||
finished=is_finalize,
|
|
||||||
failed=is_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- media upload --------------------------------------------------------
|
|
||||||
|
|
||||||
async def _upload_media(self, file_path: str | Path, media_type: str) -> str | None:
|
|
||||||
try:
|
|
||||||
file_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
|
|
||||||
token = await self._get_access_token()
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{DINGTALK_API_BASE}/v1.0/files/upload",
|
|
||||||
headers={"x-acs-dingtalk-access-token": token},
|
|
||||||
files={"file": ("upload", file_bytes)},
|
|
||||||
data={"type": media_type},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
try:
|
|
||||||
payload = response.json()
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.exception("[DingTalk] failed to decode upload response JSON: %s", file_path)
|
|
||||||
return None
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
logger.warning("[DingTalk] unexpected upload response type %s for %s", type(payload).__name__, file_path)
|
|
||||||
return None
|
|
||||||
return payload.get("mediaId")
|
|
||||||
except (httpx.HTTPError, OSError):
|
|
||||||
logger.exception("[DingTalk] failed to upload media: %s", file_path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _log_future_error(fut: Any, name: str, msg_id: str) -> None:
|
|
||||||
try:
|
|
||||||
exc = fut.exception()
|
|
||||||
if exc:
|
|
||||||
logger.error("[DingTalk] %s failed for msg_id=%s: %s", name, msg_id, exc)
|
|
||||||
except (asyncio.CancelledError, asyncio.InvalidStateError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _DingTalkMessageHandler:
|
|
||||||
"""Callback handler registered with dingtalk-stream."""
|
|
||||||
|
|
||||||
def __init__(self, channel: DingTalkChannel) -> None:
|
|
||||||
self._channel = channel
|
|
||||||
|
|
||||||
def pre_start(self) -> None:
|
|
||||||
if hasattr(self, "dingtalk_client") and self.dingtalk_client is not None:
|
|
||||||
self._channel._dingtalk_client = self.dingtalk_client
|
|
||||||
|
|
||||||
async def raw_process(self, callback_message: Any) -> Any:
|
|
||||||
import dingtalk_stream
|
|
||||||
from dingtalk_stream.frames import Headers
|
|
||||||
|
|
||||||
code, message = await self.process(callback_message)
|
|
||||||
ack_message = dingtalk_stream.AckMessage()
|
|
||||||
ack_message.code = code
|
|
||||||
ack_message.headers.message_id = callback_message.headers.message_id
|
|
||||||
ack_message.headers.content_type = Headers.CONTENT_TYPE_APPLICATION_JSON
|
|
||||||
ack_message.data = {"response": message}
|
|
||||||
return ack_message
|
|
||||||
|
|
||||||
async def process(self, callback: Any) -> tuple[int, str]:
|
|
||||||
import dingtalk_stream
|
|
||||||
|
|
||||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
|
||||||
self._channel._on_chatbot_message(incoming_message)
|
|
||||||
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
|
|
||||||
@@ -1,553 +0,0 @@
|
|||||||
"""Discord channel integration using discord.py."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.channels.base import Channel
|
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_DISCORD_MAX_MESSAGE_LEN = 2000
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(Channel):
|
|
||||||
"""Discord bot channel.
|
|
||||||
|
|
||||||
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
|
||||||
- ``bot_token``: Discord Bot token.
|
|
||||||
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
|
||||||
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
|
|
||||||
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
|
|
||||||
(even when mention_only is true). Use for channels where you want the bot to respond
|
|
||||||
without mentions. Empty = mention_only applies everywhere.
|
|
||||||
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
|
|
||||||
Default: same as ``mention_only``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
|
||||||
super().__init__(name="discord", bus=bus, config=config)
|
|
||||||
self._bot_token = str(config.get("bot_token", "")).strip()
|
|
||||||
self._allowed_guilds: set[int] = set()
|
|
||||||
for guild_id in config.get("allowed_guilds", []):
|
|
||||||
try:
|
|
||||||
self._allowed_guilds.add(int(guild_id))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
continue
|
|
||||||
self._mention_only: bool = bool(config.get("mention_only", False))
|
|
||||||
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
|
|
||||||
self._allowed_channels: set[str] = set()
|
|
||||||
for channel_id in config.get("allowed_channels", []):
|
|
||||||
self._allowed_channels.add(str(channel_id))
|
|
||||||
|
|
||||||
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
|
|
||||||
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
|
|
||||||
# conversations to DeerFlow thread IDs — a different concern.
|
|
||||||
self._active_threads: dict[str, str] = {}
|
|
||||||
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
|
|
||||||
self._active_thread_ids: set[str] = set()
|
|
||||||
# Lock protecting _active_threads and the JSON file from concurrent access.
|
|
||||||
# _run_client (Discord loop thread) and the main thread both read/write.
|
|
||||||
self._thread_store_lock = threading.Lock()
|
|
||||||
store = config.get("channel_store")
|
|
||||||
if store is not None:
|
|
||||||
self._thread_store_path = store._path.parent / "discord_threads.json"
|
|
||||||
else:
|
|
||||||
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
|
|
||||||
|
|
||||||
# Typing indicator management
|
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
self._client = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._discord_module = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
import discord
|
|
||||||
except ImportError:
|
|
||||||
logger.error("discord.py is not installed. Install it with: uv add discord.py")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self._bot_token:
|
|
||||||
logger.error("Discord channel requires bot_token")
|
|
||||||
return
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.messages = True
|
|
||||||
intents.guilds = True
|
|
||||||
intents.message_content = True
|
|
||||||
|
|
||||||
client = discord.Client(
|
|
||||||
intents=intents,
|
|
||||||
allowed_mentions=discord.AllowedMentions.none(),
|
|
||||||
)
|
|
||||||
self._client = client
|
|
||||||
self._discord_module = discord
|
|
||||||
self._main_loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_message(message) -> None:
|
|
||||||
await self._on_message(message)
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
self._load_active_threads()
|
|
||||||
logger.info("Discord channel started")
|
|
||||||
|
|
||||||
def _load_active_threads(self) -> None:
|
|
||||||
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
|
|
||||||
with self._thread_store_lock:
|
|
||||||
try:
|
|
||||||
if not self._thread_store_path.exists():
|
|
||||||
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
|
|
||||||
return
|
|
||||||
data = json.loads(self._thread_store_path.read_text())
|
|
||||||
self._active_threads.clear()
|
|
||||||
self._active_thread_ids.clear()
|
|
||||||
for channel_id, thread_id in data.items():
|
|
||||||
self._active_threads[channel_id] = thread_id
|
|
||||||
self._active_thread_ids.add(thread_id)
|
|
||||||
if self._active_threads:
|
|
||||||
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to load thread mappings")
|
|
||||||
|
|
||||||
def _save_thread(self, channel_id: str, thread_id: str) -> None:
|
|
||||||
"""Persist a Discord thread mapping to the dedicated JSON file."""
|
|
||||||
with self._thread_store_lock:
|
|
||||||
try:
|
|
||||||
data: dict[str, str] = {}
|
|
||||||
if self._thread_store_path.exists():
|
|
||||||
data = json.loads(self._thread_store_path.read_text())
|
|
||||||
old_id = data.get(channel_id)
|
|
||||||
data[channel_id] = thread_id
|
|
||||||
# Update reverse-lookup set
|
|
||||||
if old_id:
|
|
||||||
self._active_thread_ids.discard(old_id)
|
|
||||||
self._active_thread_ids.add(thread_id)
|
|
||||||
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._thread_store_path.write_text(json.dumps(data, indent=2))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
self._running = False
|
|
||||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
|
||||||
|
|
||||||
# Cancel all active typing indicator tasks
|
|
||||||
for target_id, task in list(self._typing_tasks.items()):
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.debug("[Discord] cancelled typing task for target %s", target_id)
|
|
||||||
self._typing_tasks.clear()
|
|
||||||
|
|
||||||
if self._client and self._discord_loop and self._discord_loop.is_running():
|
|
||||||
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning("[Discord] client close timed out after 10s")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] error while closing client")
|
|
||||||
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=10)
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
self._client = None
|
|
||||||
self._discord_loop = None
|
|
||||||
self._discord_module = None
|
|
||||||
logger.info("Discord channel stopped")
|
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
|
||||||
# Stop typing indicator once we're sending the response
|
|
||||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(stop_future)
|
|
||||||
|
|
||||||
target = await self._resolve_target(msg)
|
|
||||||
if target is None:
|
|
||||||
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
|
||||||
return
|
|
||||||
|
|
||||||
text = msg.text or ""
|
|
||||||
for chunk in self._split_text(text):
|
|
||||||
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(send_future)
|
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
|
||||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(stop_future)
|
|
||||||
|
|
||||||
target = await self._resolve_target(msg)
|
|
||||||
if target is None:
|
|
||||||
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self._discord_module is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
|
|
||||||
file = self._discord_module.File(fp, filename=attachment.filename)
|
|
||||||
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
|
||||||
await asyncio.wrap_future(send_future)
|
|
||||||
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
|
|
||||||
"""Starts a loop to send periodic typing indicators."""
|
|
||||||
target_id = thread_ts or chat_id
|
|
||||||
if target_id in self._typing_tasks:
|
|
||||||
return # Already typing for this target
|
|
||||||
|
|
||||||
async def _typing_loop():
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await channel.trigger_typing()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
task = asyncio.create_task(_typing_loop())
|
|
||||||
self._typing_tasks[target_id] = task
|
|
||||||
|
|
||||||
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
|
|
||||||
"""Stops the typing loop for a specific target."""
|
|
||||||
target_id = thread_ts or chat_id
|
|
||||||
task = self._typing_tasks.pop(target_id, None)
|
|
||||||
if task and not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
|
|
||||||
|
|
||||||
async def _add_reaction(self, message) -> None:
|
|
||||||
"""Add a checkmark reaction to acknowledge the message was received."""
|
|
||||||
try:
|
|
||||||
await message.add_reaction("✅")
|
|
||||||
except Exception:
|
|
||||||
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
|
|
||||||
|
|
||||||
async def _on_message(self, message) -> None:
|
|
||||||
if not self._running or not self._client:
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.author.bot:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._client.user and message.author.id == self._client.user.id:
|
|
||||||
return
|
|
||||||
|
|
||||||
guild = message.guild
|
|
||||||
if self._allowed_guilds:
|
|
||||||
if guild is None or guild.id not in self._allowed_guilds:
|
|
||||||
return
|
|
||||||
|
|
||||||
text = (message.content or "").strip()
|
|
||||||
if not text:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._discord_module is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Determine whether the bot is mentioned in this message
|
|
||||||
user = self._client.user if self._client else None
|
|
||||||
if user:
|
|
||||||
bot_mention = user.mention # <@ID>
|
|
||||||
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
|
|
||||||
standard_mention = f"<@{user.id}>"
|
|
||||||
else:
|
|
||||||
bot_mention = None
|
|
||||||
alt_mention = None
|
|
||||||
standard_mention = ""
|
|
||||||
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
|
|
||||||
|
|
||||||
# Strip mention from text for processing
|
|
||||||
if has_mention:
|
|
||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
|
||||||
|
|
||||||
# --- Determine thread/channel routing and typing target ---
|
|
||||||
thread_id = None
|
|
||||||
chat_id = None
|
|
||||||
typing_target = None # The Discord object to type into
|
|
||||||
|
|
||||||
if isinstance(message.channel, self._discord_module.Thread):
|
|
||||||
# --- Message already inside a thread ---
|
|
||||||
thread_obj = message.channel
|
|
||||||
thread_id = str(thread_obj.id)
|
|
||||||
chat_id = str(thread_obj.parent_id or thread_obj.id)
|
|
||||||
typing_target = thread_obj
|
|
||||||
|
|
||||||
# If this is a known active thread, process normally
|
|
||||||
if thread_id in self._active_thread_ids:
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=str(message.author.id),
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=thread_id,
|
|
||||||
metadata={
|
|
||||||
"guild_id": str(guild.id) if guild else None,
|
|
||||||
"channel_id": str(message.channel.id),
|
|
||||||
"message_id": str(message.id),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = thread_id
|
|
||||||
self._publish(inbound)
|
|
||||||
# Start typing indicator in the thread
|
|
||||||
if typing_target:
|
|
||||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
||||||
asyncio.create_task(self._add_reaction(message))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Thread not tracked (orphaned) — create new thread and handle below
|
|
||||||
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
|
|
||||||
thread_id = None
|
|
||||||
typing_target = None
|
|
||||||
|
|
||||||
# At this point we're guaranteed to be in a channel, not a thread
|
|
||||||
# (the Thread case is handled above). Apply mention_only for all
|
|
||||||
# non-thread messages — no special case needed.
|
|
||||||
channel_id = str(message.channel.id)
|
|
||||||
|
|
||||||
# Check if there's an active thread for this channel
|
|
||||||
if channel_id in self._active_threads:
|
|
||||||
# respect mention_only: if enabled, only process messages that mention the bot
|
|
||||||
# (unless the channel is in allowed_channels)
|
|
||||||
# Messages within a thread are always allowed through (continuation).
|
|
||||||
# At this code point we know the message is in a channel, not a thread
|
|
||||||
# (Thread case handled above), so always apply the check.
|
|
||||||
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
||||||
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
|
|
||||||
return
|
|
||||||
# mention_only + fresh @ → create new thread instead of routing to existing one
|
|
||||||
if self._mention_only and has_mention:
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is not None:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj
|
|
||||||
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
|
|
||||||
else:
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel
|
|
||||||
else:
|
|
||||||
# Existing session → route to the existing thread
|
|
||||||
target_thread_id = self._active_threads[channel_id]
|
|
||||||
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = await self._get_channel_or_thread(target_thread_id)
|
|
||||||
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
||||||
# Not mentioned and not in an allowed channel → skip
|
|
||||||
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
|
|
||||||
return
|
|
||||||
elif self._mention_only and has_mention:
|
|
||||||
# First mention in this channel → create thread
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is not None:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj # Type into the new thread
|
|
||||||
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
|
|
||||||
else:
|
|
||||||
# Fallback: thread creation failed (disabled/permissions), reply in channel
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
elif self._thread_mode:
|
|
||||||
# thread_mode but mention_only is False → create thread anyway for conversation grouping
|
|
||||||
thread_obj = await self._create_thread(message)
|
|
||||||
if thread_obj is None:
|
|
||||||
# Thread creation failed (disabled/permissions), fall back to channel replies
|
|
||||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
else:
|
|
||||||
target_thread_id = str(thread_obj.id)
|
|
||||||
self._active_threads[channel_id] = target_thread_id
|
|
||||||
self._save_thread(channel_id, target_thread_id)
|
|
||||||
thread_id = target_thread_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = thread_obj # Type into the new thread
|
|
||||||
else:
|
|
||||||
# No threading — reply directly in channel
|
|
||||||
thread_id = channel_id
|
|
||||||
chat_id = channel_id
|
|
||||||
typing_target = message.channel # Type into the channel
|
|
||||||
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
|
||||||
inbound = self._make_inbound(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=str(message.author.id),
|
|
||||||
text=text,
|
|
||||||
msg_type=msg_type,
|
|
||||||
thread_ts=thread_id,
|
|
||||||
metadata={
|
|
||||||
"guild_id": str(guild.id) if guild else None,
|
|
||||||
"channel_id": str(message.channel.id),
|
|
||||||
"message_id": str(message.id),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
inbound.topic_id = thread_id
|
|
||||||
|
|
||||||
# Start typing indicator in the correct target (thread or channel)
|
|
||||||
if typing_target:
|
|
||||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
||||||
|
|
||||||
self._publish(inbound)
|
|
||||||
asyncio.create_task(self._add_reaction(message))
|
|
||||||
|
|
||||||
def _publish(self, inbound) -> None:
|
|
||||||
"""Publish an inbound message to the main event loop."""
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
|
||||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
|
||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
|
||||||
|
|
||||||
def _run_client(self) -> None:
|
|
||||||
self._discord_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._discord_loop)
|
|
||||||
try:
|
|
||||||
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
|
|
||||||
except Exception:
|
|
||||||
if self._running:
|
|
||||||
logger.exception("Discord client error")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
if self._client and not self._client.is_closed():
|
|
||||||
self._discord_loop.run_until_complete(self._client.close())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error during Discord shutdown")
|
|
||||||
|
|
||||||
async def _create_thread(self, message):
|
|
||||||
try:
|
|
||||||
if self._discord_module is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
|
|
||||||
channel_type = message.channel.type
|
|
||||||
if channel_type not in (
|
|
||||||
self._discord_module.ChannelType.text,
|
|
||||||
self._discord_module.ChannelType.news,
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"[Discord] channel type %s (%s) does not support threads",
|
|
||||||
channel_type.value,
|
|
||||||
channel_type.name,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
|
||||||
return await message.create_thread(name=thread_name)
|
|
||||||
except self._discord_module.errors.HTTPException as exc:
|
|
||||||
if exc.code == 50024:
|
|
||||||
logger.info(
|
|
||||||
"[Discord] cannot create thread in channel %s (error code 50024): %s",
|
|
||||||
message.channel.id,
|
|
||||||
channel_type.name if (channel_type := message.channel.type) else "unknown",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.exception(
|
|
||||||
"[Discord] failed to create thread for message=%s (HTTPException %s)",
|
|
||||||
message.id,
|
|
||||||
exc.code,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _resolve_target(self, msg: OutboundMessage):
|
|
||||||
if not self._client or not self._discord_loop:
|
|
||||||
return None
|
|
||||||
|
|
||||||
target_ids: list[str] = []
|
|
||||||
if msg.thread_ts:
|
|
||||||
target_ids.append(msg.thread_ts)
|
|
||||||
if msg.chat_id and msg.chat_id not in target_ids:
|
|
||||||
target_ids.append(msg.chat_id)
|
|
||||||
|
|
||||||
for raw_id in target_ids:
|
|
||||||
target = await self._get_channel_or_thread(raw_id)
|
|
||||||
if target is not None:
|
|
||||||
return target
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_channel_or_thread(self, raw_id: str):
|
|
||||||
if not self._client or not self._discord_loop:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
target_id = int(raw_id)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
|
|
||||||
try:
|
|
||||||
return await asyncio.wrap_future(get_future)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _fetch_channel(self, target_id: int):
|
|
||||||
if not self._client:
|
|
||||||
return None
|
|
||||||
|
|
||||||
channel = self._client.get_channel(target_id)
|
|
||||||
if channel is not None:
|
|
||||||
return channel
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self._client.fetch_channel(target_id)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_text(text: str) -> list[str]:
|
|
||||||
if not text:
|
|
||||||
return [""]
|
|
||||||
|
|
||||||
chunks: list[str] = []
|
|
||||||
remaining = text
|
|
||||||
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
|
|
||||||
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
|
|
||||||
if split_at <= 0:
|
|
||||||
split_at = _DISCORD_MAX_MESSAGE_LEN
|
|
||||||
chunks.append(remaining[:split_at])
|
|
||||||
remaining = remaining[split_at:].lstrip("\n")
|
|
||||||
|
|
||||||
if remaining:
|
|
||||||
chunks.append(remaining)
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
@@ -63,10 +63,6 @@ class FeishuChannel(Channel):
|
|||||||
self._GetMessageResourceRequest = None
|
self._GetMessageResourceRequest = None
|
||||||
self._thread_lock = threading.Lock()
|
self._thread_lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via Gateway."""
|
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via LangGraph Server."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -17,13 +17,11 @@ from langgraph_sdk.errors import ConflictError
|
|||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
|
||||||
from app.gateway.internal_auth import create_internal_auth_headers
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_LANGGRAPH_URL = "http://localhost:8001/api"
|
DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
|
||||||
DEFAULT_GATEWAY_URL = "http://localhost:8001"
|
DEFAULT_GATEWAY_URL = "http://localhost:8001"
|
||||||
DEFAULT_ASSISTANT_ID = "lead_agent"
|
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
||||||
@@ -38,8 +36,6 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
|||||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||||
|
|
||||||
CHANNEL_CAPABILITIES = {
|
CHANNEL_CAPABILITIES = {
|
||||||
"dingtalk": {"supports_streaming": False},
|
|
||||||
"discord": {"supports_streaming": False},
|
|
||||||
"feishu": {"supports_streaming": True},
|
"feishu": {"supports_streaming": True},
|
||||||
"slack": {"supports_streaming": False},
|
"slack": {"supports_streaming": False},
|
||||||
"telegram": {"supports_streaming": False},
|
"telegram": {"supports_streaming": False},
|
||||||
@@ -49,13 +45,6 @@ CHANNEL_CAPABILITIES = {
|
|||||||
|
|
||||||
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
|
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
|
||||||
|
|
||||||
_METADATA_DROP_KEYS = frozenset({"raw_message", "ref_msg"})
|
|
||||||
|
|
||||||
|
|
||||||
def _slim_metadata(meta: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Return a shallow copy of *meta* with known-large keys removed."""
|
|
||||||
return {k: v for k, v in meta.items() if k not in _METADATA_DROP_KEYS}
|
|
||||||
|
|
||||||
|
|
||||||
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
|
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
|
||||||
|
|
||||||
@@ -146,13 +135,6 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
|
|||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
def _strip_loop_warning_text(text: str) -> str:
|
|
||||||
"""Remove middleware-authored loop warning lines from display text."""
|
|
||||||
if "[LOOP DETECTED]" not in text:
|
|
||||||
return text
|
|
||||||
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_response_text(result: dict | list) -> str:
|
def _extract_response_text(result: dict | list) -> str:
|
||||||
"""Extract the last AI message text from a LangGraph runs.wait result.
|
"""Extract the last AI message text from a LangGraph runs.wait result.
|
||||||
|
|
||||||
@@ -162,7 +144,7 @@ def _extract_response_text(result: dict | list) -> str:
|
|||||||
Handles special cases:
|
Handles special cases:
|
||||||
- Regular AI text responses
|
- Regular AI text responses
|
||||||
- Clarification interrupts (``ask_clarification`` tool messages)
|
- Clarification interrupts (``ask_clarification`` tool messages)
|
||||||
- Strips loop-detection warnings attached to tool-call AI messages
|
- AI messages with tool_calls but no text content
|
||||||
"""
|
"""
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
messages = result
|
messages = result
|
||||||
@@ -192,12 +174,7 @@ def _extract_response_text(result: dict | list) -> str:
|
|||||||
# Regular AI message with text content
|
# Regular AI message with text content
|
||||||
if msg_type == "ai":
|
if msg_type == "ai":
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
has_tool_calls = bool(msg.get("tool_calls"))
|
|
||||||
if isinstance(content, str) and content:
|
if isinstance(content, str) and content:
|
||||||
if has_tool_calls:
|
|
||||||
content = _strip_loop_warning_text(content)
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
return content
|
return content
|
||||||
# content can be a list of content blocks
|
# content can be a list of content blocks
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
@@ -208,8 +185,6 @@ def _extract_response_text(result: dict | list) -> str:
|
|||||||
elif isinstance(block, str):
|
elif isinstance(block, str):
|
||||||
parts.append(block)
|
parts.append(block)
|
||||||
text = "".join(parts)
|
text = "".join(parts)
|
||||||
if has_tool_calls:
|
|
||||||
text = _strip_loop_warning_text(text)
|
|
||||||
if text:
|
if text:
|
||||||
return text
|
return text
|
||||||
return ""
|
return ""
|
||||||
@@ -434,13 +409,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
if not msg.files:
|
if not msg.files:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||||
UnsafeUploadPathError,
|
|
||||||
claim_unique_filename,
|
|
||||||
ensure_uploads_dir,
|
|
||||||
normalize_filename,
|
|
||||||
write_upload_file_no_symlink,
|
|
||||||
)
|
|
||||||
|
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||||
@@ -491,10 +460,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
|
|
||||||
dest = uploads_dir / safe_name
|
dest = uploads_dir / safe_name
|
||||||
try:
|
try:
|
||||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
dest.write_bytes(data)
|
||||||
except UnsafeUploadPathError:
|
|
||||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
|
||||||
continue
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||||
continue
|
continue
|
||||||
@@ -542,7 +508,7 @@ class ChannelManager:
|
|||||||
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
||||||
|
|
||||||
It reads from the MessageBus inbound queue, creates/reuses threads on
|
It reads from the MessageBus inbound queue, creates/reuses threads on
|
||||||
Gateway's LangGraph-compatible API, sends messages via ``runs.wait``, and publishes
|
the LangGraph Server, sends messages via ``runs.wait``, and publishes
|
||||||
outbound responses back through the bus.
|
outbound responses back through the bus.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -567,20 +533,12 @@ class ChannelManager:
|
|||||||
self._default_session = _as_dict(default_session)
|
self._default_session = _as_dict(default_session)
|
||||||
self._channel_sessions = dict(channel_sessions or {})
|
self._channel_sessions = dict(channel_sessions or {})
|
||||||
self._client = None # lazy init — langgraph_sdk async client
|
self._client = None # lazy init — langgraph_sdk async client
|
||||||
self._csrf_token = generate_csrf_token()
|
|
||||||
self._semaphore: asyncio.Semaphore | None = None
|
self._semaphore: asyncio.Semaphore | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _channel_supports_streaming(channel_name: str) -> bool:
|
def _channel_supports_streaming(channel_name: str) -> bool:
|
||||||
from .service import get_channel_service
|
|
||||||
|
|
||||||
service = get_channel_service()
|
|
||||||
if service:
|
|
||||||
channel = service.get_channel(channel_name)
|
|
||||||
if channel is not None:
|
|
||||||
return channel.supports_streaming
|
|
||||||
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
|
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
|
||||||
|
|
||||||
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
|
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
@@ -603,17 +561,6 @@ class ChannelManager:
|
|||||||
user_layer.get("config"),
|
user_layer.get("config"),
|
||||||
)
|
)
|
||||||
|
|
||||||
configurable = run_config.get("configurable")
|
|
||||||
if isinstance(configurable, Mapping):
|
|
||||||
configurable = dict(configurable)
|
|
||||||
else:
|
|
||||||
configurable = {}
|
|
||||||
run_config["configurable"] = configurable
|
|
||||||
# Pin channel-triggered runs to the root graph namespace so follow-up
|
|
||||||
# turns continue from the same conversation checkpoint.
|
|
||||||
configurable["checkpoint_ns"] = ""
|
|
||||||
configurable["thread_id"] = thread_id
|
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
DEFAULT_RUN_CONTEXT,
|
DEFAULT_RUN_CONTEXT,
|
||||||
self._default_session.get("context"),
|
self._default_session.get("context"),
|
||||||
@@ -638,14 +585,7 @@ class ChannelManager:
|
|||||||
if self._client is None:
|
if self._client is None:
|
||||||
from langgraph_sdk import get_client
|
from langgraph_sdk import get_client
|
||||||
|
|
||||||
self._client = get_client(
|
self._client = get_client(url=self._langgraph_url)
|
||||||
url=self._langgraph_url,
|
|
||||||
headers={
|
|
||||||
**create_internal_auth_headers(),
|
|
||||||
CSRF_HEADER_NAME: self._csrf_token,
|
|
||||||
"Cookie": f"{CSRF_COOKIE_NAME}={self._csrf_token}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
# -- lifecycle ---------------------------------------------------------
|
||||||
@@ -728,7 +668,7 @@ class ChannelManager:
|
|||||||
# -- chat handling -----------------------------------------------------
|
# -- chat handling -----------------------------------------------------
|
||||||
|
|
||||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||||
"""Create a new thread through Gateway and store the mapping."""
|
"""Create a new thread on the LangGraph Server and store the mapping."""
|
||||||
thread = await client.threads.create()
|
thread = await client.threads.create()
|
||||||
thread_id = thread["thread_id"]
|
thread_id = thread["thread_id"]
|
||||||
self.store.set_thread_id(
|
self.store.set_thread_id(
|
||||||
@@ -738,7 +678,7 @@ class ChannelManager:
|
|||||||
topic_id=msg.topic_id,
|
topic_id=msg.topic_id,
|
||||||
user_id=msg.user_id,
|
user_id=msg.user_id,
|
||||||
)
|
)
|
||||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
logger.info("[Manager] new thread created on LangGraph Server: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||||
@@ -787,22 +727,13 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
try:
|
result = await client.runs.wait(
|
||||||
result = await client.runs.wait(
|
thread_id,
|
||||||
thread_id,
|
assistant_id,
|
||||||
assistant_id,
|
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
config=run_config,
|
||||||
config=run_config,
|
context=run_context,
|
||||||
context=run_context,
|
)
|
||||||
multitask_strategy="reject",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
if _is_thread_busy_error(exc):
|
|
||||||
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
|
|
||||||
await self._send_error(msg, THREAD_BUSY_MESSAGE)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
@@ -830,7 +761,6 @@ class ChannelManager:
|
|||||||
artifacts=artifacts,
|
artifacts=artifacts,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
@@ -892,7 +822,6 @@ class ChannelManager:
|
|||||||
text=latest_text,
|
text=latest_text,
|
||||||
is_final=False,
|
is_final=False,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
last_published_text = latest_text
|
last_published_text = latest_text
|
||||||
@@ -937,7 +866,6 @@ class ChannelManager:
|
|||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
is_final=True,
|
is_final=True,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -957,7 +885,7 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if command == "new":
|
if command == "new":
|
||||||
# Create a new thread through Gateway
|
# Create a new thread on the LangGraph Server
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
thread = await client.threads.create()
|
thread = await client.threads.create()
|
||||||
new_thread_id = thread["thread_id"]
|
new_thread_id = thread["thread_id"]
|
||||||
@@ -996,7 +924,6 @@ class ChannelManager:
|
|||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||||
text=reply,
|
text=reply,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|
||||||
@@ -1006,11 +933,7 @@ class ChannelManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as http:
|
async with httpx.AsyncClient() as http:
|
||||||
resp = await http.get(
|
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
|
||||||
f"{self._gateway_url}{path}",
|
|
||||||
timeout=10,
|
|
||||||
headers=create_internal_auth_headers(),
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -1034,6 +957,5 @@ class ChannelManager:
|
|||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||||
text=error_text,
|
text=error_text,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||||
@@ -13,13 +13,8 @@ from app.channels.store import ChannelStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
# Channel name → import path for lazy loading
|
# Channel name → import path for lazy loading
|
||||||
_CHANNEL_REGISTRY: dict[str, str] = {
|
_CHANNEL_REGISTRY: dict[str, str] = {
|
||||||
"dingtalk": "app.channels.dingtalk:DingTalkChannel",
|
|
||||||
"discord": "app.channels.discord:DiscordChannel",
|
|
||||||
"feishu": "app.channels.feishu:FeishuChannel",
|
"feishu": "app.channels.feishu:FeishuChannel",
|
||||||
"slack": "app.channels.slack:SlackChannel",
|
"slack": "app.channels.slack:SlackChannel",
|
||||||
"telegram": "app.channels.telegram:TelegramChannel",
|
"telegram": "app.channels.telegram:TelegramChannel",
|
||||||
@@ -27,17 +22,6 @@ _CHANNEL_REGISTRY: dict[str, str] = {
|
|||||||
"wecom": "app.channels.wecom:WeComChannel",
|
"wecom": "app.channels.wecom:WeComChannel",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Keys that indicate a user has configured credentials for a channel.
|
|
||||||
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
|
|
||||||
"dingtalk": ["client_id", "client_secret"],
|
|
||||||
"discord": ["bot_token"],
|
|
||||||
"feishu": ["app_id", "app_secret"],
|
|
||||||
"slack": ["bot_token", "app_token"],
|
|
||||||
"telegram": ["bot_token"],
|
|
||||||
"wecom": ["bot_id", "bot_secret"],
|
|
||||||
"wechat": ["bot_token"],
|
|
||||||
}
|
|
||||||
|
|
||||||
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||||
|
|
||||||
@@ -80,15 +64,14 @@ class ChannelService:
|
|||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
def from_app_config(cls) -> ChannelService:
|
||||||
"""Create a ChannelService from the application config."""
|
"""Create a ChannelService from the application config."""
|
||||||
if app_config is None:
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
|
|
||||||
app_config = get_app_config()
|
config = get_app_config()
|
||||||
channels_config = {}
|
channels_config = {}
|
||||||
# extra fields are allowed by AppConfig (extra="allow")
|
# extra fields are allowed by AppConfig (extra="allow")
|
||||||
extra = app_config.model_extra or {}
|
extra = config.model_extra or {}
|
||||||
if "channels" in extra:
|
if "channels" in extra:
|
||||||
channels_config = extra["channels"]
|
channels_config = extra["channels"]
|
||||||
return cls(channels_config=channels_config)
|
return cls(channels_config=channels_config)
|
||||||
@@ -104,16 +87,7 @@ class ChannelService:
|
|||||||
if not isinstance(channel_config, dict):
|
if not isinstance(channel_config, dict):
|
||||||
continue
|
continue
|
||||||
if not channel_config.get("enabled", False):
|
if not channel_config.get("enabled", False):
|
||||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
logger.info("Channel %s is disabled, skipping", name)
|
||||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
|
||||||
if has_creds:
|
|
||||||
logger.warning(
|
|
||||||
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
|
|
||||||
name,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Channel %s is disabled, skipping", name)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self._start_channel(name, channel_config)
|
await self._start_channel(name, channel_config)
|
||||||
@@ -167,19 +141,12 @@ class ChannelService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = dict(config)
|
|
||||||
config["channel_store"] = self.store
|
|
||||||
channel = channel_cls(bus=self.bus, config=config)
|
channel = channel_cls(bus=self.bus, config=config)
|
||||||
self._channels[name] = channel
|
|
||||||
await channel.start()
|
await channel.start()
|
||||||
if not channel.is_running:
|
self._channels[name] = channel
|
||||||
self._channels.pop(name, None)
|
|
||||||
logger.error("Channel %s did not enter a running state after start()", name)
|
|
||||||
return False
|
|
||||||
logger.info("Channel %s started", name)
|
logger.info("Channel %s started", name)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
self._channels.pop(name, None)
|
|
||||||
logger.exception("Failed to start channel %s", name)
|
logger.exception("Failed to start channel %s", name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -214,12 +181,12 @@ def get_channel_service() -> ChannelService | None:
|
|||||||
return _channel_service
|
return _channel_service
|
||||||
|
|
||||||
|
|
||||||
async def start_channel_service(app_config: AppConfig | None = None) -> ChannelService:
|
async def start_channel_service() -> ChannelService:
|
||||||
"""Create and start the global ChannelService from app config."""
|
"""Create and start the global ChannelService from app config."""
|
||||||
global _channel_service
|
global _channel_service
|
||||||
if _channel_service is not None:
|
if _channel_service is not None:
|
||||||
return _channel_service
|
return _channel_service
|
||||||
_channel_service = ChannelService.from_app_config(app_config)
|
_channel_service = ChannelService.from_app_config()
|
||||||
await _channel_service.start()
|
await _channel_service.start()
|
||||||
return _channel_service
|
return _channel_service
|
||||||
|
|
||||||
|
|||||||
@@ -16,31 +16,13 @@ logger = logging.getLogger(__name__)
|
|||||||
_slack_md_converter = SlackMarkdownConverter()
|
_slack_md_converter = SlackMarkdownConverter()
|
||||||
|
|
||||||
|
|
||||||
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|
||||||
if allowed_users is None:
|
|
||||||
return set()
|
|
||||||
if isinstance(allowed_users, str):
|
|
||||||
values = [allowed_users]
|
|
||||||
elif isinstance(allowed_users, list | tuple | set):
|
|
||||||
values = allowed_users
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Slack allowed_users should be a list of Slack user IDs or a single Slack user ID string; treating %s as one string value",
|
|
||||||
type(allowed_users).__name__,
|
|
||||||
)
|
|
||||||
values = [allowed_users]
|
|
||||||
return {str(user_id) for user_id in values if str(user_id)}
|
|
||||||
|
|
||||||
|
|
||||||
class SlackChannel(Channel):
|
class SlackChannel(Channel):
|
||||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||||
|
|
||||||
Configuration keys (in ``config.yaml`` under ``channels.slack``):
|
Configuration keys (in ``config.yaml`` under ``channels.slack``):
|
||||||
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
|
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
|
||||||
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
|
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
|
||||||
- ``allowed_users``: (optional) List of allowed Slack user IDs, or a
|
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all.
|
||||||
single Slack user ID string as shorthand. Empty = allow all. Other
|
|
||||||
scalar values are treated as a single string with a warning.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||||
@@ -48,7 +30,7 @@ class SlackChannel(Channel):
|
|||||||
self._socket_client = None
|
self._socket_client = None
|
||||||
self._web_client = None
|
self._web_client = None
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
|
|||||||
@@ -29,10 +29,6 @@ class WeComChannel(Channel):
|
|||||||
self._ws_stream_ids: dict[str, str] = {}
|
self._ws_stream_ids: dict[str, str] = {}
|
||||||
self._working_message = "Working on it..."
|
self._working_message = "Working on it..."
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
||||||
if not thread_ts:
|
if not thread_ts:
|
||||||
return
|
return
|
||||||
|
|||||||
+40
-62
@@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
@@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from app.gateway.auth_middleware import AuthMiddleware
|
from app.gateway.auth_middleware import AuthMiddleware
|
||||||
from app.gateway.config import get_gateway_config
|
from app.gateway.config import get_gateway_config
|
||||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||||
from app.gateway.deps import langgraph_runtime
|
from app.gateway.deps import langgraph_runtime
|
||||||
from app.gateway.routers import (
|
from app.gateway.routers import (
|
||||||
agents,
|
agents,
|
||||||
@@ -27,13 +27,9 @@ from app.gateway.routers import (
|
|||||||
threads,
|
threads,
|
||||||
uploads,
|
uploads,
|
||||||
)
|
)
|
||||||
from deerflow.config import app_config as deerflow_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.config.app_config import apply_logging_level
|
|
||||||
|
|
||||||
AppConfig = deerflow_app_config.AppConfig
|
# Configure logging
|
||||||
get_app_config = deerflow_app_config.get_app_config
|
|
||||||
|
|
||||||
# Default logging; lifespan overrides from config.yaml log_level.
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
@@ -42,11 +38,6 @@ logging.basicConfig(
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
|
|
||||||
# Bounds worker exit time so uvicorn's reload supervisor does not keep
|
|
||||||
# firing signals into a worker that is stuck waiting for shutdown cleanup.
|
|
||||||
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
|
|
||||||
|
|
||||||
|
|
||||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||||
"""Startup hook: handle first boot and migrate orphan threads otherwise.
|
"""Startup hook: handle first boot and migrate orphan threads otherwise.
|
||||||
@@ -62,7 +53,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
|||||||
|
|
||||||
Subsequent boots (admin already exists):
|
Subsequent boots (admin already exists):
|
||||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||||
existing LangGraph thread metadata that has no user_id.
|
existing LangGraph thread metadata that has no owner_id.
|
||||||
|
|
||||||
No SQL persistence migration is needed: the four user_id columns
|
No SQL persistence migration is needed: the four user_id columns
|
||||||
(threads_meta, runs, run_events, feedback) only come into existence
|
(threads_meta, runs, run_events, feedback) only come into existence
|
||||||
@@ -75,18 +66,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
|||||||
from deerflow.persistence.engine import get_session_factory
|
from deerflow.persistence.engine import get_session_factory
|
||||||
from deerflow.persistence.user.model import UserRow
|
from deerflow.persistence.user.model import UserRow
|
||||||
|
|
||||||
try:
|
provider = get_local_provider()
|
||||||
provider = get_local_provider()
|
|
||||||
except RuntimeError:
|
|
||||||
# Auth persistence may not be initialized in some test/boot paths.
|
|
||||||
# Skip admin migration work rather than failing gateway startup.
|
|
||||||
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
|
|
||||||
return
|
|
||||||
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
admin_count = await provider.count_admin_users()
|
admin_count = await provider.count_admin_users()
|
||||||
|
|
||||||
if admin_count == 0:
|
if admin_count == 0:
|
||||||
@@ -98,6 +78,10 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
|||||||
|
|
||||||
# Admin already exists — run orphan thread migration for any
|
# Admin already exists — run orphan thread migration for any
|
||||||
# LangGraph thread metadata that pre-dates the auth module.
|
# LangGraph thread metadata that pre-dates the auth module.
|
||||||
|
sf = get_session_factory()
|
||||||
|
if sf is None:
|
||||||
|
return
|
||||||
|
|
||||||
async with sf() as session:
|
async with sf() as session:
|
||||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
@@ -163,8 +147,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
# Load config and check necessary environment variables at startup
|
# Load config and check necessary environment variables at startup
|
||||||
try:
|
try:
|
||||||
app.state.config = get_app_config()
|
get_app_config()
|
||||||
apply_logging_level(app.state.config.log_level)
|
|
||||||
logger.info("Configuration loaded successfully")
|
logger.info("Configuration loaded successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||||
@@ -177,7 +160,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
async with langgraph_runtime(app):
|
async with langgraph_runtime(app):
|
||||||
logger.info("LangGraph runtime initialised")
|
logger.info("LangGraph runtime initialised")
|
||||||
|
|
||||||
# Check admin bootstrap state and migrate orphan threads after admin exists.
|
# Ensure admin user exists (auto-create on first boot)
|
||||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||||
await _ensure_admin_user(app)
|
await _ensure_admin_user(app)
|
||||||
|
|
||||||
@@ -185,26 +168,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
try:
|
try:
|
||||||
from app.channels.service import start_channel_service
|
from app.channels.service import start_channel_service
|
||||||
|
|
||||||
channel_service = await start_channel_service(app.state.config)
|
channel_service = await start_channel_service()
|
||||||
logger.info("Channel service started: %s", channel_service.get_status())
|
logger.info("Channel service started: %s", channel_service.get_status())
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("No IM channels configured or channel service failed to start")
|
logger.exception("No IM channels configured or channel service failed to start")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Stop channel service on shutdown (bounded to prevent worker hang)
|
# Stop channel service on shutdown
|
||||||
try:
|
try:
|
||||||
from app.channels.service import stop_channel_service
|
from app.channels.service import stop_channel_service
|
||||||
|
|
||||||
await asyncio.wait_for(
|
await stop_channel_service()
|
||||||
stop_channel_service(),
|
|
||||||
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
|
|
||||||
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to stop channel service")
|
logger.exception("Failed to stop channel service")
|
||||||
|
|
||||||
@@ -217,10 +192,6 @@ def create_app() -> FastAPI:
|
|||||||
Returns:
|
Returns:
|
||||||
Configured FastAPI application instance.
|
Configured FastAPI application instance.
|
||||||
"""
|
"""
|
||||||
config = get_gateway_config()
|
|
||||||
docs_url = "/docs" if config.enable_docs else None
|
|
||||||
redoc_url = "/redoc" if config.enable_docs else None
|
|
||||||
openapi_url = "/openapi.json" if config.enable_docs else None
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="DeerFlow API Gateway",
|
title="DeerFlow API Gateway",
|
||||||
@@ -240,14 +211,14 @@ API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execu
|
|||||||
|
|
||||||
### Architecture
|
### Architecture
|
||||||
|
|
||||||
LangGraph-compatible requests are routed through nginx to this gateway.
|
LangGraph requests are handled by nginx reverse proxy.
|
||||||
This gateway provides runtime endpoints for agent runs plus custom endpoints for models, MCP configuration, skills, and artifacts.
|
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
|
||||||
""",
|
""",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url=docs_url,
|
docs_url="/docs",
|
||||||
redoc_url=redoc_url,
|
redoc_url="/redoc",
|
||||||
openapi_url=openapi_url,
|
openapi_url="/openapi.json",
|
||||||
openapi_tags=[
|
openapi_tags=[
|
||||||
{
|
{
|
||||||
"name": "models",
|
"name": "models",
|
||||||
@@ -310,18 +281,25 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
|||||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||||
app.add_middleware(CSRFMiddleware)
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
|
||||||
# CORS: the unified nginx endpoint is same-origin by default. Split-origin
|
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
|
||||||
# browser clients must opt in with this explicit Gateway allowlist so CORS
|
# In production, nginx handles CORS and no middleware is needed.
|
||||||
# and CSRF origin checks share the same source of truth.
|
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||||
cors_origins = sorted(get_configured_cors_origins())
|
if cors_origins_env:
|
||||||
if cors_origins:
|
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||||
app.add_middleware(
|
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||||
CORSMiddleware,
|
for origin in cors_origins:
|
||||||
allow_origins=cors_origins,
|
if origin == "*":
|
||||||
allow_credentials=True,
|
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
|
||||||
allow_methods=["*"],
|
cors_origins = [o for o in cors_origins if o != "*"]
|
||||||
allow_headers=["*"],
|
break
|
||||||
)
|
if cors_origins:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
# Include routers
|
# Include routers
|
||||||
# Models API is mounted at /api/models
|
# Models API is mounted at /api/models
|
||||||
@@ -370,7 +348,7 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
|||||||
app.include_router(runs.router)
|
app.include_router(runs.router)
|
||||||
|
|
||||||
@app.get("/health", tags=["health"])
|
@app.get("/health", tags=["health"])
|
||||||
async def health_check() -> dict[str, str]:
|
async def health_check() -> dict:
|
||||||
"""Health check endpoint.
|
"""Health check endpoint.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
load_dotenv()
|
||||||
|
|
||||||
_SECRET_FILE = ".jwt_secret"
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(BaseModel):
|
class AuthConfig(BaseModel):
|
||||||
@@ -32,46 +33,17 @@ class AuthConfig(BaseModel):
|
|||||||
_auth_config: AuthConfig | None = None
|
_auth_config: AuthConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
def _load_or_create_secret() -> str:
|
|
||||||
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
paths = get_paths()
|
|
||||||
secret_file = paths.base_dir / _SECRET_FILE
|
|
||||||
|
|
||||||
try:
|
|
||||||
if secret_file.exists():
|
|
||||||
secret = secret_file.read_text(encoding="utf-8").strip()
|
|
||||||
if secret:
|
|
||||||
return secret
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
|
|
||||||
|
|
||||||
secret = secrets.token_urlsafe(32)
|
|
||||||
try:
|
|
||||||
secret_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
||||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
||||||
fh.write(secret)
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
|
|
||||||
return secret
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_config() -> AuthConfig:
|
def get_auth_config() -> AuthConfig:
|
||||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||||
global _auth_config
|
global _auth_config
|
||||||
if _auth_config is None:
|
if _auth_config is None:
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||||
if not jwt_secret:
|
if not jwt_secret:
|
||||||
jwt_secret = _load_or_create_secret()
|
jwt_secret = secrets.token_urlsafe(32)
|
||||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
|
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||||
"persisted to .jwt_secret. Sessions will survive restarts. "
|
"Sessions will be invalidated on restart. "
|
||||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
"""Local email/password authentication provider."""
|
"""Local email/password authentication provider."""
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
from app.gateway.auth.models import User
|
||||||
from app.gateway.auth.password import hash_password_async, needs_rehash, verify_password_async
|
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||||
from app.gateway.auth.providers import AuthProvider
|
from app.gateway.auth.providers import AuthProvider
|
||||||
from app.gateway.auth.repositories.base import UserRepository
|
from app.gateway.auth.repositories.base import UserRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAuthProvider(AuthProvider):
|
class LocalAuthProvider(AuthProvider):
|
||||||
"""Email/password authentication provider using local database."""
|
"""Email/password authentication provider using local database."""
|
||||||
@@ -47,15 +43,6 @@ class LocalAuthProvider(AuthProvider):
|
|||||||
if not await verify_password_async(password, user.password_hash):
|
if not await verify_password_async(password, user.password_hash):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if needs_rehash(user.password_hash):
|
|
||||||
try:
|
|
||||||
user.password_hash = await hash_password_async(password)
|
|
||||||
await self._repo.update_user(user)
|
|
||||||
except Exception:
|
|
||||||
# Rehash is an opportunistic upgrade; a transient DB error must not
|
|
||||||
# prevent an otherwise-valid login from succeeding.
|
|
||||||
logger.warning("Failed to rehash password for user %s; login will still succeed", user.email, exc_info=True)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def get_user(self, user_id: str) -> User | None:
|
async def get_user(self, user_id: str) -> User | None:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class User(BaseModel):
|
|||||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||||
|
|
||||||
# Auth lifecycle
|
# Auth lifecycle
|
||||||
needs_setup: bool = Field(default=False, description="True when a reset account must complete setup")
|
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,66 +1,18 @@
|
|||||||
"""Password hashing utilities with versioned hash format.
|
"""Password hashing utilities using bcrypt directly."""
|
||||||
|
|
||||||
Hash format: ``$dfv<N>$<bcrypt_hash>`` where ``<N>`` is the version.
|
|
||||||
|
|
||||||
- **v1** (legacy): ``bcrypt(password)`` — plain bcrypt, susceptible to
|
|
||||||
72-byte silent truncation.
|
|
||||||
- **v2** (current): ``bcrypt(b64(sha256(password)))`` — SHA-256 pre-hash
|
|
||||||
avoids the 72-byte truncation limit so the full password contributes
|
|
||||||
to the hash.
|
|
||||||
|
|
||||||
Verification auto-detects the version and falls back to v1 for hashes
|
|
||||||
without a prefix, so existing deployments upgrade transparently on next
|
|
||||||
login.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
_CURRENT_VERSION = 2
|
|
||||||
_PREFIX_V2 = "$dfv2$"
|
|
||||||
_PREFIX_V1 = "$dfv1$"
|
|
||||||
|
|
||||||
|
|
||||||
def _pre_hash_v2(password: str) -> bytes:
|
|
||||||
"""SHA-256 pre-hash to bypass bcrypt's 72-byte limit."""
|
|
||||||
return base64.b64encode(hashlib.sha256(password.encode("utf-8")).digest())
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
"""Hash a password (current version: v2 — SHA-256 + bcrypt)."""
|
"""Hash a password using bcrypt."""
|
||||||
raw = bcrypt.hashpw(_pre_hash_v2(password), bcrypt.gensalt()).decode("utf-8")
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
return f"{_PREFIX_V2}{raw}"
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""Verify a password, auto-detecting the hash version.
|
"""Verify a password against its hash."""
|
||||||
|
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||||
Accepts v2 (``$dfv2$…``), v1 (``$dfv1$…``), and bare bcrypt hashes
|
|
||||||
(treated as v1 for backward compatibility with pre-versioning data).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if hashed_password.startswith(_PREFIX_V2):
|
|
||||||
bcrypt_hash = hashed_password[len(_PREFIX_V2) :]
|
|
||||||
return bcrypt.checkpw(_pre_hash_v2(plain_password), bcrypt_hash.encode("utf-8"))
|
|
||||||
|
|
||||||
if hashed_password.startswith(_PREFIX_V1):
|
|
||||||
bcrypt_hash = hashed_password[len(_PREFIX_V1) :]
|
|
||||||
else:
|
|
||||||
bcrypt_hash = hashed_password
|
|
||||||
|
|
||||||
return bcrypt.checkpw(plain_password.encode("utf-8"), bcrypt_hash.encode("utf-8"))
|
|
||||||
except ValueError:
|
|
||||||
# bcrypt raises ValueError for malformed or corrupt hashes (e.g., invalid salt).
|
|
||||||
# Fail closed rather than crashing the request.
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def needs_rehash(hashed_password: str) -> bool:
|
|
||||||
"""Return True if the hash uses an older version and should be rehashed."""
|
|
||||||
return not hashed_password.startswith(_PREFIX_V2)
|
|
||||||
|
|
||||||
|
|
||||||
async def hash_password_async(password: str) -> str:
|
async def hash_password_async(password: str) -> str:
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ class AuthProvider(ABC):
|
|||||||
|
|
||||||
Returns User if authentication succeeds, None otherwise.
|
Returns User if authentication succeeds, None otherwise.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user(self, user_id: str) -> "User | None":
|
async def get_user(self, user_id: str) -> "User | None":
|
||||||
"""Retrieve user by ID."""
|
"""Retrieve user by ID."""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
|
|
||||||
# Import User at runtime to avoid circular imports
|
# Import User at runtime to avoid circular imports
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class UserRepository(ABC):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If email already exists
|
ValueError: If email already exists
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||||
@@ -47,7 +47,7 @@ class UserRepository(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
User if found, None otherwise
|
User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
async def get_user_by_email(self, email: str) -> User | None:
|
||||||
@@ -59,7 +59,7 @@ class UserRepository(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
User if found, None otherwise
|
User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def update_user(self, user: User) -> User:
|
async def update_user(self, user: User) -> User:
|
||||||
@@ -76,17 +76,17 @@ class UserRepository(ABC):
|
|||||||
a hard failure (not a no-op) so callers cannot mistake a
|
a hard failure (not a no-op) so callers cannot mistake a
|
||||||
concurrent-delete race for a successful update.
|
concurrent-delete race for a successful update.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def count_users(self) -> int:
|
async def count_users(self) -> int:
|
||||||
"""Return total number of registered users."""
|
"""Return total number of registered users."""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def count_admin_users(self) -> int:
|
async def count_admin_users(self) -> int:
|
||||||
"""Return number of users with system_role == 'admin'."""
|
"""Return number of users with system_role == 'admin'."""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
@@ -99,4 +99,4 @@ class UserRepository(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
User if found, None otherwise
|
User if found, None otherwise
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from starlette.types import ASGIApp
|
|||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
|
|
||||||
# Paths that never require authentication.
|
# Paths that never require authentication.
|
||||||
@@ -76,12 +75,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if _is_public(request.url.path):
|
if _is_public(request.url.path):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
internal_user = None
|
|
||||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
|
||||||
internal_user = get_internal_user()
|
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
# Non-public path: require session cookie
|
||||||
if internal_user is None and not request.cookies.get("access_token"):
|
if not request.cookies.get("access_token"):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
content={
|
content={
|
||||||
@@ -105,13 +100,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
# bubble up, so we catch and render it as JSONResponse here.
|
# bubble up, so we catch and render it as JSONResponse here.
|
||||||
from app.gateway.deps import get_current_user_from_request
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
if internal_user is not None:
|
try:
|
||||||
user = internal_user
|
user = await get_current_user_from_request(request)
|
||||||
else:
|
except HTTPException as exc:
|
||||||
try:
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
except HTTPException as exc:
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
# Stamp both request.state.user (for the contextvar pattern)
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
# and request.state.auth (so @require_permission's "auth is
|
||||||
|
|||||||
@@ -30,9 +30,7 @@ Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blo
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -119,15 +117,6 @@ _ALL_PERMISSIONS: list[str] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _make_test_request_stub() -> Any:
|
|
||||||
"""Create a minimal request-like object for direct unit calls.
|
|
||||||
|
|
||||||
Used when decorated route handlers are invoked without FastAPI's
|
|
||||||
request injection. Includes fields accessed by auth helpers.
|
|
||||||
"""
|
|
||||||
return SimpleNamespace(state=SimpleNamespace(), cookies={}, _deerflow_test_bypass_auth=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _authenticate(request: Request) -> AuthContext:
|
async def _authenticate(request: Request) -> AuthContext:
|
||||||
"""Authenticate request and return AuthContext.
|
"""Authenticate request and return AuthContext.
|
||||||
|
|
||||||
@@ -145,11 +134,7 @@ async def _authenticate(request: Request) -> AuthContext:
|
|||||||
|
|
||||||
|
|
||||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||||
"""Decorator that authenticates the request and enforces authentication.
|
"""Decorator that authenticates the request and sets AuthContext.
|
||||||
|
|
||||||
Independently raises HTTP 401 for unauthenticated requests, regardless of
|
|
||||||
whether ``AuthMiddleware`` is present in the ASGI stack. Sets the resolved
|
|
||||||
``AuthContext`` on ``request.state.auth`` for downstream handlers.
|
|
||||||
|
|
||||||
Must be placed ABOVE other decorators (executes after them).
|
Must be placed ABOVE other decorators (executes after them).
|
||||||
|
|
||||||
@@ -162,33 +147,19 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
|||||||
...
|
...
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: 401 if the request is unauthenticated.
|
ValueError: If 'request' parameter is missing
|
||||||
ValueError: If 'request' parameter is missing.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
request = kwargs.get("request")
|
request = kwargs.get("request")
|
||||||
if request is None:
|
if request is None:
|
||||||
# Unit tests may call decorated handlers directly without a
|
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||||
# FastAPI Request object. Inject a minimal request stub when
|
|
||||||
# the wrapped function declares `request`.
|
|
||||||
if "request" in inspect.signature(func).parameters:
|
|
||||||
kwargs["request"] = _make_test_request_stub()
|
|
||||||
else:
|
|
||||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
|
||||||
request = kwargs["request"]
|
|
||||||
|
|
||||||
if getattr(request, "_deerflow_test_bypass_auth", False):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Authenticate and set context
|
# Authenticate and set context
|
||||||
auth_context = await _authenticate(request)
|
auth_context = await _authenticate(request)
|
||||||
request.state.auth = auth_context
|
request.state.auth = auth_context
|
||||||
|
|
||||||
if not auth_context.is_authenticated:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -239,17 +210,7 @@ def require_permission(
|
|||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
request = kwargs.get("request")
|
request = kwargs.get("request")
|
||||||
if request is None:
|
if request is None:
|
||||||
# Unit tests may call decorated route handlers directly without
|
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||||
# constructing a FastAPI Request object. Inject a minimal stub
|
|
||||||
# when the wrapped function declares `request`.
|
|
||||||
if "request" in inspect.signature(func).parameters:
|
|
||||||
kwargs["request"] = _make_test_request_stub()
|
|
||||||
else:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
request = kwargs["request"]
|
|
||||||
|
|
||||||
if getattr(request, "_deerflow_test_bypass_auth", False):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
auth: AuthContext = getattr(request.state, "auth", None)
|
auth: AuthContext = getattr(request.state, "auth", None)
|
||||||
if auth is None:
|
if auth is None:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class GatewayConfig(BaseModel):
|
|||||||
|
|
||||||
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
||||||
port: int = Field(default=8001, description="Port to bind the gateway server")
|
port: int = Field(default=8001, description="Port to bind the gateway server")
|
||||||
enable_docs: bool = Field(default=True, description="Enable Swagger/ReDoc/OpenAPI endpoints")
|
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
|
||||||
|
|
||||||
|
|
||||||
_gateway_config: GatewayConfig | None = None
|
_gateway_config: GatewayConfig | None = None
|
||||||
@@ -18,9 +18,10 @@ def get_gateway_config() -> GatewayConfig:
|
|||||||
"""Get gateway config, loading from environment if available."""
|
"""Get gateway config, loading from environment if available."""
|
||||||
global _gateway_config
|
global _gateway_config
|
||||||
if _gateway_config is None:
|
if _gateway_config is None:
|
||||||
|
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
|
||||||
_gateway_config = GatewayConfig(
|
_gateway_config = GatewayConfig(
|
||||||
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
||||||
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
||||||
enable_docs=os.getenv("GATEWAY_ENABLE_DOCS", "true").lower() == "true",
|
cors_origins=cors_origins_str.split(","),
|
||||||
)
|
)
|
||||||
return _gateway_config
|
return _gateway_config
|
||||||
|
|||||||
@@ -4,10 +4,8 @@ Per RFC-001:
|
|||||||
State-changing operations require CSRF protection.
|
State-changing operations require CSRF protection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import secrets
|
import secrets
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Callable
|
||||||
from urllib.parse import urlsplit
|
|
||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
@@ -21,7 +19,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
|
|||||||
|
|
||||||
def is_secure_request(request: Request) -> bool:
|
def is_secure_request(request: Request) -> bool:
|
||||||
"""Detect whether the original client request was made over HTTPS."""
|
"""Detect whether the original client request was made over HTTPS."""
|
||||||
return _request_scheme(request) == "https"
|
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||||
|
|
||||||
|
|
||||||
def generate_csrf_token() -> str:
|
def generate_csrf_token() -> str:
|
||||||
@@ -63,129 +61,15 @@ def is_auth_endpoint(request: Request) -> bool:
|
|||||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||||
|
|
||||||
|
|
||||||
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
|
|
||||||
"""Return normalized host[:port], omitting default ports."""
|
|
||||||
host = hostname.lower()
|
|
||||||
if ":" in host and not host.startswith("["):
|
|
||||||
host = f"[{host}]"
|
|
||||||
|
|
||||||
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
|
|
||||||
return host
|
|
||||||
return f"{host}:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_origin(origin: str) -> str | None:
|
|
||||||
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
|
|
||||||
try:
|
|
||||||
parsed = urlsplit(origin.strip())
|
|
||||||
port = parsed.port
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
scheme = parsed.scheme.lower()
|
|
||||||
if scheme not in {"http", "https"} or not parsed.hostname:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
|
|
||||||
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
|
|
||||||
|
|
||||||
|
|
||||||
def _configured_cors_origins() -> set[str]:
|
|
||||||
"""Return explicit configured browser origins that may call auth routes."""
|
|
||||||
origins = set()
|
|
||||||
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
|
|
||||||
origin = raw_origin.strip()
|
|
||||||
if not origin or origin == "*":
|
|
||||||
continue
|
|
||||||
normalized = _normalize_origin(origin)
|
|
||||||
if normalized:
|
|
||||||
origins.add(normalized)
|
|
||||||
return origins
|
|
||||||
|
|
||||||
|
|
||||||
def get_configured_cors_origins() -> set[str]:
|
|
||||||
"""Return normalized explicit browser origins from GATEWAY_CORS_ORIGINS."""
|
|
||||||
return _configured_cors_origins()
|
|
||||||
|
|
||||||
|
|
||||||
def _first_header_value(value: str | None) -> str | None:
|
|
||||||
"""Return the first value from a comma-separated proxy header."""
|
|
||||||
if not value:
|
|
||||||
return None
|
|
||||||
first = value.split(",", 1)[0].strip()
|
|
||||||
return first or None
|
|
||||||
|
|
||||||
|
|
||||||
def _forwarded_param(request: Request, name: str) -> str | None:
|
|
||||||
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
|
|
||||||
forwarded = _first_header_value(request.headers.get("forwarded"))
|
|
||||||
if not forwarded:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for part in forwarded.split(";"):
|
|
||||||
key, sep, value = part.strip().partition("=")
|
|
||||||
if sep and key.lower() == name:
|
|
||||||
return value.strip().strip('"') or None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _request_scheme(request: Request) -> str:
|
|
||||||
"""Resolve the original request scheme from trusted proxy headers."""
|
|
||||||
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
|
|
||||||
return scheme.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def _request_origin(request: Request) -> str | None:
|
|
||||||
"""Build the origin for the URL the browser is targeting."""
|
|
||||||
scheme = _request_scheme(request)
|
|
||||||
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
|
|
||||||
|
|
||||||
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
|
|
||||||
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
|
|
||||||
host = f"{host}:{forwarded_port}"
|
|
||||||
|
|
||||||
return _normalize_origin(f"{scheme}://{host}")
|
|
||||||
|
|
||||||
|
|
||||||
def is_allowed_auth_origin(request: Request) -> bool:
|
|
||||||
"""Allow auth POSTs only from the same origin or explicit configured origins.
|
|
||||||
|
|
||||||
Login/register/initialize are exempt from the double-submit token because
|
|
||||||
first-time browser clients do not have a CSRF token yet. They still create
|
|
||||||
a session cookie, so browser requests with a hostile Origin header must be
|
|
||||||
rejected to prevent login CSRF / session fixation. Requests without Origin
|
|
||||||
are allowed for non-browser clients such as curl and mobile integrations.
|
|
||||||
"""
|
|
||||||
origin = request.headers.get("origin")
|
|
||||||
if not origin:
|
|
||||||
return True
|
|
||||||
|
|
||||||
normalized_origin = _normalize_origin(origin)
|
|
||||||
if normalized_origin is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
request_origin = _request_origin(request)
|
|
||||||
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
|
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
_is_auth = is_auth_endpoint(request)
|
_is_auth = is_auth_endpoint(request)
|
||||||
|
|
||||||
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "Cross-site auth request denied."},
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_check_csrf(request) and not _is_auth:
|
if should_check_csrf(request) and not _is_auth:
|
||||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, TypeVar, cast
|
|||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
from deerflow.persistence.feedback import FeedbackRepository
|
||||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
@@ -30,14 +29,6 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def get_config(request: Request) -> AppConfig:
|
|
||||||
"""Return the app-scoped ``AppConfig`` stored on ``app.state``."""
|
|
||||||
config = getattr(request.app.state, "config", None)
|
|
||||||
if config is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Configuration not available")
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||||
@@ -47,24 +38,22 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
async with langgraph_runtime(app):
|
async with langgraph_runtime(app):
|
||||||
yield
|
yield
|
||||||
"""
|
"""
|
||||||
|
from deerflow.config import get_app_config
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
from deerflow.runtime import make_store, make_stream_bridge
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
from deerflow.runtime.events.store import make_run_event_store
|
from deerflow.runtime.events.store import make_run_event_store
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
config = getattr(app.state, "config", None)
|
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||||
if config is None:
|
|
||||||
raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized")
|
|
||||||
|
|
||||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
|
|
||||||
|
|
||||||
# Initialize persistence engine BEFORE checkpointer so that
|
# Initialize persistence engine BEFORE checkpointer so that
|
||||||
# auto-create-database logic runs first (postgres backend).
|
# auto-create-database logic runs first (postgres backend).
|
||||||
|
config = get_app_config()
|
||||||
await init_engine_from_config(config.database)
|
await init_engine_from_config(config.database)
|
||||||
|
|
||||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
|
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||||
app.state.store = await stack.enter_async_context(make_store(config))
|
app.state.store = await stack.enter_async_context(make_store())
|
||||||
|
|
||||||
# Initialize repositories — one get_session_factory() call for all.
|
# Initialize repositories — one get_session_factory() call for all.
|
||||||
sf = get_session_factory()
|
sf = get_session_factory()
|
||||||
@@ -141,14 +130,14 @@ def get_run_context(request: Request) -> RunContext:
|
|||||||
|
|
||||||
Returns a *base* context with infrastructure dependencies.
|
Returns a *base* context with infrastructure dependencies.
|
||||||
"""
|
"""
|
||||||
config = get_config(request)
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
return RunContext(
|
return RunContext(
|
||||||
checkpointer=get_checkpointer(request),
|
checkpointer=get_checkpointer(request),
|
||||||
store=get_store(request),
|
store=get_store(request),
|
||||||
event_store=get_run_event_store(request),
|
event_store=get_run_event_store(request),
|
||||||
run_events_config=getattr(config, "run_events", None),
|
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||||
thread_store=get_thread_store(request),
|
thread_store=get_thread_store(request),
|
||||||
app_config=config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
"""Process-local authentication for Gateway internal callers."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
|
||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
|
||||||
_INTERNAL_AUTH_TOKEN = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
def create_internal_auth_headers() -> dict[str, str]:
|
|
||||||
"""Return headers that authenticate same-process Gateway internal calls."""
|
|
||||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
|
||||||
"""Return True when *token* matches the process-local internal token."""
|
|
||||||
return bool(token) and secrets.compare_digest(token, _INTERNAL_AUTH_TOKEN)
|
|
||||||
|
|
||||||
|
|
||||||
def get_internal_user():
|
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
|
|
||||||
@@ -1,12 +1,8 @@
|
|||||||
"""LangGraph compatibility auth handler — shares JWT logic with Gateway.
|
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||||
|
|
||||||
The default DeerFlow runtime is embedded in the FastAPI Gateway; scripts and
|
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||||
Docker deployments do not load this module. It is retained for LangGraph
|
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||||
tooling, Studio, or direct LangGraph Server compatibility through
|
so both modes validate tokens with the same secret and rules.
|
||||||
``langgraph.json``'s ``auth.path``.
|
|
||||||
|
|
||||||
When that compatibility path is used, this module reuses the same JWT and CSRF
|
|
||||||
rules as Gateway so both modes validate sessions consistently.
|
|
||||||
|
|
||||||
Two layers:
|
Two layers:
|
||||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||||
@@ -77,7 +73,7 @@ async def authenticate(request):
|
|||||||
if isinstance(payload, TokenError):
|
if isinstance(payload, TokenError):
|
||||||
raise Auth.exceptions.HTTPException(
|
raise Auth.exceptions.HTTPException(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
detail="Invalid token",
|
detail=f"Token error: {payload.value}",
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await get_local_provider().get_user(payload.sub)
|
user = await get_local_provider().get_user(payload.sub)
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ import yaml
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from deerflow.config.agents_api_config import get_agents_api_config
|
|
||||||
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api", tags=["agents"])
|
router = APIRouter(prefix="/api", tags=["agents"])
|
||||||
@@ -26,7 +24,6 @@ class AgentResponse(BaseModel):
|
|||||||
description: str = Field(default="", description="Agent description")
|
description: str = Field(default="", description="Agent description")
|
||||||
model: str | None = Field(default=None, description="Optional model override")
|
model: str | None = Field(default=None, description="Optional model override")
|
||||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||||
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all, []=none)")
|
|
||||||
soul: str | None = Field(default=None, description="SOUL.md content")
|
soul: str | None = Field(default=None, description="SOUL.md content")
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +40,6 @@ class AgentCreateRequest(BaseModel):
|
|||||||
description: str = Field(default="", description="Agent description")
|
description: str = Field(default="", description="Agent description")
|
||||||
model: str | None = Field(default=None, description="Optional model override")
|
model: str | None = Field(default=None, description="Optional model override")
|
||||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||||
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all enabled, []=none)")
|
|
||||||
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
|
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +49,6 @@ class AgentUpdateRequest(BaseModel):
|
|||||||
description: str | None = Field(default=None, description="Updated description")
|
description: str | None = Field(default=None, description="Updated description")
|
||||||
model: str | None = Field(default=None, description="Updated model override")
|
model: str | None = Field(default=None, description="Updated model override")
|
||||||
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
|
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
|
||||||
skills: list[str] | None = Field(default=None, description="Updated skill whitelist (None=all, []=none)")
|
|
||||||
soul: str | None = Field(default=None, description="Updated SOUL.md content")
|
soul: str | None = Field(default=None, description="Updated SOUL.md content")
|
||||||
|
|
||||||
|
|
||||||
@@ -78,27 +73,17 @@ def _normalize_agent_name(name: str) -> str:
|
|||||||
return name.lower()
|
return name.lower()
|
||||||
|
|
||||||
|
|
||||||
def _require_agents_api_enabled() -> None:
|
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
|
||||||
"""Reject access unless the custom-agent management API is explicitly enabled."""
|
|
||||||
if not get_agents_api_config().enabled:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
|
|
||||||
"""Convert AgentConfig to AgentResponse."""
|
"""Convert AgentConfig to AgentResponse."""
|
||||||
soul: str | None = None
|
soul: str | None = None
|
||||||
if include_soul:
|
if include_soul:
|
||||||
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
|
soul = load_agent_soul(agent_cfg.name) or ""
|
||||||
|
|
||||||
return AgentResponse(
|
return AgentResponse(
|
||||||
name=agent_cfg.name,
|
name=agent_cfg.name,
|
||||||
description=agent_cfg.description,
|
description=agent_cfg.description,
|
||||||
model=agent_cfg.model,
|
model=agent_cfg.model,
|
||||||
tool_groups=agent_cfg.tool_groups,
|
tool_groups=agent_cfg.tool_groups,
|
||||||
skills=agent_cfg.skills,
|
|
||||||
soul=soul,
|
soul=soul,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,12 +100,9 @@ async def list_agents() -> AgentsListResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
List of all custom agents with their metadata and soul content.
|
List of all custom agents with their metadata and soul content.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
|
|
||||||
user_id = get_effective_user_id()
|
|
||||||
try:
|
try:
|
||||||
agents = list_custom_agents(user_id=user_id)
|
agents = list_custom_agents()
|
||||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
|
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||||
@@ -143,15 +125,9 @@ async def check_agent_name(name: str) -> dict:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 422 if the name is invalid.
|
HTTPException: 422 if the name is invalid.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
normalized = _normalize_agent_name(name)
|
normalized = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
available = not get_paths().agent_dir(normalized).exists()
|
||||||
paths = get_paths()
|
|
||||||
# Treat the name as taken if either the per-user path or the legacy shared
|
|
||||||
# path holds an agent — picking a name that collides with an unmigrated
|
|
||||||
# legacy agent would shadow the legacy entry once migration runs.
|
|
||||||
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
|
|
||||||
return {"available": available, "name": normalized}
|
return {"available": available, "name": normalized}
|
||||||
|
|
||||||
|
|
||||||
@@ -173,14 +149,12 @@ async def get_agent(name: str) -> AgentResponse:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 404 if agent not found.
|
HTTPException: 404 if agent not found.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
agent_cfg = load_agent_config(name)
|
||||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -207,16 +181,12 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 409 if agent already exists, 422 if name is invalid.
|
HTTPException: 409 if agent already exists, 422 if name is invalid.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(request.name)
|
_validate_agent_name(request.name)
|
||||||
normalized_name = _normalize_agent_name(request.name)
|
normalized_name = _normalize_agent_name(request.name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
paths = get_paths()
|
|
||||||
|
|
||||||
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
agent_dir = get_paths().agent_dir(normalized_name)
|
||||||
legacy_dir = paths.agent_dir(normalized_name)
|
|
||||||
|
|
||||||
if agent_dir.exists() or legacy_dir.exists():
|
if agent_dir.exists():
|
||||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -230,8 +200,6 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
config_data["model"] = request.model
|
config_data["model"] = request.model
|
||||||
if request.tool_groups is not None:
|
if request.tool_groups is not None:
|
||||||
config_data["tool_groups"] = request.tool_groups
|
config_data["tool_groups"] = request.tool_groups
|
||||||
if request.skills is not None:
|
|
||||||
config_data["skills"] = request.skills
|
|
||||||
|
|
||||||
config_file = agent_dir / "config.yaml"
|
config_file = agent_dir / "config.yaml"
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
@@ -243,8 +211,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
|
|
||||||
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
||||||
|
|
||||||
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
agent_cfg = load_agent_config(normalized_name)
|
||||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -275,52 +243,33 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 404 if agent not found.
|
HTTPException: 404 if agent not found.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
agent_cfg = load_agent_config(name)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
|
|
||||||
paths = get_paths()
|
agent_dir = get_paths().agent_dir(name)
|
||||||
agent_dir = paths.user_agent_dir(user_id, name)
|
|
||||||
if not agent_dir.exists() and paths.agent_dir(name).exists():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update config if any config fields changed
|
# Update config if any config fields changed
|
||||||
# Use model_fields_set to distinguish "field omitted" from "explicitly set to null".
|
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
|
||||||
# This is critical for skills where None means "inherit all" (not "don't change").
|
|
||||||
fields_set = request.model_fields_set
|
|
||||||
config_changed = bool(fields_set & {"description", "model", "tool_groups", "skills"})
|
|
||||||
|
|
||||||
if config_changed:
|
if config_changed:
|
||||||
updated: dict = {
|
updated: dict = {
|
||||||
"name": agent_cfg.name,
|
"name": agent_cfg.name,
|
||||||
"description": request.description if "description" in fields_set else agent_cfg.description,
|
"description": request.description if request.description is not None else agent_cfg.description,
|
||||||
}
|
}
|
||||||
new_model = request.model if "model" in fields_set else agent_cfg.model
|
new_model = request.model if request.model is not None else agent_cfg.model
|
||||||
if new_model is not None:
|
if new_model is not None:
|
||||||
updated["model"] = new_model
|
updated["model"] = new_model
|
||||||
|
|
||||||
new_tool_groups = request.tool_groups if "tool_groups" in fields_set else agent_cfg.tool_groups
|
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
|
||||||
if new_tool_groups is not None:
|
if new_tool_groups is not None:
|
||||||
updated["tool_groups"] = new_tool_groups
|
updated["tool_groups"] = new_tool_groups
|
||||||
|
|
||||||
# skills: None = inherit all, [] = no skills, ["a","b"] = whitelist
|
|
||||||
if "skills" in fields_set:
|
|
||||||
new_skills = request.skills
|
|
||||||
else:
|
|
||||||
new_skills = agent_cfg.skills
|
|
||||||
if new_skills is not None:
|
|
||||||
updated["skills"] = new_skills
|
|
||||||
|
|
||||||
config_file = agent_dir / "config.yaml"
|
config_file = agent_dir / "config.yaml"
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
|
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
|
||||||
@@ -332,8 +281,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
|||||||
|
|
||||||
logger.info(f"Updated agent '{name}'")
|
logger.info(f"Updated agent '{name}'")
|
||||||
|
|
||||||
refreshed_cfg = load_agent_config(name, user_id=user_id)
|
refreshed_cfg = load_agent_config(name)
|
||||||
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(refreshed_cfg, include_soul=True)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -366,8 +315,6 @@ async def get_user_profile() -> UserProfileResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
UserProfileResponse with content=None if USER.md does not exist yet.
|
UserProfileResponse with content=None if USER.md does not exist yet.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_md_path = get_paths().user_md_file
|
user_md_path = get_paths().user_md_file
|
||||||
if not user_md_path.exists():
|
if not user_md_path.exists():
|
||||||
@@ -394,8 +341,6 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
|
|||||||
Returns:
|
Returns:
|
||||||
UserProfileResponse with the saved content.
|
UserProfileResponse with the saved content.
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.base_dir.mkdir(parents=True, exist_ok=True)
|
paths.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -420,22 +365,14 @@ async def delete_agent(name: str) -> None:
|
|||||||
name: The agent name.
|
name: The agent name.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
|
HTTPException: 404 if agent not found.
|
||||||
shared copy exists (suggesting the migration script).
|
|
||||||
"""
|
"""
|
||||||
_require_agents_api_enabled()
|
|
||||||
_validate_agent_name(name)
|
_validate_agent_name(name)
|
||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
|
||||||
paths = get_paths()
|
agent_dir = get_paths().agent_dir(name)
|
||||||
agent_dir = paths.user_agent_dir(user_id, name)
|
|
||||||
|
|
||||||
if not agent_dir.exists():
|
if not agent_dir.exists():
|
||||||
if paths.agent_dir(name).exists():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
|
||||||
)
|
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -20,9 +20,6 @@ ACTIVE_CONTENT_MIME_TYPES = {
|
|||||||
"image/svg+xml",
|
"image/svg+xml",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
|
|
||||||
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
||||||
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
||||||
@@ -47,22 +44,6 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
|
|
||||||
"""Read a .skill archive member while enforcing an uncompressed size cap."""
|
|
||||||
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
|
||||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
|
||||||
|
|
||||||
chunks: list[bytes] = []
|
|
||||||
total_read = 0
|
|
||||||
with zip_ref.open(info, "r") as src:
|
|
||||||
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
|
|
||||||
total_read += len(chunk)
|
|
||||||
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
|
||||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
|
||||||
chunks.append(chunk)
|
|
||||||
return b"".join(chunks)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
||||||
"""Extract a file from a .skill ZIP archive.
|
"""Extract a file from a .skill ZIP archive.
|
||||||
|
|
||||||
@@ -79,16 +60,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
|||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
# List all files in the archive
|
# List all files in the archive
|
||||||
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
|
namelist = zip_ref.namelist()
|
||||||
|
|
||||||
# Try direct path first
|
# Try direct path first
|
||||||
if internal_path in infos_by_name:
|
if internal_path in namelist:
|
||||||
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
|
return zip_ref.read(internal_path)
|
||||||
|
|
||||||
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
||||||
for name, info in infos_by_name.items():
|
for name in namelist:
|
||||||
if name.endswith("/" + internal_path) or name == internal_path:
|
if name.endswith("/" + internal_path) or name == internal_path:
|
||||||
return _read_skill_archive_member(zip_ref, info)
|
return zip_ref.read(name)
|
||||||
|
|
||||||
# Not found
|
# Not found
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Authentication endpoints."""
|
"""Authentication endpoints."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -147,13 +146,7 @@ def _set_session_cookie(response: Response, token: str, request: Request) -> Non
|
|||||||
|
|
||||||
|
|
||||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||||
# In-process dict — not shared across workers.
|
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||||
#
|
|
||||||
# **Limitation**: with multi-worker deployments (e.g., gunicorn -w N), each
|
|
||||||
# worker maintains its own lockout table, so an attacker effectively gets
|
|
||||||
# N × _MAX_LOGIN_ATTEMPTS guesses before being locked out everywhere. For
|
|
||||||
# production multi-worker setups, replace this with a shared store (Redis,
|
|
||||||
# database-backed counter) to enforce a true per-IP limit.
|
|
||||||
|
|
||||||
_MAX_LOGIN_ATTEMPTS = 5
|
_MAX_LOGIN_ATTEMPTS = 5
|
||||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||||
@@ -306,7 +299,7 @@ async def login_local(
|
|||||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||||
"""Register a new user account (always 'user' role).
|
"""Register a new user account (always 'user' role).
|
||||||
|
|
||||||
The first admin is created explicitly through /initialize. This endpoint creates regular users.
|
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||||
Auto-login by setting the session cookie.
|
Auto-login by setting the session cookie.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -383,72 +376,11 @@ async def get_me(request: Request):
|
|||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||||
|
|
||||||
|
|
||||||
# Per-IP cache: ip → (timestamp, result_dict).
|
|
||||||
# Returns the cached result within the TTL instead of 429, because
|
|
||||||
# the answer (whether an admin exists) rarely changes and returning
|
|
||||||
# 429 breaks multi-tab / post-restart reconnection storms.
|
|
||||||
_SETUP_STATUS_CACHE: dict[str, tuple[float, dict]] = {}
|
|
||||||
_SETUP_STATUS_CACHE_TTL_SECONDS = 60
|
|
||||||
_MAX_TRACKED_SETUP_STATUS_IPS = 10000
|
|
||||||
_SETUP_STATUS_INFLIGHT: dict[str, asyncio.Task[dict]] = {}
|
|
||||||
_SETUP_STATUS_INFLIGHT_GUARD = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/setup-status")
|
@router.get("/setup-status")
|
||||||
async def setup_status(request: Request):
|
async def setup_status():
|
||||||
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
||||||
client_ip = _get_client_ip(request)
|
admin_count = await get_local_provider().count_admin_users()
|
||||||
now = time.time()
|
return {"needs_setup": admin_count == 0}
|
||||||
|
|
||||||
# Return cached result when within TTL — avoids 429 on multi-tab reconnection.
|
|
||||||
cached = _SETUP_STATUS_CACHE.get(client_ip)
|
|
||||||
if cached is not None:
|
|
||||||
cached_time, cached_result = cached
|
|
||||||
if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS:
|
|
||||||
return cached_result
|
|
||||||
|
|
||||||
async with _SETUP_STATUS_INFLIGHT_GUARD:
|
|
||||||
# Recheck cache after waiting for the inflight guard.
|
|
||||||
now = time.time()
|
|
||||||
cached = _SETUP_STATUS_CACHE.get(client_ip)
|
|
||||||
if cached is not None:
|
|
||||||
cached_time, cached_result = cached
|
|
||||||
if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS:
|
|
||||||
return cached_result
|
|
||||||
|
|
||||||
task = _SETUP_STATUS_INFLIGHT.get(client_ip)
|
|
||||||
if task is None:
|
|
||||||
# Evict stale entries when dict grows too large to bound memory usage.
|
|
||||||
if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS:
|
|
||||||
cutoff = now - _SETUP_STATUS_CACHE_TTL_SECONDS
|
|
||||||
stale = [k for k, (t, _) in _SETUP_STATUS_CACHE.items() if t < cutoff]
|
|
||||||
for k in stale:
|
|
||||||
del _SETUP_STATUS_CACHE[k]
|
|
||||||
if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS:
|
|
||||||
by_time = sorted(_SETUP_STATUS_CACHE.items(), key=lambda entry: entry[1][0])
|
|
||||||
for k, _ in by_time[: len(by_time) // 2]:
|
|
||||||
del _SETUP_STATUS_CACHE[k]
|
|
||||||
|
|
||||||
async def _compute_setup_status() -> dict:
|
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
|
||||||
return {"needs_setup": admin_count == 0}
|
|
||||||
|
|
||||||
task = asyncio.create_task(_compute_setup_status())
|
|
||||||
_SETUP_STATUS_INFLIGHT[client_ip] = task
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await task
|
|
||||||
finally:
|
|
||||||
async with _SETUP_STATUS_INFLIGHT_GUARD:
|
|
||||||
if _SETUP_STATUS_INFLIGHT.get(client_ip) is task:
|
|
||||||
del _SETUP_STATUS_INFLIGHT[client_ip]
|
|
||||||
|
|
||||||
# Cache only the stable "initialized" result to avoid stale setup redirects.
|
|
||||||
if result["needs_setup"] is False:
|
|
||||||
_SETUP_STATUS_CACHE[client_ip] = (time.time(), result)
|
|
||||||
else:
|
|
||||||
_SETUP_STATUS_CACHE.pop(client_ip, None)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class InitializeAdminRequest(BaseModel):
|
class InitializeAdminRequest(BaseModel):
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.deps import get_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["models"])
|
router = APIRouter(prefix="/api", tags=["models"])
|
||||||
|
|
||||||
@@ -18,17 +17,10 @@ class ModelResponse(BaseModel):
|
|||||||
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
|
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
|
||||||
|
|
||||||
|
|
||||||
class TokenUsageResponse(BaseModel):
|
|
||||||
"""Token usage display configuration."""
|
|
||||||
|
|
||||||
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsListResponse(BaseModel):
|
class ModelsListResponse(BaseModel):
|
||||||
"""Response model for listing all models."""
|
"""Response model for listing all models."""
|
||||||
|
|
||||||
models: list[ModelResponse]
|
models: list[ModelResponse]
|
||||||
token_usage: TokenUsageResponse
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -37,14 +29,14 @@ class ModelsListResponse(BaseModel):
|
|||||||
summary="List All Models",
|
summary="List All Models",
|
||||||
description="Retrieve a list of all available AI models configured in the system.",
|
description="Retrieve a list of all available AI models configured in the system.",
|
||||||
)
|
)
|
||||||
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
|
async def list_models() -> ModelsListResponse:
|
||||||
"""List all available models from configuration.
|
"""List all available models from configuration.
|
||||||
|
|
||||||
Returns model information suitable for frontend display,
|
Returns model information suitable for frontend display,
|
||||||
excluding sensitive fields like API keys and internal configuration.
|
excluding sensitive fields like API keys and internal configuration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of all configured models with their metadata and token usage display settings.
|
A list of all configured models with their metadata.
|
||||||
|
|
||||||
Example Response:
|
Example Response:
|
||||||
```json
|
```json
|
||||||
@@ -52,27 +44,21 @@ async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResp
|
|||||||
"models": [
|
"models": [
|
||||||
{
|
{
|
||||||
"name": "gpt-4",
|
"name": "gpt-4",
|
||||||
"model": "gpt-4",
|
|
||||||
"display_name": "GPT-4",
|
"display_name": "GPT-4",
|
||||||
"description": "OpenAI GPT-4 model",
|
"description": "OpenAI GPT-4 model",
|
||||||
"supports_thinking": false,
|
"supports_thinking": false
|
||||||
"supports_reasoning_effort": false
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "claude-3-opus",
|
"name": "claude-3-opus",
|
||||||
"model": "claude-3-opus",
|
|
||||||
"display_name": "Claude 3 Opus",
|
"display_name": "Claude 3 Opus",
|
||||||
"description": "Anthropic Claude 3 Opus model",
|
"description": "Anthropic Claude 3 Opus model",
|
||||||
"supports_thinking": true,
|
"supports_thinking": true
|
||||||
"supports_reasoning_effort": false
|
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"token_usage": {
|
|
||||||
"enabled": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
config = get_app_config()
|
||||||
models = [
|
models = [
|
||||||
ModelResponse(
|
ModelResponse(
|
||||||
name=model.name,
|
name=model.name,
|
||||||
@@ -84,10 +70,7 @@ async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResp
|
|||||||
)
|
)
|
||||||
for model in config.models
|
for model in config.models
|
||||||
]
|
]
|
||||||
return ModelsListResponse(
|
return ModelsListResponse(models=models)
|
||||||
models=models,
|
|
||||||
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -96,7 +79,7 @@ async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResp
|
|||||||
summary="Get Model Details",
|
summary="Get Model Details",
|
||||||
description="Retrieve detailed information about a specific AI model by its name.",
|
description="Retrieve detailed information about a specific AI model by its name.",
|
||||||
)
|
)
|
||||||
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
|
async def get_model(model_name: str) -> ModelResponse:
|
||||||
"""Get a specific model by name.
|
"""Get a specific model by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -118,6 +101,7 @@ async def get_model(model_name: str, config: AppConfig = Depends(get_config)) ->
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
config = get_app_config()
|
||||||
model = config.get_model_config(model_name)
|
model = config.get_model_config(model_name)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|||||||
@@ -1,20 +1,29 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.deps import get_config
|
|
||||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||||
from deerflow.skills import Skill
|
from deerflow.skills import Skill, load_skills
|
||||||
from deerflow.skills.installer import SkillAlreadyExistsError
|
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||||
|
from deerflow.skills.manager import (
|
||||||
|
append_history,
|
||||||
|
atomic_write,
|
||||||
|
custom_skill_exists,
|
||||||
|
ensure_custom_skill_is_editable,
|
||||||
|
get_custom_skill_dir,
|
||||||
|
get_custom_skill_file,
|
||||||
|
get_skill_history_file,
|
||||||
|
read_custom_skill_content,
|
||||||
|
read_history,
|
||||||
|
validate_skill_markdown_content,
|
||||||
|
)
|
||||||
from deerflow.skills.security_scanner import scan_skill_content
|
from deerflow.skills.security_scanner import scan_skill_content
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
|
||||||
from deerflow.skills.types import SKILL_MD_FILE, SkillCategory
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +36,7 @@ class SkillResponse(BaseModel):
|
|||||||
name: str = Field(..., description="Name of the skill")
|
name: str = Field(..., description="Name of the skill")
|
||||||
description: str = Field(..., description="Description of what the skill does")
|
description: str = Field(..., description="Description of what the skill does")
|
||||||
license: str | None = Field(None, description="License information")
|
license: str | None = Field(None, description="License information")
|
||||||
category: SkillCategory = Field(..., description="Category of the skill (public or custom)")
|
category: str = Field(..., description="Category of the skill (public or custom)")
|
||||||
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
||||||
|
|
||||||
|
|
||||||
@@ -91,9 +100,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
|
|||||||
summary="List All Skills",
|
summary="List All Skills",
|
||||||
description="Retrieve a list of all available skills from both public and custom directories.",
|
description="Retrieve a list of all available skills from both public and custom directories.",
|
||||||
)
|
)
|
||||||
async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
async def list_skills() -> SkillsListResponse:
|
||||||
try:
|
try:
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
skills = load_skills(enabled_only=False)
|
||||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
||||||
@@ -106,10 +115,10 @@ async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResp
|
|||||||
summary="Install Skill",
|
summary="Install Skill",
|
||||||
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
|
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
|
||||||
)
|
)
|
||||||
async def install_skill(request: SkillInstallRequest, config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
|
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||||
try:
|
try:
|
||||||
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
||||||
result = await get_or_new_skill_storage(app_config=config).ainstall_skill_from_archive(skill_file_path)
|
result = install_skill_from_archive(skill_file_path)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return SkillInstallResponse(**result)
|
return SkillInstallResponse(**result)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -126,9 +135,9 @@ async def install_skill(request: SkillInstallRequest, config: AppConfig = Depend
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
||||||
async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
async def list_custom_skills() -> SkillsListResponse:
|
||||||
try:
|
try:
|
||||||
skills = [skill for skill in get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False) if skill.category == SkillCategory.CUSTOM]
|
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
||||||
@@ -136,14 +145,13 @@ async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsL
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
||||||
async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
skills = load_skills(enabled_only=False)
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
|
||||||
skill = next((s for s in skills if s.name == skill_name and s.category == SkillCategory.CUSTOM), None)
|
|
||||||
if skill is None:
|
if skill is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=get_or_new_skill_storage(app_config=config).read_custom_skill(skill_name))
|
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -152,31 +160,30 @@ async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_conf
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
||||||
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
ensure_custom_skill_is_editable(skill_name)
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
validate_skill_markdown_content(skill_name, request.content)
|
||||||
storage.ensure_custom_skill_is_editable(skill_name)
|
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||||
storage.validate_skill_markdown_content(skill_name, request.content)
|
|
||||||
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
|
|
||||||
if scan.decision == "block":
|
if scan.decision == "block":
|
||||||
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
||||||
prev_content = storage.read_custom_skill(skill_name)
|
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
|
||||||
storage.write_custom_skill(skill_name, SKILL_MD_FILE, request.content)
|
prev_content = skill_file.read_text(encoding="utf-8")
|
||||||
storage.append_history(
|
atomic_write(skill_file, request.content)
|
||||||
|
append_history(
|
||||||
skill_name,
|
skill_name,
|
||||||
{
|
{
|
||||||
"action": "human_edit",
|
"action": "human_edit",
|
||||||
"author": "human",
|
"author": "human",
|
||||||
"thread_id": None,
|
"thread_id": None,
|
||||||
"file_path": SKILL_MD_FILE,
|
"file_path": "SKILL.md",
|
||||||
"prev_content": prev_content,
|
"prev_content": prev_content,
|
||||||
"new_content": request.content,
|
"new_content": request.content,
|
||||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return await get_custom_skill(skill_name, config)
|
return await get_custom_skill(skill_name)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -189,22 +196,24 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
||||||
async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
|
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
ensure_custom_skill_is_editable(skill_name)
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
skill_dir = get_custom_skill_dir(skill_name)
|
||||||
storage.delete_custom_skill(
|
prev_content = read_custom_skill_content(skill_name)
|
||||||
|
append_history(
|
||||||
skill_name,
|
skill_name,
|
||||||
history_meta={
|
{
|
||||||
"action": "human_delete",
|
"action": "human_delete",
|
||||||
"author": "human",
|
"author": "human",
|
||||||
"thread_id": None,
|
"thread_id": None,
|
||||||
"file_path": SKILL_MD_FILE,
|
"file_path": "SKILL.md",
|
||||||
"prev_content": None,
|
"prev_content": prev_content,
|
||||||
"new_content": None,
|
"new_content": None,
|
||||||
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
shutil.rmtree(skill_dir)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -217,13 +226,11 @@ async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_c
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
||||||
async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
|
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
|
||||||
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
|
|
||||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||||
return CustomSkillHistoryResponse(history=storage.read_history(skill_name))
|
return CustomSkillHistoryResponse(history=read_history(skill_name))
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -232,39 +239,38 @@ async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
||||||
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
|
||||||
try:
|
try:
|
||||||
storage = get_or_new_skill_storage(app_config=config)
|
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||||
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
|
|
||||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||||
history = storage.read_history(skill_name)
|
history = read_history(skill_name)
|
||||||
if not history:
|
if not history:
|
||||||
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
||||||
record = history[request.history_index]
|
record = history[request.history_index]
|
||||||
target_content = record.get("prev_content")
|
target_content = record.get("prev_content")
|
||||||
if target_content is None:
|
if target_content is None:
|
||||||
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
||||||
storage.validate_skill_markdown_content(skill_name, target_content)
|
validate_skill_markdown_content(skill_name, target_content)
|
||||||
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
|
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||||
skill_file = storage.get_custom_skill_file(skill_name)
|
skill_file = get_custom_skill_file(skill_name)
|
||||||
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
||||||
history_entry = {
|
history_entry = {
|
||||||
"action": "rollback",
|
"action": "rollback",
|
||||||
"author": "human",
|
"author": "human",
|
||||||
"thread_id": None,
|
"thread_id": None,
|
||||||
"file_path": SKILL_MD_FILE,
|
"file_path": "SKILL.md",
|
||||||
"prev_content": current_content,
|
"prev_content": current_content,
|
||||||
"new_content": target_content,
|
"new_content": target_content,
|
||||||
"rollback_from_ts": record.get("ts"),
|
"rollback_from_ts": record.get("ts"),
|
||||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||||
}
|
}
|
||||||
if scan.decision == "block":
|
if scan.decision == "block":
|
||||||
storage.append_history(skill_name, history_entry)
|
append_history(skill_name, history_entry)
|
||||||
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
||||||
storage.write_custom_skill(skill_name, SKILL_MD_FILE, target_content)
|
atomic_write(skill_file, target_content)
|
||||||
storage.append_history(skill_name, history_entry)
|
append_history(skill_name, history_entry)
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
return await get_custom_skill(skill_name, config)
|
return await get_custom_skill(skill_name)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@@ -284,10 +290,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest,
|
|||||||
summary="Get Skill Details",
|
summary="Get Skill Details",
|
||||||
description="Retrieve detailed information about a specific skill by its name.",
|
description="Retrieve detailed information about a specific skill by its name.",
|
||||||
)
|
)
|
||||||
async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
|
async def get_skill(skill_name: str) -> SkillResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
skills = load_skills(enabled_only=False)
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
|
||||||
skill = next((s for s in skills if s.name == skill_name), None)
|
skill = next((s for s in skills if s.name == skill_name), None)
|
||||||
|
|
||||||
if skill is None:
|
if skill is None:
|
||||||
@@ -307,10 +312,9 @@ async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) ->
|
|||||||
summary="Update Skill",
|
summary="Update Skill",
|
||||||
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
||||||
)
|
)
|
||||||
async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
|
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
|
||||||
try:
|
try:
|
||||||
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
|
skills = load_skills(enabled_only=False)
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
|
||||||
skill = next((s for s in skills if s.name == skill_name), None)
|
skill = next((s for s in skills if s.name == skill_name), None)
|
||||||
|
|
||||||
if skill is None:
|
if skill is None:
|
||||||
@@ -336,7 +340,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest, config: App
|
|||||||
reload_extensions_config()
|
reload_extensions_config()
|
||||||
await refresh_skills_system_prompt_cache_async()
|
await refresh_skills_system_prompt_cache_async()
|
||||||
|
|
||||||
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
|
skills = load_skills(enabled_only=False)
|
||||||
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
||||||
|
|
||||||
if updated_skill is None:
|
if updated_skill is None:
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Request
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_config
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -102,12 +100,7 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
|||||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
)
|
)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def generate_suggestions(
|
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||||
thread_id: str,
|
|
||||||
body: SuggestionsRequest,
|
|
||||||
request: Request,
|
|
||||||
config: AppConfig = Depends(get_config),
|
|
||||||
) -> SuggestionsResponse:
|
|
||||||
if not body.messages:
|
if not body.messages:
|
||||||
return SuggestionsResponse(suggestions=[])
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
@@ -129,8 +122,8 @@ async def generate_suggestions(
|
|||||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
|
model = create_chat_model(name=body.model_name, thinking_enabled=False)
|
||||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
|
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||||
raw = _extract_response_text(response.content)
|
raw = _extract_response_text(response.content)
|
||||||
suggestions = _parse_json_string_list(raw) or []
|
suggestions = _parse_json_string_list(raw) or []
|
||||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from pydantic import BaseModel, Field
|
|||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.services import sse_consumer, start_run
|
from app.gateway.services import sse_consumer, start_run
|
||||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||||
@@ -68,38 +68,11 @@ class RunResponse(BaseModel):
|
|||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
|
||||||
tokens: int = 0
|
|
||||||
runs: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageCallerBreakdown(BaseModel):
|
|
||||||
lead_agent: int = 0
|
|
||||||
subagent: int = 0
|
|
||||||
middleware: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageResponse(BaseModel):
|
|
||||||
thread_id: str
|
|
||||||
total_tokens: int = 0
|
|
||||||
total_input_tokens: int = 0
|
|
||||||
total_output_tokens: int = 0
|
|
||||||
total_runs: int = 0
|
|
||||||
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
|
|
||||||
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _cancel_conflict_detail(run_id: str, record: RunRecord) -> str:
|
|
||||||
if record.status in (RunStatus.pending, RunStatus.running):
|
|
||||||
return f"Run {run_id} is not active on this worker and cannot be cancelled"
|
|
||||||
return f"Run {run_id} is not cancellable (status: {record.status.value})"
|
|
||||||
|
|
||||||
|
|
||||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||||
return RunResponse(
|
return RunResponse(
|
||||||
run_id=record.run_id,
|
run_id=record.run_id,
|
||||||
@@ -186,8 +159,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
|||||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||||
"""List all runs for a thread."""
|
"""List all runs for a thread."""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
user_id = await get_current_user(request)
|
records = await run_mgr.list_by_thread(thread_id)
|
||||||
records = await run_mgr.list_by_thread(thread_id, user_id=user_id)
|
|
||||||
return [_record_to_response(r) for r in records]
|
return [_record_to_response(r) for r in records]
|
||||||
|
|
||||||
|
|
||||||
@@ -196,8 +168,7 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
|||||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||||
"""Get details of a specific run."""
|
"""Get details of a specific run."""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
user_id = await get_current_user(request)
|
record = run_mgr.get(run_id)
|
||||||
record = await run_mgr.get(run_id, user_id=user_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
return _record_to_response(record)
|
return _record_to_response(record)
|
||||||
@@ -220,13 +191,16 @@ async def cancel_run(
|
|||||||
- wait=false: Return immediately with 202
|
- wait=false: Return immediately with 202
|
||||||
"""
|
"""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
record = await run_mgr.get(run_id)
|
record = run_mgr.get(run_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||||
if not cancelled:
|
if not cancelled:
|
||||||
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
||||||
|
)
|
||||||
|
|
||||||
if wait and record.task is not None:
|
if wait and record.task is not None:
|
||||||
try:
|
try:
|
||||||
@@ -242,14 +216,12 @@ async def cancel_run(
|
|||||||
@require_permission("runs", "read", owner_check=True)
|
@require_permission("runs", "read", owner_check=True)
|
||||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||||
"""Join an existing run's SSE stream."""
|
"""Join an existing run's SSE stream."""
|
||||||
|
bridge = get_stream_bridge(request)
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
record = await run_mgr.get(run_id)
|
record = run_mgr.get(run_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
if record.store_only:
|
|
||||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
|
||||||
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
sse_consumer(bridge, record, request, run_mgr),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
@@ -278,18 +250,14 @@ async def stream_existing_run(
|
|||||||
remaining buffered events so the client observes a clean shutdown.
|
remaining buffered events so the client observes a clean shutdown.
|
||||||
"""
|
"""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
record = await run_mgr.get(run_id)
|
record = run_mgr.get(run_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
if record.store_only and action is None:
|
|
||||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
|
||||||
|
|
||||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
# Cancel if an action was requested (stop-button / interrupt flow)
|
||||||
if action is not None:
|
if action is not None:
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||||
if not cancelled:
|
if cancelled and wait and record.task is not None:
|
||||||
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
|
||||||
if wait and record.task is not None:
|
|
||||||
try:
|
try:
|
||||||
await record.task
|
await record.task
|
||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
@@ -400,10 +368,10 @@ async def list_run_events(
|
|||||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
@router.get("/{thread_id}/token-usage")
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||||
"""Thread-level token usage aggregation."""
|
"""Thread-level token usage aggregation."""
|
||||||
run_store = get_run_store(request)
|
run_store = get_run_store(request)
|
||||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
return {"thread_id": thread_id, **agg}
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ matching the LangGraph Platform wire format expected by the
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
@@ -26,7 +27,6 @@ from app.gateway.utils import sanitize_log_param
|
|||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.utils.time import coerce_iso, now_iso
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||||
@@ -90,28 +90,6 @@ class ThreadSearchRequest(BaseModel):
|
|||||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||||
status: str | None = Field(default=None, description="Filter by thread status")
|
status: str | None = Field(default=None, description="Filter by thread status")
|
||||||
|
|
||||||
@field_validator("metadata")
|
|
||||||
@classmethod
|
|
||||||
def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Reject filter entries the SQL backend cannot compile.
|
|
||||||
|
|
||||||
Enforces consistent behaviour across SQL and memory backends.
|
|
||||||
See ``deerflow.persistence.json_compat`` for the shared validators.
|
|
||||||
"""
|
|
||||||
if not v:
|
|
||||||
return v
|
|
||||||
from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value
|
|
||||||
|
|
||||||
bad_entries: list[str] = []
|
|
||||||
for key, value in v.items():
|
|
||||||
if not validate_metadata_filter_key(key):
|
|
||||||
bad_entries.append(f"{key!r} (unsafe key)")
|
|
||||||
elif not validate_metadata_filter_value(value):
|
|
||||||
bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})")
|
|
||||||
if bad_entries:
|
|
||||||
raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateResponse(BaseModel):
|
class ThreadStateResponse(BaseModel):
|
||||||
"""Response model for thread state."""
|
"""Response model for thread state."""
|
||||||
@@ -256,7 +234,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_store = get_thread_store(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = now_iso()
|
now = time.time()
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
@@ -266,8 +244,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status=existing_record.get("status", "idle"),
|
status=existing_record.get("status", "idle"),
|
||||||
created_at=coerce_iso(existing_record.get("created_at", "")),
|
created_at=str(existing_record.get("created_at", "")),
|
||||||
updated_at=coerce_iso(existing_record.get("updated_at", "")),
|
updated_at=str(existing_record.get("updated_at", "")),
|
||||||
metadata=existing_record.get("metadata", {}),
|
metadata=existing_record.get("metadata", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -285,6 +263,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
# Write an empty checkpoint so state endpoints work immediately
|
# Write an empty checkpoint so state endpoints work immediately
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
try:
|
try:
|
||||||
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
|
|
||||||
ckpt_metadata = {
|
ckpt_metadata = {
|
||||||
"step": -1,
|
"step": -1,
|
||||||
"source": "input",
|
"source": "input",
|
||||||
@@ -302,8 +282,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status="idle",
|
status="idle",
|
||||||
created_at=now,
|
created_at=str(now),
|
||||||
updated_at=now,
|
updated_at=str(now),
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -316,27 +296,20 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_store
|
||||||
from deerflow.persistence.thread_meta import InvalidMetadataFilterError
|
|
||||||
|
|
||||||
repo = get_thread_store(request)
|
repo = get_thread_store(request)
|
||||||
try:
|
rows = await repo.search(
|
||||||
rows = await repo.search(
|
metadata=body.metadata or None,
|
||||||
metadata=body.metadata or None,
|
status=body.status,
|
||||||
status=body.status,
|
limit=body.limit,
|
||||||
limit=body.limit,
|
offset=body.offset,
|
||||||
offset=body.offset,
|
)
|
||||||
)
|
|
||||||
except InvalidMetadataFilterError as exc:
|
|
||||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
||||||
return [
|
return [
|
||||||
ThreadResponse(
|
ThreadResponse(
|
||||||
thread_id=r["thread_id"],
|
thread_id=r["thread_id"],
|
||||||
status=r.get("status", "idle"),
|
status=r.get("status", "idle"),
|
||||||
# ``coerce_iso`` heals legacy unix-second values that
|
created_at=r.get("created_at", ""),
|
||||||
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
|
updated_at=r.get("updated_at", ""),
|
||||||
# SQL-backed rows already arrive as ISO strings and pass through.
|
|
||||||
created_at=coerce_iso(r.get("created_at", "")),
|
|
||||||
updated_at=coerce_iso(r.get("updated_at", "")),
|
|
||||||
metadata=r.get("metadata", {}),
|
metadata=r.get("metadata", {}),
|
||||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||||
interrupts={},
|
interrupts={},
|
||||||
@@ -368,8 +341,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
|||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status=record.get("status", "idle"),
|
status=record.get("status", "idle"),
|
||||||
created_at=coerce_iso(record.get("created_at", "")),
|
created_at=str(record.get("created_at", "")),
|
||||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
updated_at=str(record.get("updated_at", "")),
|
||||||
metadata=record.get("metadata", {}),
|
metadata=record.get("metadata", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -409,8 +382,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
record = {
|
record = {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"status": "idle",
|
"status": "idle",
|
||||||
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
|
"created_at": ckpt_meta.get("created_at", ""),
|
||||||
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
||||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,8 +397,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status=status,
|
status=status,
|
||||||
created_at=coerce_iso(record.get("created_at", "")),
|
created_at=str(record.get("created_at", "")),
|
||||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
updated_at=str(record.get("updated_at", "")),
|
||||||
metadata=record.get("metadata", {}),
|
metadata=record.get("metadata", {}),
|
||||||
values=serialize_channel_values(channel_values),
|
values=serialize_channel_values(channel_values),
|
||||||
)
|
)
|
||||||
@@ -476,10 +449,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
|||||||
values=values,
|
values=values,
|
||||||
next=next_tasks,
|
next=next_tasks,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
|
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||||
checkpoint_id=checkpoint_id,
|
checkpoint_id=checkpoint_id,
|
||||||
parent_checkpoint_id=parent_checkpoint_id,
|
parent_checkpoint_id=parent_checkpoint_id,
|
||||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
created_at=str(metadata.get("created_at", "")),
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -529,7 +502,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
channel_values.update(body.values)
|
channel_values.update(body.values)
|
||||||
|
|
||||||
checkpoint["channel_values"] = channel_values
|
checkpoint["channel_values"] = channel_values
|
||||||
metadata["updated_at"] = now_iso()
|
metadata["updated_at"] = time.time()
|
||||||
|
|
||||||
if body.as_node:
|
if body.as_node:
|
||||||
metadata["source"] = "update"
|
metadata["source"] = "update"
|
||||||
@@ -570,7 +543,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
next=[],
|
next=[],
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
checkpoint_id=new_checkpoint_id,
|
checkpoint_id=new_checkpoint_id,
|
||||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
created_at=str(metadata.get("created_at", "")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -637,7 +610,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
parent_checkpoint_id=parent_id,
|
parent_checkpoint_id=parent_id,
|
||||||
metadata=user_meta,
|
metadata=user_meta,
|
||||||
values=values,
|
values=values,
|
||||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
created_at=str(metadata.get("created_at", "")),
|
||||||
next=next_tasks,
|
next=next_tasks,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,26 +4,21 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_config
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
UnsafeUploadPathError,
|
|
||||||
claim_unique_filename,
|
|
||||||
delete_file_safe,
|
delete_file_safe,
|
||||||
enrich_file_listing,
|
enrich_file_listing,
|
||||||
ensure_uploads_dir,
|
ensure_uploads_dir,
|
||||||
get_uploads_dir,
|
get_uploads_dir,
|
||||||
list_files_in_dir,
|
list_files_in_dir,
|
||||||
normalize_filename,
|
normalize_filename,
|
||||||
open_upload_file_no_symlink,
|
|
||||||
upload_artifact_url,
|
upload_artifact_url,
|
||||||
upload_virtual_path,
|
upload_virtual_path,
|
||||||
)
|
)
|
||||||
@@ -33,11 +28,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
||||||
|
|
||||||
UPLOAD_CHUNK_SIZE = 8192
|
|
||||||
DEFAULT_MAX_FILES = 10
|
|
||||||
DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
|
|
||||||
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class UploadResponse(BaseModel):
|
class UploadResponse(BaseModel):
|
||||||
"""Response model for file upload."""
|
"""Response model for file upload."""
|
||||||
@@ -45,15 +35,6 @@ class UploadResponse(BaseModel):
|
|||||||
success: bool
|
success: bool
|
||||||
files: list[dict[str, str]]
|
files: list[dict[str, str]]
|
||||||
message: str
|
message: str
|
||||||
skipped_files: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class UploadLimits(BaseModel):
|
|
||||||
"""Application-level upload limits exposed to clients."""
|
|
||||||
|
|
||||||
max_files: int
|
|
||||||
max_file_size: int
|
|
||||||
max_total_size: int
|
|
||||||
|
|
||||||
|
|
||||||
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||||
@@ -74,188 +55,68 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
|
|
||||||
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
|
|
||||||
|
|
||||||
|
|
||||||
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
|
|
||||||
"""Read a value from the uploads config, supporting dict and attribute access."""
|
|
||||||
uploads_cfg = getattr(app_config, "uploads", None)
|
|
||||||
if isinstance(uploads_cfg, dict):
|
|
||||||
return uploads_cfg.get(key, default)
|
|
||||||
return getattr(uploads_cfg, key, default)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int:
|
|
||||||
try:
|
|
||||||
value = _get_uploads_config_value(app_config, key, None)
|
|
||||||
if value is None and legacy_key is not None:
|
|
||||||
value = _get_uploads_config_value(app_config, legacy_key, None)
|
|
||||||
if value is None:
|
|
||||||
value = default
|
|
||||||
limit = int(value)
|
|
||||||
if limit <= 0:
|
|
||||||
raise ValueError
|
|
||||||
return limit
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Invalid uploads.%s value; falling back to %d", key, default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _get_upload_limits(app_config: AppConfig) -> UploadLimits:
|
|
||||||
return UploadLimits(
|
|
||||||
max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"),
|
|
||||||
max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"),
|
|
||||||
max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
|
|
||||||
for path in reversed(paths):
|
|
||||||
try:
|
|
||||||
os.unlink(path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _write_upload_file_with_limits(
|
|
||||||
file: UploadFile,
|
|
||||||
*,
|
|
||||||
uploads_dir: os.PathLike[str] | str,
|
|
||||||
display_filename: str,
|
|
||||||
max_single_file_size: int,
|
|
||||||
max_total_size: int,
|
|
||||||
total_size: int,
|
|
||||||
) -> tuple[os.PathLike[str] | str, int, int]:
|
|
||||||
file_size = 0
|
|
||||||
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
|
|
||||||
try:
|
|
||||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
|
||||||
file_size += len(chunk)
|
|
||||||
total_size += len(chunk)
|
|
||||||
if file_size > max_single_file_size:
|
|
||||||
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
|
|
||||||
if total_size > max_total_size:
|
|
||||||
raise HTTPException(status_code=413, detail="Total upload size too large")
|
|
||||||
fh.write(chunk)
|
|
||||||
except Exception:
|
|
||||||
fh.close()
|
|
||||||
try:
|
|
||||||
os.unlink(file_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
fh.close()
|
|
||||||
return file_path, file_size, total_size
|
|
||||||
|
|
||||||
|
|
||||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
|
||||||
"""Return whether automatic host-side document conversion is enabled.
|
|
||||||
|
|
||||||
The secure default is disabled unless an operator explicitly opts in via
|
|
||||||
uploads.auto_convert_documents in config.yaml.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
|
|
||||||
if isinstance(raw, str):
|
|
||||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
|
||||||
return bool(raw)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
files: list[UploadFile] = File(...),
|
files: list[UploadFile] = File(...),
|
||||||
config: AppConfig = Depends(get_config),
|
|
||||||
) -> UploadResponse:
|
) -> UploadResponse:
|
||||||
"""Upload multiple files to a thread's uploads directory."""
|
"""Upload multiple files to a thread's uploads directory."""
|
||||||
if not files:
|
if not files:
|
||||||
raise HTTPException(status_code=400, detail="No files provided")
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
limits = _get_upload_limits(config)
|
|
||||||
if len(files) > limits.max_files:
|
|
||||||
raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
uploaded_files = []
|
uploaded_files = []
|
||||||
written_paths = []
|
|
||||||
sandbox_sync_targets = []
|
|
||||||
skipped_files = []
|
|
||||||
total_size = 0
|
|
||||||
# Track filenames within this request so duplicate form parts do not
|
|
||||||
# silently truncate each other. Existing uploads keep the historical
|
|
||||||
# overwrite behavior for a single replacement upload.
|
|
||||||
seen_filenames: set[str] = set()
|
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
sandbox = None
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
if sync_to_sandbox:
|
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
|
||||||
if sandbox is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
|
|
||||||
auto_convert_documents = _auto_convert_documents_enabled(config)
|
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
original_filename = normalize_filename(file.filename)
|
safe_filename = normalize_filename(file.filename)
|
||||||
safe_filename = claim_unique_filename(original_filename, seen_filenames)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_path, file_size, total_size = await _write_upload_file_with_limits(
|
content = await file.read()
|
||||||
file,
|
file_path = uploads_dir / safe_filename
|
||||||
uploads_dir=uploads_dir,
|
file_path.write_bytes(content)
|
||||||
display_filename=safe_filename,
|
|
||||||
max_single_file_size=limits.max_file_size,
|
|
||||||
max_total_size=limits.max_total_size,
|
|
||||||
total_size=total_size,
|
|
||||||
)
|
|
||||||
written_paths.append(file_path)
|
|
||||||
|
|
||||||
virtual_path = upload_virtual_path(safe_filename)
|
virtual_path = upload_virtual_path(safe_filename)
|
||||||
|
|
||||||
if sync_to_sandbox:
|
if sandbox_id != "local":
|
||||||
sandbox_sync_targets.append((file_path, virtual_path))
|
_make_file_sandbox_writable(file_path)
|
||||||
|
sandbox.update_file(virtual_path, content)
|
||||||
|
|
||||||
file_info = {
|
file_info = {
|
||||||
"filename": safe_filename,
|
"filename": safe_filename,
|
||||||
"size": str(file_size),
|
"size": str(len(content)),
|
||||||
"path": str(sandbox_uploads / safe_filename),
|
"path": str(sandbox_uploads / safe_filename),
|
||||||
"virtual_path": virtual_path,
|
"virtual_path": virtual_path,
|
||||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||||
}
|
}
|
||||||
if safe_filename != original_filename:
|
|
||||||
file_info["original_filename"] = original_filename
|
|
||||||
|
|
||||||
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
|
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||||
|
|
||||||
file_ext = file_path.suffix.lower()
|
file_ext = file_path.suffix.lower()
|
||||||
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
|
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||||
md_path = await convert_file_to_markdown(file_path)
|
md_path = await convert_file_to_markdown(file_path)
|
||||||
if md_path:
|
if md_path:
|
||||||
written_paths.append(md_path)
|
|
||||||
md_virtual_path = upload_virtual_path(md_path.name)
|
md_virtual_path = upload_virtual_path(md_path.name)
|
||||||
|
|
||||||
if sync_to_sandbox:
|
if sandbox_id != "local":
|
||||||
sandbox_sync_targets.append((md_path, md_virtual_path))
|
_make_file_sandbox_writable(md_path)
|
||||||
|
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||||
|
|
||||||
file_info["markdown_file"] = md_path.name
|
file_info["markdown_file"] = md_path.name
|
||||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||||
@@ -264,46 +125,17 @@ async def upload_files(
|
|||||||
|
|
||||||
uploaded_files.append(file_info)
|
uploaded_files.append(file_info)
|
||||||
|
|
||||||
except HTTPException as e:
|
|
||||||
_cleanup_uploaded_paths(written_paths)
|
|
||||||
raise e
|
|
||||||
except UnsafeUploadPathError as e:
|
|
||||||
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
|
|
||||||
skipped_files.append(safe_filename)
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||||
_cleanup_uploaded_paths(written_paths)
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||||
|
|
||||||
if sync_to_sandbox:
|
|
||||||
for file_path, virtual_path in sandbox_sync_targets:
|
|
||||||
_make_file_sandbox_writable(file_path)
|
|
||||||
sandbox.update_file(virtual_path, file_path.read_bytes())
|
|
||||||
|
|
||||||
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
|
|
||||||
if skipped_files:
|
|
||||||
message += f"; skipped {len(skipped_files)} unsafe file(s)"
|
|
||||||
|
|
||||||
return UploadResponse(
|
return UploadResponse(
|
||||||
success=not skipped_files,
|
success=True,
|
||||||
files=uploaded_files,
|
files=uploaded_files,
|
||||||
message=message,
|
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||||
skipped_files=skipped_files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/limits", response_model=UploadLimits)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_upload_limits(
|
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
config: AppConfig = Depends(get_config),
|
|
||||||
) -> UploadLimits:
|
|
||||||
"""Return upload limits used by the gateway for this thread."""
|
|
||||||
return _get_upload_limits(config)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=dict)
|
@router.get("/list", response_model=dict)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||||
|
|||||||
+33
-111
@@ -8,18 +8,17 @@ frames, and consuming stream bridge events. Router modules
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
END_SENTINEL,
|
END_SENTINEL,
|
||||||
HEARTBEAT_SENTINEL,
|
HEARTBEAT_SENTINEL,
|
||||||
@@ -99,70 +98,13 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|||||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
|
|
||||||
|
|
||||||
# Whitelist of run-context keys that the langgraph-compat layer forwards from
|
|
||||||
# ``body.context`` into the run config. ``config["context"]`` exists in
|
|
||||||
# LangGraph >=0.6, but these values must be written to both ``configurable``
|
|
||||||
# (for legacy ``_get_runtime_config`` consumers) and ``context`` because
|
|
||||||
# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to
|
|
||||||
# ``configurable`` for consumers like ``setup_agent``.
|
|
||||||
_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"model_name",
|
|
||||||
"mode",
|
|
||||||
"thinking_enabled",
|
|
||||||
"reasoning_effort",
|
|
||||||
"is_plan_mode",
|
|
||||||
"subagent_enabled",
|
|
||||||
"max_concurrent_subagents",
|
|
||||||
"agent_name",
|
|
||||||
"is_bootstrap",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None:
|
|
||||||
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
|
||||||
and ``config['context']`` so they are visible to legacy configurable readers and
|
|
||||||
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
|
||||||
see issue #2677)."""
|
|
||||||
if not context:
|
|
||||||
return
|
|
||||||
configurable = config.setdefault("configurable", {})
|
|
||||||
runtime_context = config.setdefault("context", {})
|
|
||||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
|
||||||
if key in context:
|
|
||||||
if isinstance(configurable, dict):
|
|
||||||
configurable.setdefault(key, context[key])
|
|
||||||
if isinstance(runtime_context, dict):
|
|
||||||
runtime_context.setdefault(key, context[key])
|
|
||||||
|
|
||||||
|
|
||||||
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
|
||||||
"""Stamp the authenticated user into the run context for background tools.
|
|
||||||
|
|
||||||
Tool execution may happen after the request handler has returned, so tools
|
|
||||||
that persist user-scoped files should not rely only on ambient ContextVars.
|
|
||||||
The value comes from server-side auth state, never from client context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = getattr(request.state, "user", None)
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
if user_id is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
runtime_context = config.setdefault("context", {})
|
|
||||||
if isinstance(runtime_context, dict):
|
|
||||||
runtime_context["user_id"] = str(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_agent_factory(assistant_id: str | None):
|
def resolve_agent_factory(assistant_id: str | None):
|
||||||
"""Resolve the agent factory callable from config.
|
"""Resolve the agent factory callable from config.
|
||||||
|
|
||||||
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
||||||
injected into ``configurable`` or ``context`` — see
|
injected into ``configurable`` — see :func:`build_run_config`. All
|
||||||
:func:`build_run_config`. All ``assistant_id`` values therefore map to the
|
``assistant_id`` values therefore map to the same factory; the routing
|
||||||
same factory; the routing happens inside ``make_lead_agent`` when it reads
|
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
|
||||||
``cfg["agent_name"]``.
|
|
||||||
"""
|
"""
|
||||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||||
|
|
||||||
@@ -179,12 +121,10 @@ def build_run_config(
|
|||||||
"""Build a RunnableConfig dict for the agent.
|
"""Build a RunnableConfig dict for the agent.
|
||||||
|
|
||||||
When *assistant_id* refers to a custom agent (anything other than
|
When *assistant_id* refers to a custom agent (anything other than
|
||||||
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
|
``"lead_agent"`` / ``None``), the name is forwarded as
|
||||||
whichever runtime options container is active: ``context`` for
|
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
|
||||||
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
|
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||||||
``make_lead_agent`` reads this key to load the matching
|
without it the agent silently runs as the default lead agent.
|
||||||
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
|
|
||||||
silently runs as the default lead agent.
|
|
||||||
|
|
||||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||||
@@ -203,14 +143,7 @@ def build_run_config(
|
|||||||
thread_id,
|
thread_id,
|
||||||
list(request_config.get("configurable", {}).keys()),
|
list(request_config.get("configurable", {}).keys()),
|
||||||
)
|
)
|
||||||
context_value = request_config["context"]
|
config["context"] = request_config["context"]
|
||||||
if context_value is None:
|
|
||||||
context = {}
|
|
||||||
elif isinstance(context_value, Mapping):
|
|
||||||
context = dict(context_value)
|
|
||||||
else:
|
|
||||||
raise ValueError("request config 'context' must be a mapping or null.")
|
|
||||||
config["context"] = context
|
|
||||||
else:
|
else:
|
||||||
configurable = {"thread_id": thread_id}
|
configurable = {"thread_id": thread_id}
|
||||||
configurable.update(request_config.get("configurable", {}))
|
configurable.update(request_config.get("configurable", {}))
|
||||||
@@ -222,19 +155,13 @@ def build_run_config(
|
|||||||
config["configurable"] = {"thread_id": thread_id}
|
config["configurable"] = {"thread_id": thread_id}
|
||||||
|
|
||||||
# Inject custom agent name when the caller specified a non-default assistant.
|
# Inject custom agent name when the caller specified a non-default assistant.
|
||||||
# Honour an explicit agent_name in the active runtime options container.
|
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
|
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
||||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
if "agent_name" not in config["configurable"]:
|
||||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||||
if "configurable" in config:
|
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||||
target = config["configurable"]
|
config["configurable"]["agent_name"] = normalized
|
||||||
elif "context" in config:
|
|
||||||
target = config["context"]
|
|
||||||
else:
|
|
||||||
target = config.setdefault("configurable", {})
|
|
||||||
if target is not None and "agent_name" not in target:
|
|
||||||
target["agent_name"] = normalized
|
|
||||||
if metadata:
|
if metadata:
|
||||||
config.setdefault("metadata", {}).update(metadata)
|
config.setdefault("metadata", {}).update(metadata)
|
||||||
return config
|
return config
|
||||||
@@ -268,23 +195,6 @@ async def start_run(
|
|||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||||
|
|
||||||
body_context = getattr(body, "context", None) or {}
|
|
||||||
model_name = body_context.get("model_name")
|
|
||||||
|
|
||||||
# Coerce non-string model_name values to str before truncation.
|
|
||||||
if model_name is not None and not isinstance(model_name, str):
|
|
||||||
model_name = str(model_name)
|
|
||||||
|
|
||||||
# Validate model against the allowlist when a model_name is provided.
|
|
||||||
if model_name:
|
|
||||||
app_config = get_app_config()
|
|
||||||
resolved = app_config.get_model_config(model_name)
|
|
||||||
if resolved is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -293,7 +203,6 @@ async def start_run(
|
|||||||
metadata=body.metadata or {},
|
metadata=body.metadata or {},
|
||||||
kwargs={"input": body.input, "config": body.config},
|
kwargs={"input": body.input, "config": body.config},
|
||||||
multitask_strategy=body.multitask_strategy,
|
multitask_strategy=body.multitask_strategy,
|
||||||
model_name=model_name,
|
|
||||||
)
|
)
|
||||||
except ConflictError as exc:
|
except ConflictError as exc:
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
@@ -320,12 +229,25 @@ async def start_run(
|
|||||||
graph_input = normalize_input(body.input)
|
graph_input = normalize_input(body.input)
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
# Merge DeerFlow-specific context overrides into configurable.
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
context = getattr(body, "context", None)
|
||||||
inject_authenticated_user_context(config, request)
|
if context:
|
||||||
|
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
}
|
||||||
|
configurable = config.setdefault("configurable", {})
|
||||||
|
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
|
|
||||||
|
|||||||
+13
-90
@@ -19,72 +19,24 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
try:
|
from deerflow.agents import make_lead_agent
|
||||||
from prompt_toolkit import PromptSession
|
|
||||||
from prompt_toolkit.history import InMemoryHistory
|
|
||||||
|
|
||||||
_HAS_PROMPT_TOOLKIT = True
|
|
||||||
except ImportError:
|
|
||||||
_HAS_PROMPT_TOOLKIT = False
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
_LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
logging.basicConfig(
|
||||||
_LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
def _setup_logging(log_level: int = logging.INFO) -> None:
|
)
|
||||||
"""Route logs to ``debug.log`` using *log_level* for the initial root/file setup.
|
|
||||||
|
|
||||||
This configures the root logger and the ``debug.log`` file handler so logs do
|
|
||||||
not print on the interactive console. It is idempotent: any pre-existing
|
|
||||||
handlers on the root logger (e.g. installed by ``logging.basicConfig`` in
|
|
||||||
transitively imported modules) are removed so the debug session output only
|
|
||||||
lands in ``debug.log``.
|
|
||||||
|
|
||||||
Note: later config-driven logging adjustments may change named logger
|
|
||||||
verbosity without raising the root logger or file-handler thresholds set
|
|
||||||
here, so the eventual contents of ``debug.log`` may not be filtered solely by
|
|
||||||
this function's ``log_level`` argument.
|
|
||||||
"""
|
|
||||||
root = logging.root
|
|
||||||
for h in list(root.handlers):
|
|
||||||
root.removeHandler(h)
|
|
||||||
h.close()
|
|
||||||
root.setLevel(log_level)
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler("debug.log", mode="a", encoding="utf-8")
|
|
||||||
file_handler.setLevel(log_level)
|
|
||||||
file_handler.setFormatter(logging.Formatter(_LOG_FMT, datefmt=_LOG_DATEFMT))
|
|
||||||
root.addHandler(file_handler)
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Install file logging first so warnings emitted while loading config do not
|
|
||||||
# leak onto the interactive terminal via Python's lastResort handler.
|
|
||||||
_setup_logging()
|
|
||||||
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
from deerflow.config.app_config import apply_logging_level
|
|
||||||
|
|
||||||
app_config = get_app_config()
|
|
||||||
apply_logging_level(app_config.log_level)
|
|
||||||
|
|
||||||
# Delay the rest of the deerflow imports until *after* logging is installed
|
|
||||||
# so that any import-time side effects (e.g. deerflow.agents starts a
|
|
||||||
# background skill-loader thread on import) emit logs to debug.log instead
|
|
||||||
# of leaking onto the interactive terminal via Python's lastResort handler.
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
|
|
||||||
from deerflow.agents import make_lead_agent
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
from deerflow.mcp import initialize_mcp_tools
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
# Initialize MCP tools at startup
|
# Initialize MCP tools at startup
|
||||||
try:
|
try:
|
||||||
|
from deerflow.mcp import initialize_mcp_tools
|
||||||
|
|
||||||
await initialize_mcp_tools()
|
await initialize_mcp_tools()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to initialize MCP tools: {e}")
|
print(f"Warning: Failed to initialize MCP tools: {e}")
|
||||||
@@ -100,29 +52,16 @@ async def main():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime = Runtime(context={"thread_id": config["configurable"]["thread_id"]})
|
|
||||||
config["configurable"]["__pregel_runtime"] = runtime
|
|
||||||
|
|
||||||
agent = make_lead_agent(config)
|
agent = make_lead_agent(config)
|
||||||
|
|
||||||
session = PromptSession(history=InMemoryHistory()) if _HAS_PROMPT_TOOLKIT else None
|
|
||||||
|
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
print("Lead Agent Debug Mode")
|
print("Lead Agent Debug Mode")
|
||||||
print("Type 'quit' or 'exit' to stop")
|
print("Type 'quit' or 'exit' to stop")
|
||||||
print(f"Logs: debug.log (log_level={app_config.log_level})")
|
|
||||||
if not _HAS_PROMPT_TOOLKIT:
|
|
||||||
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
|
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
seen_artifacts: set[str] = set()
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if session:
|
user_input = input("\nYou: ").strip()
|
||||||
user_input = (await session.prompt_async("\nYou: ")).strip()
|
|
||||||
else:
|
|
||||||
user_input = input("\nYou: ").strip()
|
|
||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
if user_input.lower() in ("quit", "exit"):
|
if user_input.lower() in ("quit", "exit"):
|
||||||
@@ -131,31 +70,15 @@ async def main():
|
|||||||
|
|
||||||
# Invoke the agent
|
# Invoke the agent
|
||||||
state = {"messages": [HumanMessage(content=user_input)]}
|
state = {"messages": [HumanMessage(content=user_input)]}
|
||||||
result = await agent.ainvoke(state, config=config)
|
result = await agent.ainvoke(state, config=config, context={"thread_id": "debug-thread-001"})
|
||||||
|
|
||||||
# Print the response
|
# Print the response
|
||||||
if result.get("messages"):
|
if result.get("messages"):
|
||||||
last_message = result["messages"][-1]
|
last_message = result["messages"][-1]
|
||||||
print(f"\nAgent: {last_message.content}")
|
print(f"\nAgent: {last_message.content}")
|
||||||
|
|
||||||
# Show files presented to the user this turn (new artifacts only)
|
except KeyboardInterrupt:
|
||||||
artifacts = result.get("artifacts") or []
|
print("\nInterrupted. Goodbye!")
|
||||||
new_artifacts = [p for p in artifacts if p not in seen_artifacts]
|
|
||||||
if new_artifacts:
|
|
||||||
thread_id = config["configurable"]["thread_id"]
|
|
||||||
user_id = get_effective_user_id()
|
|
||||||
paths = get_paths()
|
|
||||||
print("\n[Presented files]")
|
|
||||||
for virtual in new_artifacts:
|
|
||||||
try:
|
|
||||||
physical = paths.resolve_virtual_path(thread_id, virtual, user_id=user_id)
|
|
||||||
print(f" - {virtual}\n → {physical}")
|
|
||||||
except ValueError as exc:
|
|
||||||
print(f" - {virtual} (failed to resolve physical path: {exc})")
|
|
||||||
seen_artifacts.update(new_artifacts)
|
|
||||||
|
|
||||||
except (KeyboardInterrupt, EOFError):
|
|
||||||
print("\nGoodbye!")
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError: {e}")
|
print(f"\nError: {e}")
|
||||||
|
|||||||
+35
-52
@@ -6,16 +6,16 @@ This document provides a complete reference for the DeerFlow backend APIs.
|
|||||||
|
|
||||||
DeerFlow backend exposes two sets of APIs:
|
DeerFlow backend exposes two sets of APIs:
|
||||||
|
|
||||||
1. **LangGraph-compatible API** - Agent interactions, threads, and streaming (`/api/langgraph/*`)
|
1. **LangGraph API** - Agent interactions, threads, and streaming (`/api/langgraph/*`)
|
||||||
2. **Gateway API** - Models, MCP, skills, uploads, and artifacts (`/api/*`)
|
2. **Gateway API** - Models, MCP, skills, uploads, and artifacts (`/api/*`)
|
||||||
|
|
||||||
All APIs are accessed through the Nginx reverse proxy at port 2026.
|
All APIs are accessed through the Nginx reverse proxy at port 2026.
|
||||||
|
|
||||||
## LangGraph-compatible API
|
## LangGraph API
|
||||||
|
|
||||||
Base URL: `/api/langgraph`
|
Base URL: `/api/langgraph`
|
||||||
|
|
||||||
The public LangGraph-compatible API follows LangGraph SDK conventions. In the unified nginx deployment, Gateway owns `/api/langgraph/*` and translates those paths to its native `/api/*` run, thread, and streaming routers.
|
The LangGraph API is provided by the LangGraph server and follows the LangGraph SDK conventions.
|
||||||
|
|
||||||
### Threads
|
### Threads
|
||||||
|
|
||||||
@@ -104,11 +104,17 @@ Content-Type: application/json
|
|||||||
**Recursion Limit:**
|
**Recursion Limit:**
|
||||||
|
|
||||||
`config.recursion_limit` caps the number of graph steps LangGraph will execute
|
`config.recursion_limit` caps the number of graph steps LangGraph will execute
|
||||||
in a single run. The unified Gateway path defaults to `100` in
|
in a single run. The `/api/langgraph/*` endpoints go straight to the LangGraph
|
||||||
`build_run_config` (see `backend/app/gateway/services.py`), which is a safer
|
server and therefore inherit LangGraph's native default of **25**, which is
|
||||||
starting point for plan-mode or subagent-heavy runs. Clients can still set
|
too low for plan-mode or subagent-heavy runs — the agent typically errors out
|
||||||
`recursion_limit` explicitly in the request body; increase it if you run deeply
|
with `GraphRecursionError` after the first round of subagent results comes
|
||||||
nested subagent graphs.
|
back, before the lead agent can synthesize the final answer.
|
||||||
|
|
||||||
|
DeerFlow's own Gateway and IM-channel paths mitigate this by defaulting to
|
||||||
|
`100` in `build_run_config` (see `backend/app/gateway/services.py`), but
|
||||||
|
clients calling the LangGraph API directly must set `recursion_limit`
|
||||||
|
explicitly in the request body. `100` matches the Gateway default and is a
|
||||||
|
safe starting point; increase it if you run deeply nested subagent graphs.
|
||||||
|
|
||||||
**Configurable Options:**
|
**Configurable Options:**
|
||||||
- `model_name` (string): Override the default model
|
- `model_name` (string): Override the default model
|
||||||
@@ -535,28 +541,14 @@ All APIs return errors in a consistent format:
|
|||||||
|
|
||||||
## Authentication
|
## Authentication
|
||||||
|
|
||||||
DeerFlow enforces authentication for all non-public HTTP routes. Public routes are limited to health/docs metadata and these public auth endpoints:
|
Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials.
|
||||||
|
|
||||||
- `POST /api/v1/auth/initialize` creates the first admin account when no admin exists.
|
Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers.
|
||||||
- `POST /api/v1/auth/login/local` logs in with email/password and sets an HttpOnly `access_token` cookie.
|
|
||||||
- `POST /api/v1/auth/register` creates a regular `user` account and sets the session cookie.
|
|
||||||
- `POST /api/v1/auth/logout` clears the session cookie.
|
|
||||||
- `GET /api/v1/auth/setup-status` reports whether the first admin still needs to be created.
|
|
||||||
|
|
||||||
The authenticated auth endpoints are:
|
For production deployments, it is recommended to:
|
||||||
|
1. Use Nginx for basic auth or OAuth integration
|
||||||
- `GET /api/v1/auth/me` returns the current user.
|
2. Deploy behind a VPN or private network
|
||||||
- `POST /api/v1/auth/change-password` changes password, optionally changes email during setup, increments `token_version`, and reissues the cookie.
|
3. Implement custom authentication middleware
|
||||||
|
|
||||||
Protected state-changing requests also require the CSRF double-submit token: send the `csrf_token` cookie value as the `X-CSRF-Token` header. Login/register/initialize/logout are bootstrap auth endpoints: they are exempt from the double-submit token but still reject hostile browser `Origin` headers.
|
|
||||||
|
|
||||||
User isolation is enforced from the authenticated user context:
|
|
||||||
|
|
||||||
- Thread metadata is scoped by `threads_meta.user_id`; search/read/write/delete APIs only expose the current user's threads.
|
|
||||||
- Thread files live under `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/` and are exposed inside the sandbox as `/mnt/user-data/`.
|
|
||||||
- Memory and custom agents are stored under `{base_dir}/users/{user_id}/...`.
|
|
||||||
|
|
||||||
Note: MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers; that is separate from DeerFlow API authentication.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -575,13 +567,12 @@ location /api/ {
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Streaming Support
|
## WebSocket Support
|
||||||
|
|
||||||
Gateway's LangGraph-compatible API streams run events with Server-Sent Events (SSE):
|
The LangGraph server supports WebSocket connections for real-time streaming. Connect to:
|
||||||
|
|
||||||
```http
|
```
|
||||||
POST /api/langgraph/threads/{thread_id}/runs/stream
|
ws://localhost:2026/api/langgraph/threads/{thread_id}/runs/stream
|
||||||
Accept: text/event-stream
|
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -617,21 +608,13 @@ const response = await fetch('/api/models');
|
|||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
console.log(data.models);
|
console.log(data.models);
|
||||||
|
|
||||||
// Create a run and stream SSE events
|
// Using EventSource for streaming
|
||||||
const streamResponse = await fetch(`/api/langgraph/threads/${threadId}/runs/stream`, {
|
const eventSource = new EventSource(
|
||||||
method: "POST",
|
`/api/langgraph/threads/${threadId}/runs/stream`
|
||||||
headers: {
|
);
|
||||||
"Content-Type": "application/json",
|
eventSource.onmessage = (event) => {
|
||||||
Accept: "text/event-stream",
|
console.log(JSON.parse(event.data));
|
||||||
},
|
};
|
||||||
body: JSON.stringify({
|
|
||||||
input: { messages: [{ role: "user", content: "Hello" }] },
|
|
||||||
stream_mode: ["values", "messages-tuple", "custom"],
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const reader = streamResponse.body?.getReader();
|
|
||||||
// Decode and parse SSE frames from reader in your client code.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### cURL Examples
|
### cURL Examples
|
||||||
@@ -666,7 +649,7 @@ curl -X POST http://localhost:2026/api/langgraph/threads/abc123/runs \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
> The unified Gateway path defaults `config.recursion_limit` to 100 for
|
> The `/api/langgraph/*` endpoints bypass DeerFlow's Gateway and inherit
|
||||||
> plan-mode and subagent-heavy runs. Clients may still set
|
> LangGraph's native `recursion_limit` default of 25, which is too low for
|
||||||
> `config.recursion_limit` explicitly — see the [Create Run](#create-run)
|
> plan-mode or subagent runs. Set `config.recursion_limit` explicitly — see
|
||||||
> section for details.
|
> the [Create Run](#create-run) section for details.
|
||||||
|
|||||||
@@ -14,28 +14,30 @@ This document provides a comprehensive overview of the DeerFlow backend architec
|
|||||||
│ Nginx (Port 2026) │
|
│ Nginx (Port 2026) │
|
||||||
│ Unified Reverse Proxy Entry Point │
|
│ Unified Reverse Proxy Entry Point │
|
||||||
│ ┌────────────────────────────────────────────────────────────────────┐ │
|
│ ┌────────────────────────────────────────────────────────────────────┐ │
|
||||||
│ │ /api/langgraph/* → Gateway LangGraph-compatible runtime (8001) │ │
|
│ │ /api/langgraph/* → LangGraph Server (2024) │ │
|
||||||
│ │ /api/* → Gateway REST APIs (8001) │ │
|
│ │ /api/* → Gateway API (8001) │ │
|
||||||
│ │ /* → Frontend (3000) │ │
|
│ │ /* → Frontend (3000) │ │
|
||||||
│ └────────────────────────────────────────────────────────────────────┘ │
|
│ └────────────────────────────────────────────────────────────────────┘ │
|
||||||
└─────────────────────────────────┬────────────────────────────────────────┘
|
└─────────────────────────────────┬────────────────────────────────────────┘
|
||||||
│
|
│
|
||||||
┌───────────────────────┴───────────────────────┐
|
┌───────────────────────┼───────────────────────┐
|
||||||
│ │
|
│ │ │
|
||||||
▼ ▼
|
▼ ▼ ▼
|
||||||
┌─────────────────────────────────────────────┐ ┌─────────────────────┐
|
┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐
|
||||||
│ Gateway API │ │ Frontend │
|
│ LangGraph Server │ │ Gateway API │ │ Frontend │
|
||||||
│ (Port 8001) │ │ (Port 3000) │
|
│ (Port 2024) │ │ (Port 8001) │ │ (Port 3000) │
|
||||||
│ │ │ │
|
│ │ │ │ │ │
|
||||||
│ - LangGraph-compatible runs/threads API │ │ - Next.js App │
|
│ - Agent Runtime │ │ - Models API │ │ - Next.js App │
|
||||||
│ - Embedded Agent Runtime │ │ - React UI │
|
│ - Thread Mgmt │ │ - MCP Config │ │ - React UI │
|
||||||
│ - SSE Streaming │ │ - Chat Interface │
|
│ - SSE Streaming │ │ - Skills Mgmt │ │ - Chat Interface │
|
||||||
│ - Checkpointing │ │ │
|
│ - Checkpointing │ │ - File Uploads │ │ │
|
||||||
│ - Models, MCP, Skills, Uploads, Artifacts │ │ │
|
│ │ │ - Thread Cleanup │ │ │
|
||||||
│ - Thread Cleanup │ │ │
|
│ │ │ - Artifacts │ │ │
|
||||||
└─────────────────────────────────────────────┘ └─────────────────────┘
|
└─────────────────────┘ └─────────────────────┘ └─────────────────────┘
|
||||||
│
|
│ │
|
||||||
▼
|
│ ┌─────────────────┘
|
||||||
|
│ │
|
||||||
|
▼ ▼
|
||||||
┌──────────────────────────────────────────────────────────────────────────┐
|
┌──────────────────────────────────────────────────────────────────────────┐
|
||||||
│ Shared Configuration │
|
│ Shared Configuration │
|
||||||
│ ┌─────────────────────────┐ ┌────────────────────────────────────────┐ │
|
│ ┌─────────────────────────┐ ┌────────────────────────────────────────┐ │
|
||||||
@@ -50,9 +52,9 @@ This document provides a comprehensive overview of the DeerFlow backend architec
|
|||||||
|
|
||||||
## Component Details
|
## Component Details
|
||||||
|
|
||||||
### Gateway Embedded Agent Runtime
|
### LangGraph Server
|
||||||
|
|
||||||
The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for robust multi-agent workflow orchestration. Nginx rewrites `/api/langgraph/*` to Gateway's native `/api/*` routes, so the public API remains compatible with LangGraph SDK clients without running a separate LangGraph server.
|
The LangGraph server is the core agent runtime, built on LangGraph for robust multi-agent workflow orchestration.
|
||||||
|
|
||||||
**Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent`
|
**Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent`
|
||||||
|
|
||||||
@@ -63,8 +65,7 @@ The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for
|
|||||||
- Tool execution orchestration
|
- Tool execution orchestration
|
||||||
- SSE streaming for real-time responses
|
- SSE streaming for real-time responses
|
||||||
|
|
||||||
**Graph registry**: `langgraph.json` remains available for tooling, Studio, or direct LangGraph Server compatibility.
|
**Configuration**: `langgraph.json`
|
||||||
It is not the default service entrypoint; scripts and Docker deployments run the Gateway embedded runtime.
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -77,13 +78,12 @@ It is not the default service entrypoint; scripts and Docker deployments run the
|
|||||||
|
|
||||||
### Gateway API
|
### Gateway API
|
||||||
|
|
||||||
FastAPI application providing REST endpoints plus the public LangGraph-compatible `/api/langgraph/*` runtime routes.
|
FastAPI application providing REST endpoints for non-agent operations.
|
||||||
|
|
||||||
**Entry Point**: `app/gateway/app.py`
|
**Entry Point**: `app/gateway/app.py`
|
||||||
|
|
||||||
**Routers**:
|
**Routers**:
|
||||||
- `models.py` - `/api/models` - Model listing and details
|
- `models.py` - `/api/models` - Model listing and details
|
||||||
- `thread_runs.py` / `runs.py` - `/api/threads/{id}/runs`, `/api/runs/*` - LangGraph-compatible runs and streaming
|
|
||||||
- `mcp.py` - `/api/mcp` - MCP server configuration
|
- `mcp.py` - `/api/mcp` - MCP server configuration
|
||||||
- `skills.py` - `/api/skills` - Skills management
|
- `skills.py` - `/api/skills` - Skills management
|
||||||
- `uploads.py` - `/api/threads/{id}/uploads` - File upload
|
- `uploads.py` - `/api/threads/{id}/uploads` - File upload
|
||||||
@@ -91,7 +91,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl
|
|||||||
- `artifacts.py` - `/api/threads/{id}/artifacts` - Artifact serving
|
- `artifacts.py` - `/api/threads/{id}/artifacts` - Artifact serving
|
||||||
- `suggestions.py` - `/api/threads/{id}/suggestions` - Follow-up suggestion generation
|
- `suggestions.py` - `/api/threads/{id}/suggestions` - Follow-up suggestion generation
|
||||||
|
|
||||||
The web conversation delete flow first deletes Gateway-managed thread state through the LangGraph-compatible route, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`.
|
The web conversation delete flow is now split across both backend surfaces: LangGraph handles `DELETE /api/langgraph/threads/{thread_id}` for thread state, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`.
|
||||||
|
|
||||||
### Agent Architecture
|
### Agent Architecture
|
||||||
|
|
||||||
@@ -199,7 +199,7 @@ class ThreadState(AgentState):
|
|||||||
│ Built-in Tools │ │ Configured Tools │ │ MCP Tools │
|
│ Built-in Tools │ │ Configured Tools │ │ MCP Tools │
|
||||||
│ (packages/harness/deerflow/tools/) │ │ (config.yaml) │ │ (extensions.json) │
|
│ (packages/harness/deerflow/tools/) │ │ (config.yaml) │ │ (extensions.json) │
|
||||||
├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤
|
├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤
|
||||||
│ - present_files │ │ - web_search │ │ - github │
|
│ - present_file │ │ - web_search │ │ - github │
|
||||||
│ - ask_clarification │ │ - web_fetch │ │ - filesystem │
|
│ - ask_clarification │ │ - web_fetch │ │ - filesystem │
|
||||||
│ - view_image │ │ - bash │ │ - postgres │
|
│ - view_image │ │ - bash │ │ - postgres │
|
||||||
│ │ │ - read_file │ │ - brave-search │
|
│ │ │ - read_file │ │ - brave-search │
|
||||||
@@ -353,10 +353,10 @@ SKILL.md Format:
|
|||||||
POST /api/langgraph/threads/{thread_id}/runs
|
POST /api/langgraph/threads/{thread_id}/runs
|
||||||
{"input": {"messages": [{"role": "user", "content": "Hello"}]}}
|
{"input": {"messages": [{"role": "user", "content": "Hello"}]}}
|
||||||
|
|
||||||
2. Nginx → Gateway API (8001)
|
2. Nginx → LangGraph Server (2024)
|
||||||
`/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes
|
Proxied to LangGraph server
|
||||||
|
|
||||||
3. Gateway embedded runtime
|
3. LangGraph Server
|
||||||
a. Load/create thread state
|
a. Load/create thread state
|
||||||
b. Execute middleware chain:
|
b. Execute middleware chain:
|
||||||
- ThreadDataMiddleware: Set up paths
|
- ThreadDataMiddleware: Set up paths
|
||||||
@@ -412,7 +412,7 @@ SKILL.md Format:
|
|||||||
### Thread Cleanup Flow
|
### Thread Cleanup Flow
|
||||||
|
|
||||||
```
|
```
|
||||||
1. Client deletes conversation via the LangGraph-compatible Gateway route
|
1. Client deletes conversation via LangGraph
|
||||||
DELETE /api/langgraph/threads/{thread_id}
|
DELETE /api/langgraph/threads/{thread_id}
|
||||||
|
|
||||||
2. Web UI follows up with Gateway cleanup
|
2. Web UI follows up with Gateway cleanup
|
||||||
|
|||||||
@@ -1,331 +0,0 @@
|
|||||||
# 用户认证与隔离设计
|
|
||||||
|
|
||||||
本文档描述 DeerFlow 当前内置认证模块的设计,而不是历史 RFC。它覆盖浏览器登录、API 认证、CSRF、用户隔离、首次初始化、密码重置、内部调用和升级迁移。
|
|
||||||
|
|
||||||
## 设计目标
|
|
||||||
|
|
||||||
认证模块的核心目标是把 DeerFlow 从“本地单用户工具”提升为“可多用户部署的 agent runtime”,并让用户身份贯穿 HTTP API、LangGraph-compatible runtime、文件系统、memory、自定义 agent 和反馈数据。
|
|
||||||
|
|
||||||
设计约束:
|
|
||||||
|
|
||||||
- 默认强制认证:除健康检查、文档和 auth bootstrap 端点外,HTTP 路由都必须有有效 session。
|
|
||||||
- 服务端持有所有权:客户端 metadata 不能声明 `user_id` 或 `owner_id`。
|
|
||||||
- 隔离默认开启:repository(仓储)、文件路径、memory、agent 配置默认按当前用户解析。
|
|
||||||
- 旧数据可升级:无认证版本留下的 thread 可以在 admin 存在后迁移到 admin。
|
|
||||||
- 密码不进日志:首次初始化由操作者设置密码;`reset_admin` 只写 0600 凭据文件。
|
|
||||||
|
|
||||||
非目标:
|
|
||||||
|
|
||||||
- 当前 OAuth 端点只是占位,尚未实现第三方登录。
|
|
||||||
- 当前用户角色只有 `admin` 和 `user`,尚未实现细粒度 RBAC。
|
|
||||||
- 当前登录限速是进程内字典,多 worker 下不是全局精确限速。
|
|
||||||
|
|
||||||
## 核心模型
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
graph TB
|
|
||||||
classDef actor fill:#D8CFC4,stroke:#6E6259,color:#2F2A26;
|
|
||||||
classDef api fill:#C9D7D2,stroke:#5D706A,color:#21302C;
|
|
||||||
classDef state fill:#D7D3E8,stroke:#6B6680,color:#29263A;
|
|
||||||
classDef data fill:#E5D2C4,stroke:#806A5B,color:#30251E;
|
|
||||||
|
|
||||||
Browser["Browser — access_token cookie and csrf_token cookie"]:::actor
|
|
||||||
AuthMiddleware["AuthMiddleware — strict session gate"]:::api
|
|
||||||
CSRFMiddleware["CSRFMiddleware — double-submit token and Origin check"]:::api
|
|
||||||
AuthRoutes["Auth routes — initialize login register logout me change-password"]:::api
|
|
||||||
UserContext["Current user ContextVar — request-scoped identity"]:::state
|
|
||||||
Repositories["Repositories — AUTO resolves user_id from context"]:::state
|
|
||||||
Files["Filesystem — users/{user_id}/threads/{thread_id}/user-data"]:::data
|
|
||||||
Memory["Memory and agents — users/{user_id}/memory.json and agents"]:::data
|
|
||||||
|
|
||||||
Browser --> AuthMiddleware
|
|
||||||
Browser --> CSRFMiddleware
|
|
||||||
AuthMiddleware --> AuthRoutes
|
|
||||||
AuthMiddleware --> UserContext
|
|
||||||
UserContext --> Repositories
|
|
||||||
UserContext --> Files
|
|
||||||
UserContext --> Memory
|
|
||||||
```
|
|
||||||
|
|
||||||
### 用户表
|
|
||||||
|
|
||||||
用户记录定义在 `app.gateway.auth.models.User`,持久化到 `users` 表。关键字段:
|
|
||||||
|
|
||||||
| 字段 | 语义 |
|
|
||||||
|---|---|
|
|
||||||
| `id` | 用户主键,JWT `sub` 使用该值 |
|
|
||||||
| `email` | 唯一登录名 |
|
|
||||||
| `password_hash` | bcrypt hash,OAuth 用户可为空 |
|
|
||||||
| `system_role` | `admin` 或 `user` |
|
|
||||||
| `needs_setup` | reset 后要求用户完成邮箱 / 密码设置 |
|
|
||||||
| `token_version` | 改密码或 reset 时递增,用于废弃旧 JWT |
|
|
||||||
|
|
||||||
### 运行时身份
|
|
||||||
|
|
||||||
认证成功后,`AuthMiddleware` 把用户同时写入:
|
|
||||||
|
|
||||||
- `request.state.user`
|
|
||||||
- `request.state.auth`
|
|
||||||
- `deerflow.runtime.user_context` 的 `ContextVar`
|
|
||||||
|
|
||||||
`ContextVar` 是这里的核心边界。上层 Gateway 负责写入身份,下层 persistence / file path 只读取结构化的当前用户,不反向依赖 `app.gateway.auth` 具体类型。
|
|
||||||
|
|
||||||
可以把 repository 调用的用户参数理解成一个三态 ADT:
|
|
||||||
|
|
||||||
```scala
|
|
||||||
enum UserScope:
|
|
||||||
case AutoFromContext
|
|
||||||
case Explicit(userId: String)
|
|
||||||
case BypassForMigration
|
|
||||||
```
|
|
||||||
|
|
||||||
对应 Python 实现是 `AUTO | str | None`:
|
|
||||||
|
|
||||||
- `AUTO`:从 `ContextVar` 解析当前用户;没有上下文则抛错。
|
|
||||||
- `str`:显式指定用户,主要用于测试或管理脚本。
|
|
||||||
- `None`:跳过用户过滤,只允许迁移脚本或 admin CLI 使用。
|
|
||||||
|
|
||||||
## 登录与初始化流程
|
|
||||||
|
|
||||||
### 首次初始化
|
|
||||||
|
|
||||||
首次启动时,如果没有 admin,服务不会自动创建账号,只记录日志提示访问 `/setup`。
|
|
||||||
|
|
||||||
流程:
|
|
||||||
|
|
||||||
1. 用户访问 `/setup`。
|
|
||||||
2. 前端调用 `GET /api/v1/auth/setup-status`。
|
|
||||||
3. 如果返回 `{"needs_setup": true}`,前端展示创建 admin 表单。
|
|
||||||
4. 表单提交 `POST /api/v1/auth/initialize`。
|
|
||||||
5. 服务端确认当前没有 admin,创建 `system_role="admin"`、`needs_setup=false` 的用户。
|
|
||||||
6. 服务端设置 `access_token` HttpOnly cookie,用户进入 workspace。
|
|
||||||
|
|
||||||
`/api/v1/auth/initialize` 只在没有 admin 时可用。并发初始化由数据库唯一约束兜底,失败方返回 409。
|
|
||||||
|
|
||||||
### 普通登录
|
|
||||||
|
|
||||||
`POST /api/v1/auth/login/local` 使用 `OAuth2PasswordRequestForm`:
|
|
||||||
|
|
||||||
- `username` 是邮箱。
|
|
||||||
- `password` 是密码。
|
|
||||||
- 成功后签发 JWT,放入 `access_token` HttpOnly cookie。
|
|
||||||
- 响应体只返回 `expires_in` 和 `needs_setup`,不返回 token。
|
|
||||||
|
|
||||||
登录失败会按客户端 IP 计数。IP 解析只在 TCP peer 属于 `AUTH_TRUSTED_PROXIES` 时信任 `X-Real-IP`,不使用 `X-Forwarded-For`。
|
|
||||||
|
|
||||||
### 注册
|
|
||||||
|
|
||||||
`POST /api/v1/auth/register` 创建普通 `user`,并自动登录。
|
|
||||||
|
|
||||||
当前实现允许在没有 admin 时注册普通用户,但 `setup-status` 仍会返回 `needs_setup=true`,因为 admin 仍不存在。这是当前产品策略边界:如果后续要求“必须先初始化 admin 才能注册普通用户”,需要在 `/register` 增加 admin-exists gate。
|
|
||||||
|
|
||||||
### 改密码与 reset setup
|
|
||||||
|
|
||||||
`POST /api/v1/auth/change-password` 需要当前密码和新密码:
|
|
||||||
|
|
||||||
- 校验当前密码。
|
|
||||||
- 更新 bcrypt hash。
|
|
||||||
- `token_version += 1`,使旧 JWT 立即失效。
|
|
||||||
- 重新签发 cookie。
|
|
||||||
- 如果 `needs_setup=true` 且传了 `new_email`,则更新邮箱并清除 `needs_setup`。
|
|
||||||
|
|
||||||
`python -m app.gateway.auth.reset_admin` 会:
|
|
||||||
|
|
||||||
- 找到 admin 或指定邮箱用户。
|
|
||||||
- 生成随机密码。
|
|
||||||
- 更新密码 hash。
|
|
||||||
- `token_version += 1`。
|
|
||||||
- 设置 `needs_setup=true`。
|
|
||||||
- 写入 `.deer-flow/admin_initial_credentials.txt`,权限 `0600`。
|
|
||||||
|
|
||||||
命令行只输出凭据文件路径,不输出明文密码。
|
|
||||||
|
|
||||||
## HTTP 认证边界
|
|
||||||
|
|
||||||
`AuthMiddleware` 是 fail-closed(默认拒绝)的全局认证门。
|
|
||||||
|
|
||||||
公开路径:
|
|
||||||
|
|
||||||
- `/health`
|
|
||||||
- `/docs`
|
|
||||||
- `/redoc`
|
|
||||||
- `/openapi.json`
|
|
||||||
- `/api/v1/auth/login/local`
|
|
||||||
- `/api/v1/auth/register`
|
|
||||||
- `/api/v1/auth/logout`
|
|
||||||
- `/api/v1/auth/setup-status`
|
|
||||||
- `/api/v1/auth/initialize`
|
|
||||||
|
|
||||||
其余路径都要求有效 `access_token` cookie。存在 cookie 但 JWT 无效、过期、用户不存在或 `token_version` 不匹配时,直接返回 401,而不是让请求穿透到业务路由。
|
|
||||||
|
|
||||||
路由级别的 owner check 由 `require_permission(..., owner_check=True)` 完成:
|
|
||||||
|
|
||||||
- 读类请求允许旧的未追踪 legacy thread 兼容读取。
|
|
||||||
- 写 / 删除类请求使用 `require_existing=True`,要求 thread row 存在且属于当前用户,避免删除后缺 row 导致其他用户误通过。
|
|
||||||
|
|
||||||
## CSRF 设计
|
|
||||||
|
|
||||||
DeerFlow 使用 Double Submit Cookie:
|
|
||||||
|
|
||||||
- 服务端设置 `csrf_token` cookie。
|
|
||||||
- 前端 state-changing 请求发送同值 `X-CSRF-Token` header。
|
|
||||||
- 服务端用 `secrets.compare_digest` 比较 cookie/header。
|
|
||||||
|
|
||||||
需要 CSRF 的方法:
|
|
||||||
|
|
||||||
- `POST`
|
|
||||||
- `PUT`
|
|
||||||
- `DELETE`
|
|
||||||
- `PATCH`
|
|
||||||
|
|
||||||
auth bootstrap 端点(login/register/initialize/logout)不要求 double-submit token,因为首次调用时浏览器还没有 token;但这些端点会校验 browser `Origin`,拒绝 hostile Origin,避免 login CSRF / session fixation。
|
|
||||||
|
|
||||||
## 用户隔离
|
|
||||||
|
|
||||||
### Thread metadata
|
|
||||||
|
|
||||||
Thread metadata 存在 `threads_meta`,关键隔离字段是 `user_id`。
|
|
||||||
|
|
||||||
创建 thread 时:
|
|
||||||
|
|
||||||
- 客户端传入的 `metadata.user_id` 和 `metadata.owner_id` 会被剥离。
|
|
||||||
- `ThreadMetaRepository.create(..., user_id=AUTO)` 从 `ContextVar` 解析真实用户。
|
|
||||||
- `/api/threads/search` 默认只返回当前用户的 thread。
|
|
||||||
|
|
||||||
读取 / 修改 / 删除时:
|
|
||||||
|
|
||||||
- `get()` 默认按当前用户过滤。
|
|
||||||
- `check_access()` 用于路由 owner check。
|
|
||||||
- 对其他用户的 thread 返回 404,避免泄露资源存在性。
|
|
||||||
|
|
||||||
### 文件系统
|
|
||||||
|
|
||||||
当前线程文件布局:
|
|
||||||
|
|
||||||
```text
|
|
||||||
{base_dir}/users/{user_id}/threads/{thread_id}/user-data/
|
|
||||||
├── workspace/
|
|
||||||
├── uploads/
|
|
||||||
└── outputs/
|
|
||||||
```
|
|
||||||
|
|
||||||
agent 在 sandbox 内看到统一虚拟路径:
|
|
||||||
|
|
||||||
```text
|
|
||||||
/mnt/user-data/workspace
|
|
||||||
/mnt/user-data/uploads
|
|
||||||
/mnt/user-data/outputs
|
|
||||||
```
|
|
||||||
|
|
||||||
`ThreadDataMiddleware` 使用 `get_effective_user_id()` 解析当前用户并生成线程路径。没有认证上下文时会落到 `default` 用户桶,主要用于内部调用、嵌入式 client 或无 HTTP 的本地执行路径。
|
|
||||||
|
|
||||||
### Memory
|
|
||||||
|
|
||||||
默认 memory 存储:
|
|
||||||
|
|
||||||
```text
|
|
||||||
{base_dir}/users/{user_id}/memory.json
|
|
||||||
{base_dir}/users/{user_id}/agents/{agent_name}/memory.json
|
|
||||||
```
|
|
||||||
|
|
||||||
有用户上下文时,空或相对 `memory.storage_path` 都使用上述 per-user 默认路径;只有绝对 `memory.storage_path` 会视为显式 opt-out(退出) per-user isolation,所有用户共享该路径。无用户上下文的 legacy 路径仍会把相对 `storage_path` 解析到 `Paths.base_dir` 下。
|
|
||||||
|
|
||||||
### 自定义 agent
|
|
||||||
|
|
||||||
用户自定义 agent 写入:
|
|
||||||
|
|
||||||
```text
|
|
||||||
{base_dir}/users/{user_id}/agents/{agent_name}/
|
|
||||||
├── config.yaml
|
|
||||||
├── SOUL.md
|
|
||||||
└── memory.json
|
|
||||||
```
|
|
||||||
|
|
||||||
旧布局 `{base_dir}/agents/{agent_name}/` 只作为只读兼容回退。更新或删除旧共享 agent 会要求先运行迁移脚本。
|
|
||||||
|
|
||||||
## 内部调用与 IM 渠道
|
|
||||||
|
|
||||||
IM channel worker 不是浏览器用户,不持有浏览器 cookie。它们通过 Gateway 内部认证:
|
|
||||||
|
|
||||||
- 请求带 `X-DeerFlow-Internal-Token`。
|
|
||||||
- 同时带匹配的 CSRF cookie/header。
|
|
||||||
- 服务端识别为内部用户,`id="default"`、`system_role="internal"`。
|
|
||||||
|
|
||||||
这意味着 channel 产生的数据默认进入 `default` 用户桶。这个选择适合“平台级 bot 身份”,但不是“每个 IM 用户单独隔离”。如果后续要做到外部 IM 用户隔离,需要把外部 platform user 映射到 DeerFlow user,并让 channel manager 设置对应的 scoped identity。
|
|
||||||
|
|
||||||
## LangGraph-compatible 认证
|
|
||||||
|
|
||||||
Gateway 内嵌 runtime 路径由 `AuthMiddleware` 和 `CSRFMiddleware` 保护。
|
|
||||||
|
|
||||||
仓库仍保留 `app.gateway.langgraph_auth`,用于 LangGraph Server 直连模式:
|
|
||||||
|
|
||||||
- `@auth.authenticate` 校验 JWT cookie、CSRF、用户存在性和 `token_version`。
|
|
||||||
- `@auth.on` 在写入 metadata 时注入 `user_id`,并在读路径返回 `{"user_id": current_user}` 过滤条件。
|
|
||||||
|
|
||||||
这保证 Gateway 路由和 LangGraph-compatible 直连模式使用同一 JWT 语义。
|
|
||||||
|
|
||||||
## 升级与迁移
|
|
||||||
|
|
||||||
从无认证版本升级时,可能存在没有 `user_id` 的历史 thread。
|
|
||||||
|
|
||||||
当前策略:
|
|
||||||
|
|
||||||
1. 首次启动如果没有 admin,只提示访问 `/setup`,不迁移。
|
|
||||||
2. 操作者创建 admin。
|
|
||||||
3. 后续启动时,`_ensure_admin_user()` 找到 admin,并把 LangGraph store 中缺少 `metadata.user_id` 的 thread 迁移到 admin。
|
|
||||||
|
|
||||||
文件系统旧布局迁移由脚本处理:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
PYTHONPATH=. python scripts/migrate_user_isolation.py --dry-run
|
|
||||||
PYTHONPATH=. python scripts/migrate_user_isolation.py --user-id <target-user-id>
|
|
||||||
```
|
|
||||||
|
|
||||||
迁移脚本覆盖 legacy `memory.json`、`threads/` 和 `agents/` 到 per-user layout。
|
|
||||||
|
|
||||||
## 安全不变量
|
|
||||||
|
|
||||||
必须长期保持的不变量:
|
|
||||||
|
|
||||||
- JWT 只在 HttpOnly cookie 中传输,不出现在响应 JSON。
|
|
||||||
- 任何非 public HTTP 路由都不能只靠“cookie 存在”放行,必须严格验证 JWT。
|
|
||||||
- `token_version` 不匹配必须拒绝,保证改密码 / reset 后旧 session 失效。
|
|
||||||
- 客户端 metadata 中的 `user_id` / `owner_id` 必须剥离。
|
|
||||||
- repository 默认 `AUTO` 必须从当前用户上下文解析,不能静默退化成全局查询。
|
|
||||||
- 只有迁移脚本和 admin CLI 可以显式传 `user_id=None` 绕过隔离。
|
|
||||||
- 本地文件路径必须通过 `Paths` 和 sandbox path validation 解析,不能拼接未校验的用户输入。
|
|
||||||
- 捕获认证、迁移、后台任务异常必须记录日志;不能空 catch。
|
|
||||||
|
|
||||||
## 已知边界
|
|
||||||
|
|
||||||
| 边界 | 当前行为 | 后续方向 |
|
|
||||||
|---|---|---|
|
|
||||||
| 无 admin 时注册普通用户 | 允许注册普通 `user` | 如产品要求先初始化 admin,给 `/register` 加 gate |
|
|
||||||
| 登录限速 | 进程内 dict,单 worker 精确,多 worker 近似 | Redis / DB-backed rate limiter |
|
|
||||||
| OAuth | 端点占位,未实现 | 接入 provider 并统一 `token_version` / role 语义 |
|
|
||||||
| IM 用户隔离 | channel 使用 `default` 内部用户 | 建立外部用户到 DeerFlow user 的映射 |
|
|
||||||
| 绝对 memory path | 显式共享 memory | UI / docs 明确提示 opt-out 风险 |
|
|
||||||
|
|
||||||
## 相关文件
|
|
||||||
|
|
||||||
| 文件 | 职责 |
|
|
||||||
|---|---|
|
|
||||||
| `app/gateway/auth_middleware.py` | 全局认证门、JWT 严格验证、写入 user context |
|
|
||||||
| `app/gateway/csrf_middleware.py` | CSRF double-submit 和 auth Origin 校验 |
|
|
||||||
| `app/gateway/routers/auth.py` | initialize/login/register/logout/me/change-password |
|
|
||||||
| `app/gateway/auth/jwt.py` | JWT 创建与解析 |
|
|
||||||
| `app/gateway/auth/reset_admin.py` | 密码 reset CLI |
|
|
||||||
| `app/gateway/auth/credential_file.py` | 0600 凭据文件写入 |
|
|
||||||
| `app/gateway/authz.py` | 路由权限与 owner check |
|
|
||||||
| `deerflow/runtime/user_context.py` | 当前用户 ContextVar 与 `AUTO` sentinel |
|
|
||||||
| `deerflow/persistence/thread_meta/` | thread metadata owner filter |
|
|
||||||
| `deerflow/config/paths.py` | per-user filesystem layout |
|
|
||||||
| `deerflow/agents/middlewares/thread_data_middleware.py` | run 时解析用户线程目录 |
|
|
||||||
| `deerflow/agents/memory/storage.py` | per-user memory storage |
|
|
||||||
| `deerflow/config/agents_config.py` | per-user custom agents |
|
|
||||||
| `app/channels/manager.py` | IM channel 内部认证调用 |
|
|
||||||
| `scripts/migrate_user_isolation.py` | legacy 数据迁移到 per-user layout |
|
|
||||||
| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库,包含 users / threads_meta / runs / feedback 等表 |
|
|
||||||
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
|
|
||||||
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
|
|
||||||
@@ -24,11 +24,11 @@ All other test plan sections were executed against either:
|
|||||||
|
|
||||||
| Case | Title | What it covers | Why not run |
|
| Case | Title | What it covers | Why not run |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| TC-DOCKER-01 | `deerflow.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
|
| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
|
||||||
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
|
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
|
||||||
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
|
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
|
||||||
| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` |
|
| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` |
|
||||||
| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
|
| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
|
||||||
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
|
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
|
||||||
|
|
||||||
## Coverage already provided by non-Docker tests
|
## Coverage already provided by non-Docker tests
|
||||||
@@ -41,8 +41,8 @@ the test cases that ran on sg_dev or local:
|
|||||||
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
|
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
|
||||||
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
|
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
|
||||||
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
|
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
|
||||||
| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies |
|
| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP |
|
||||||
| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
|
| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
|
||||||
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
|
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
|
||||||
|
|
||||||
## Reproduction steps when Docker becomes available
|
## Reproduction steps when Docker becomes available
|
||||||
@@ -72,6 +72,6 @@ Then run TC-DOCKER-01..06 from the test plan as written.
|
|||||||
about *container packaging* details (bind mounts, multi-worker, log
|
about *container packaging* details (bind mounts, multi-worker, log
|
||||||
collection), not about whether the auth code paths work.
|
collection), not about whether the auth code paths work.
|
||||||
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
|
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
|
||||||
the current reset flow (`reset_admin` → 0600 credentials file, no log leak).
|
the post-simplify reality (credentials file → 0600 file, no log leak).
|
||||||
The old "grep 'Password:' in docker logs" expectation would have failed
|
The old "grep 'Password:' in docker logs" expectation would have failed
|
||||||
silently and given a false sense of coverage.
|
silently and given a false sense of coverage.
|
||||||
|
|||||||
+105
-149
@@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 清除已有数据
|
# 清除已有数据
|
||||||
rm -f backend/.deer-flow/data/deerflow.db
|
rm -f backend/.deer-flow/users.db
|
||||||
|
|
||||||
# 选择模式启动
|
# 选择模式启动
|
||||||
make dev # 标准模式
|
make dev # 标准模式
|
||||||
@@ -28,11 +28,10 @@ make dev-pro # Gateway 模式
|
|||||||
```
|
```
|
||||||
|
|
||||||
**验证点:**
|
**验证点:**
|
||||||
- [ ] 控制台不输出 admin 邮箱或明文密码
|
- [ ] 控制台输出 admin 邮箱和随机密码
|
||||||
- [ ] 控制台提示 `First boot detected — no admin account exists.`
|
- [ ] 密码格式为 `secrets.token_urlsafe(16)` 的 22 字符字符串
|
||||||
- [ ] 控制台提示访问 `/setup` 完成 admin 创建
|
- [ ] 邮箱为 `admin@deerflow.dev`
|
||||||
- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": true}`
|
- [ ] 提示 `Change it after login: Settings -> Account`
|
||||||
- [ ] 前端访问 `/login` 会跳转 `/setup`
|
|
||||||
|
|
||||||
### 1.2 非首次启动
|
### 1.2 非首次启动
|
||||||
|
|
||||||
@@ -43,8 +42,7 @@ make dev
|
|||||||
|
|
||||||
**验证点:**
|
**验证点:**
|
||||||
- [ ] 控制台不输出密码
|
- [ ] 控制台不输出密码
|
||||||
- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": false}`
|
- [ ] 如果 admin 仍 `needs_setup=True`,控制台有 warning 提示
|
||||||
- [ ] 已登录用户如果 `needs_setup=True`,访问 workspace 会被引导到 `/setup` 完成改邮箱 / 改密码流程
|
|
||||||
|
|
||||||
### 1.3 环境变量配置
|
### 1.3 环境变量配置
|
||||||
|
|
||||||
@@ -78,22 +76,19 @@ make dev
|
|||||||
curl -s $BASE/api/v1/auth/setup-status | jq .
|
curl -s $BASE/api/v1/auth/setup-status | jq .
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:** 返回 `{"needs_setup": false}`(admin 在启动时已自动创建,`count_users() > 0`)。仅在启动完成前的极短窗口内可能返回 `true`。
|
||||||
- 干净数据库且尚未初始化 admin:返回 `{"needs_setup": true}`
|
|
||||||
- 已存在 admin:返回 `{"needs_setup": false}`
|
|
||||||
|
|
||||||
#### TC-API-02: 首次初始化 Admin
|
#### TC-API-02: Admin 首次登录
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -s -X POST $BASE/api/v1/auth/initialize \
|
curl -s -X POST $BASE/api/v1/auth/login/local \
|
||||||
-H "Content-Type: application/json" \
|
-d "username=admin@deerflow.dev&password=<控制台密码>" \
|
||||||
-d '{"email":"admin@example.com","password":"AdminPass1!"}' \
|
|
||||||
-c cookies.txt | jq .
|
-c cookies.txt | jq .
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- 状态码 201
|
- 状态码 200
|
||||||
- Body: `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}`
|
- Body: `{"expires_in": 604800, "needs_setup": true}`
|
||||||
- `cookies.txt` 包含 `access_token`(HttpOnly)和 `csrf_token`(非 HttpOnly)
|
- `cookies.txt` 包含 `access_token`(HttpOnly)和 `csrf_token`(非 HttpOnly)
|
||||||
|
|
||||||
#### TC-API-03: 获取当前用户
|
#### TC-API-03: 获取当前用户
|
||||||
@@ -102,9 +97,9 @@ curl -s -X POST $BASE/api/v1/auth/initialize \
|
|||||||
curl -s $BASE/api/v1/auth/me -b cookies.txt | jq .
|
curl -s $BASE/api/v1/auth/me -b cookies.txt | jq .
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}`
|
**预期:** `{"id": "...", "email": "admin@deerflow.dev", "system_role": "admin", "needs_setup": true}`
|
||||||
|
|
||||||
#### TC-API-04: 改密码流程
|
#### TC-API-04: Setup 流程(改邮箱 + 改密码)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}')
|
CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}')
|
||||||
@@ -112,36 +107,13 @@ curl -s -X POST $BASE/api/v1/auth/change-password \
|
|||||||
-b cookies.txt \
|
-b cookies.txt \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "X-CSRF-Token: $CSRF" \
|
-H "X-CSRF-Token: $CSRF" \
|
||||||
-d '{"current_password":"AdminPass1!","new_password":"NewPass123!"}' | jq .
|
-d '{"current_password":"<控制台密码>","new_password":"NewPass123!","new_email":"admin@example.com"}' | jq .
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- 状态码 200
|
- 状态码 200
|
||||||
- `{"message": "Password changed successfully"}`
|
- `{"message": "Password changed successfully"}`
|
||||||
- 再调 `/auth/me` 仍为 `admin@example.com`,`needs_setup` 仍为 `false`
|
- 再调 `/auth/me` 邮箱变为 `admin@example.com`,`needs_setup` 变为 `false`
|
||||||
|
|
||||||
#### TC-API-04a: reset_admin 后的 Setup 流程(改邮箱 + 改密码)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
|
||||||
# 从 .deer-flow/admin_initial_credentials.txt 读取 reset 后密码
|
|
||||||
|
|
||||||
curl -s -X POST $BASE/api/v1/auth/login/local \
|
|
||||||
-d "username=admin@example.com&password=<凭据文件密码>" \
|
|
||||||
-c cookies.txt | jq .
|
|
||||||
|
|
||||||
CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}')
|
|
||||||
curl -s -X POST $BASE/api/v1/auth/change-password \
|
|
||||||
-b cookies.txt \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "X-CSRF-Token: $CSRF" \
|
|
||||||
-d '{"current_password":"<凭据文件密码>","new_password":"AdminPass2!","new_email":"admin2@example.com"}' | jq .
|
|
||||||
```
|
|
||||||
|
|
||||||
**预期:**
|
|
||||||
- 登录返回 `{"expires_in": 604800, "needs_setup": true}`
|
|
||||||
- `change-password` 后 `/auth/me` 邮箱变为 `admin2@example.com`,`needs_setup` 变为 `false`
|
|
||||||
|
|
||||||
#### TC-API-05: 普通用户注册
|
#### TC-API-05: 普通用户注册
|
||||||
|
|
||||||
@@ -521,7 +493,7 @@ curl -s -X POST $BASE/api/v1/auth/register \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 检查数据库
|
# 检查数据库
|
||||||
sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM users LIMIT 3;"
|
sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMIT 3;"
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** `password_hash` 以 `$2b$` 开头(bcrypt 格式)
|
**预期:** `password_hash` 以 `$2b$` 开头(bcrypt 格式)
|
||||||
@@ -534,25 +506,24 @@ sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM us
|
|||||||
|
|
||||||
### 4.1 首次登录流程
|
### 4.1 首次登录流程
|
||||||
|
|
||||||
#### TC-UI-01: 无 admin 时访问 workspace 跳转 setup
|
#### TC-UI-01: 访问首页跳转登录
|
||||||
|
|
||||||
1. 打开 `http://localhost:2026/workspace`
|
1. 打开 `http://localhost:2026/workspace`
|
||||||
2. **预期:** 自动跳转到 `/setup`
|
2. **预期:** 自动跳转到 `/login`
|
||||||
|
|
||||||
#### TC-UI-02: Setup 页面创建 admin
|
#### TC-UI-02: Login 页面
|
||||||
|
|
||||||
1. 输入 admin 邮箱、密码、确认密码
|
1. 输入 admin 邮箱和控制台密码
|
||||||
2. 点击 Create Admin Account
|
2. 点击 Login
|
||||||
|
3. **预期:** 跳转到 `/setup`(因为 `needs_setup=true`)
|
||||||
|
|
||||||
|
#### TC-UI-03: Setup 页面
|
||||||
|
|
||||||
|
1. 输入新邮箱、控制台密码(current)、新密码、确认密码
|
||||||
|
2. 点击 Complete Setup
|
||||||
3. **预期:** 跳转到 `/workspace`
|
3. **预期:** 跳转到 `/workspace`
|
||||||
4. 刷新页面不跳回 `/setup`
|
4. 刷新页面不跳回 `/setup`
|
||||||
|
|
||||||
#### TC-UI-03: 已初始化后 Login 页面
|
|
||||||
|
|
||||||
1. 退出登录后访问 `/login`
|
|
||||||
2. 输入 admin 邮箱和密码
|
|
||||||
3. 点击 Login
|
|
||||||
4. **预期:** 跳转到 `/workspace`
|
|
||||||
|
|
||||||
#### TC-UI-04: Setup 密码不匹配
|
#### TC-UI-04: Setup 密码不匹配
|
||||||
|
|
||||||
1. 新密码和确认密码不一致
|
1. 新密码和确认密码不一致
|
||||||
@@ -631,7 +602,7 @@ sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM us
|
|||||||
#### TC-UI-15: reset_admin 后重新登录
|
#### TC-UI-15: reset_admin 后重新登录
|
||||||
|
|
||||||
1. 执行 `cd backend && python -m app.gateway.auth.reset_admin`
|
1. 执行 `cd backend && python -m app.gateway.auth.reset_admin`
|
||||||
2. 从 `.deer-flow/admin_initial_credentials.txt` 读取新密码并登录
|
2. 使用新密码登录
|
||||||
3. **预期:** 跳转到 `/setup` 页面(`needs_setup` 被重置为 true)
|
3. **预期:** 跳转到 `/setup` 页面(`needs_setup` 被重置为 true)
|
||||||
4. 旧 session 已失效
|
4. 旧 session 已失效
|
||||||
|
|
||||||
@@ -674,28 +645,18 @@ make install
|
|||||||
make dev
|
make dev
|
||||||
```
|
```
|
||||||
|
|
||||||
#### TC-UPG-01: 首次启动等待 admin 初始化
|
#### TC-UPG-01: 首次启动创建 admin
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 控制台不输出 admin 邮箱或随机密码
|
- [ ] 控制台输出 admin 邮箱(`admin@deerflow.dev`)和随机密码
|
||||||
- [ ] 访问 `/setup` 可创建第一个 admin
|
|
||||||
- [ ] 无报错,正常启动
|
- [ ] 无报错,正常启动
|
||||||
|
|
||||||
#### TC-UPG-02: 旧 Thread 迁移到 admin
|
#### TC-UPG-02: 旧 Thread 迁移到 admin
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 创建第一个 admin
|
|
||||||
curl -s -X POST http://localhost:2026/api/v1/auth/initialize \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{"email":"admin@example.com","password":"AdminPass1!"}' \
|
|
||||||
-c cookies.txt
|
|
||||||
|
|
||||||
# 重启一次:启动迁移只在已有 admin 的启动路径执行
|
|
||||||
make stop && make dev
|
|
||||||
|
|
||||||
# 登录 admin
|
# 登录 admin
|
||||||
curl -s -X POST http://localhost:2026/api/v1/auth/login/local \
|
curl -s -X POST http://localhost:2026/api/v1/auth/login/local \
|
||||||
-d "username=admin@example.com&password=AdminPass1!" \
|
-d "username=admin@deerflow.dev&password=<控制台密码>" \
|
||||||
-c cookies.txt
|
-c cookies.txt
|
||||||
|
|
||||||
# 查看 thread 列表
|
# 查看 thread 列表
|
||||||
@@ -709,8 +670,8 @@ curl -s -X POST http://localhost:2026/api/threads/search \
|
|||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 返回的 thread 数量 ≥ 旧版创建的数量
|
- [ ] 返回的 thread 数量 ≥ 旧版创建的数量
|
||||||
- [ ] 控制台日志有 `Migrated N orphan LangGraph thread(s) to admin`
|
- [ ] 控制台日志有 `Migrated N orphaned thread(s) to admin`
|
||||||
- [ ] 旧 thread 只对 admin 可见
|
- [ ] 每个 thread 的 `metadata.owner_id` 都已被设为 admin 的 ID
|
||||||
|
|
||||||
#### TC-UPG-03: 旧 Thread 内容完整
|
#### TC-UPG-03: 旧 Thread 内容完整
|
||||||
|
|
||||||
@@ -722,7 +683,7 @@ curl -s http://localhost:2026/api/threads/<old-thread-id> \
|
|||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] `metadata.title` 保留原值(如 `old-thread-1`)
|
- [ ] `metadata.title` 保留原值(如 `old-thread-1`)
|
||||||
- [ ] 响应不回显服务端保留的 `user_id` / `owner_id`
|
- [ ] `metadata.owner_id` 已填充
|
||||||
|
|
||||||
#### TC-UPG-04: 新用户看不到旧 Thread
|
#### TC-UPG-04: 新用户看不到旧 Thread
|
||||||
|
|
||||||
@@ -745,19 +706,18 @@ curl -s -X POST http://localhost:2026/api/threads/search \
|
|||||||
|
|
||||||
### 5.3 数据库 Schema 兼容
|
### 5.3 数据库 Schema 兼容
|
||||||
|
|
||||||
#### TC-UPG-05: 无 deerflow.db 时创建 schema 但不创建默认用户
|
#### TC-UPG-05: 无 users.db 时自动创建
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ls -la backend/.deer-flow/data/deerflow.db
|
ls -la backend/.deer-flow/users.db
|
||||||
sqlite3 backend/.deer-flow/data/deerflow.db "SELECT COUNT(*) FROM users;"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列;未调用 `/initialize` 前用户数为 0
|
**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列
|
||||||
|
|
||||||
#### TC-UPG-06: deerflow.db WAL 模式
|
#### TC-UPG-06: users.db WAL 模式
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sqlite3 backend/.deer-flow/data/deerflow.db "PRAGMA journal_mode;"
|
sqlite3 backend/.deer-flow/users.db "PRAGMA journal_mode;"
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** 返回 `wal`
|
**预期:** 返回 `wal`
|
||||||
@@ -808,9 +768,9 @@ make dev
|
|||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 服务正常启动(忽略 `deerflow.db`,无 auth 相关代码不报错)
|
- [ ] 服务正常启动(忽略 `users.db`,无 auth 相关代码不报错)
|
||||||
- [ ] 旧对话数据仍然可访问
|
- [ ] 旧对话数据仍然可访问
|
||||||
- [ ] `deerflow.db` 文件残留但不影响运行
|
- [ ] `users.db` 文件残留但不影响运行
|
||||||
|
|
||||||
#### TC-UPG-12: 再次升级到 auth 分支
|
#### TC-UPG-12: 再次升级到 auth 分支
|
||||||
|
|
||||||
@@ -821,47 +781,51 @@ make dev
|
|||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 识别已有 `deerflow.db`,不重新创建 admin
|
- [ ] 识别已有 `users.db`,不重新创建 admin
|
||||||
- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `deerflow.db`)
|
- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `users.db`)
|
||||||
|
|
||||||
### 5.7 Admin 初始化与 reset_admin
|
### 5.7 休眠 Admin(初始密码未使用/未更改)
|
||||||
|
|
||||||
> 首次启动不生成默认 admin,也不在日志输出密码。忘记密码时走 `reset_admin`,新密码写入 0600 凭据文件。
|
> 首次启动生成 admin + 随机密码,但运维未登录、未改密码。
|
||||||
|
> 密码只在首次启动的控制台闪过一次,后续启动不再显示。
|
||||||
|
|
||||||
#### TC-UPG-13: 未初始化 admin 时重启不创建默认账号
|
#### TC-UPG-13: 重启后自动重置密码并打印
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
rm -f backend/.deer-flow/data/deerflow.db
|
# 首次启动,记录密码
|
||||||
|
rm -f backend/.deer-flow/users.db
|
||||||
make dev
|
make dev
|
||||||
|
# 控制台输出密码 P0,不登录
|
||||||
make stop
|
make stop
|
||||||
|
|
||||||
|
# 隔了几天,再次启动
|
||||||
make dev
|
make dev
|
||||||
curl -s $BASE/api/v1/auth/setup-status | jq .
|
# 控制台输出新密码 P1
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 控制台不输出密码
|
- [ ] 控制台输出 `Admin account setup incomplete — password reset`
|
||||||
- [ ] `setup-status` 仍为 `{"needs_setup": true}`
|
- [ ] 输出新密码 P1(P0 已失效)
|
||||||
- [ ] 访问 `/setup` 仍可创建第一个 admin
|
- [ ] 用 P1 可以登录,P0 不可以
|
||||||
|
- [ ] 登录后 `needs_setup=true`,跳转 `/setup`
|
||||||
|
- [ ] `token_version` 递增(旧 session 如有也失效)
|
||||||
|
|
||||||
#### TC-UPG-14: 密码丢失 — reset_admin 写入凭据文件
|
#### TC-UPG-14: 密码丢失 — 无需 CLI,重启即可
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
# 忘记了控制台密码 → 直接重启服务
|
||||||
ls -la backend/.deer-flow/admin_initial_credentials.txt
|
make stop && make dev
|
||||||
cat backend/.deer-flow/admin_initial_credentials.txt
|
# 控制台自动输出新密码
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 命令行只输出凭据文件路径,不输出明文密码
|
- [ ] 无需 `reset_admin`,重启服务即可拿到新密码
|
||||||
- [ ] 凭据文件权限为 `0600`
|
- [ ] `reset_admin` CLI 仍然可用作手动备选方案
|
||||||
- [ ] 凭据文件包含 email + password 行
|
|
||||||
- [ ] 该用户下次登录返回 `needs_setup=true`
|
|
||||||
|
|
||||||
#### TC-UPG-15: 未初始化 admin 期间普通用户注册策略边界
|
#### TC-UPG-15: 休眠 admin 期间普通用户注册
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# admin 尚不存在,普通用户尝试注册
|
# admin 存在但从未登录,普通用户先注册
|
||||||
curl -s -X POST $BASE/api/v1/auth/register \
|
curl -s -X POST $BASE/api/v1/auth/register \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{"email":"earlybird@example.com","password":"EarlyPass1!"}' \
|
-d '{"email":"earlybird@example.com","password":"EarlyPass1!"}' \
|
||||||
@@ -869,11 +833,11 @@ curl -s -X POST $BASE/api/v1/auth/register \
|
|||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 当前代码允许注册普通用户并自动登录(201,角色为 `user`)
|
- [ ] 注册成功(201),角色为 `user`
|
||||||
- [ ] 但 `setup-status` 仍为 `{"needs_setup": true}`,因为 admin 仍不存在
|
- [ ] 无法提权为 admin
|
||||||
- [ ] 这是一个产品策略边界:若要求“必须先有 admin”,需要在 `/register` 增加 admin-exists gate
|
- [ ] 普通用户的数据与 admin 隔离
|
||||||
|
|
||||||
#### TC-UPG-16: 普通用户数据与后续 admin 隔离
|
#### TC-UPG-16: 休眠 admin 不影响后续操作
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 普通用户正常创建 thread、发消息
|
# 普通用户正常创建 thread、发消息
|
||||||
@@ -885,13 +849,14 @@ curl -s -X POST $BASE/api/threads \
|
|||||||
-d '{"metadata":{}}' | jq .thread_id
|
-d '{"metadata":{}}' | jq .thread_id
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** 普通用户正常创建 thread;后续 admin 创建后,搜索不到该普通用户 thread
|
**预期:** 正常创建,不受休眠 admin 影响
|
||||||
|
|
||||||
#### TC-UPG-17: reset_admin 后完成 Setup
|
#### TC-UPG-17: 休眠 admin 最终完成 Setup
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# 运维终于登录
|
||||||
curl -s -X POST $BASE/api/v1/auth/login/local \
|
curl -s -X POST $BASE/api/v1/auth/login/local \
|
||||||
-d "username=admin@example.com&password=<凭据文件密码>" \
|
-d "username=admin@deerflow.dev&password=<P0或P1>" \
|
||||||
-c admin.txt | jq .needs_setup
|
-c admin.txt | jq .needs_setup
|
||||||
# 预期: true
|
# 预期: true
|
||||||
|
|
||||||
@@ -901,7 +866,7 @@ curl -s -X POST $BASE/api/v1/auth/change-password \
|
|||||||
-b admin.txt \
|
-b admin.txt \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "X-CSRF-Token: $CSRF" \
|
-H "X-CSRF-Token: $CSRF" \
|
||||||
-d '{"current_password":"<凭据文件密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \
|
-d '{"current_password":"<密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \
|
||||||
-c admin.txt
|
-c admin.txt
|
||||||
|
|
||||||
# 验证
|
# 验证
|
||||||
@@ -911,7 +876,7 @@ curl -s $BASE/api/v1/auth/me -b admin.txt | jq '{email, needs_setup}'
|
|||||||
**预期:**
|
**预期:**
|
||||||
- [ ] `email` 变为 `admin@real.com`
|
- [ ] `email` 变为 `admin@real.com`
|
||||||
- [ ] `needs_setup` 变为 `false`
|
- [ ] `needs_setup` 变为 `false`
|
||||||
- [ ] 后续登录使用新密码
|
- [ ] 后续重启控制台不再有 warning
|
||||||
|
|
||||||
#### TC-UPG-18: 长期未用后 JWT 密钥轮换
|
#### TC-UPG-18: 长期未用后 JWT 密钥轮换
|
||||||
|
|
||||||
@@ -925,8 +890,8 @@ make stop && make dev
|
|||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] 服务正常启动
|
- [ ] 服务正常启动
|
||||||
- [ ] 账号密码仍可登录(密码存在 DB,与 JWT 密钥无关)
|
- [ ] 旧密码仍可登录(密码存在 DB,与 JWT 密钥无关)
|
||||||
- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)
|
- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)— 但因为从未登录过也没有旧 token
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -945,7 +910,7 @@ for i in 1 2 3; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
# 检查 admin 数量
|
# 检查 admin 数量
|
||||||
sqlite3 backend/.deer-flow/data/deerflow.db \
|
sqlite3 backend/.deer-flow/users.db \
|
||||||
"SELECT COUNT(*) FROM users WHERE system_role='admin';"
|
"SELECT COUNT(*) FROM users WHERE system_role='admin';"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1090,7 +1055,7 @@ curl -s -X POST $BASE/api/v1/auth/register \
|
|||||||
wait
|
wait
|
||||||
|
|
||||||
# 检查用户数
|
# 检查用户数
|
||||||
sqlite3 backend/.deer-flow/data/deerflow.db \
|
sqlite3 backend/.deer-flow/users.db \
|
||||||
"SELECT COUNT(*) FROM users WHERE email='race@example.com';"
|
"SELECT COUNT(*) FROM users WHERE email='race@example.com';"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1200,16 +1165,13 @@ curl -s -w "%{http_code}" -X DELETE "$BASE/api/threads/$TID" \
|
|||||||
```bash
|
```bash
|
||||||
cd backend
|
cd backend
|
||||||
python -m app.gateway.auth.reset_admin
|
python -m app.gateway.auth.reset_admin
|
||||||
cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p1.txt
|
# 记录密码 P1
|
||||||
P1=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p1.txt)
|
|
||||||
|
|
||||||
python -m app.gateway.auth.reset_admin
|
python -m app.gateway.auth.reset_admin
|
||||||
cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p2.txt
|
# 记录密码 P2
|
||||||
P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:**
|
**预期:**
|
||||||
- [ ] `.deer-flow/admin_initial_credentials.txt` 每次都会被重写,文件权限为 `0600`
|
|
||||||
- [ ] P1 ≠ P2(每次生成新随机密码)
|
- [ ] P1 ≠ P2(每次生成新随机密码)
|
||||||
- [ ] P1 不可用,只有 P2 有效
|
- [ ] P1 不可用,只有 P2 有效
|
||||||
- [ ] `token_version` 递增了 2
|
- [ ] `token_version` 递增了 2
|
||||||
@@ -1362,8 +1324,7 @@ done
|
|||||||
```bash
|
```bash
|
||||||
GW=http://localhost:8001
|
GW=http://localhost:8001
|
||||||
|
|
||||||
for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local \
|
for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local /api/v1/auth/register; do
|
||||||
/api/v1/auth/register /api/v1/auth/initialize /api/v1/auth/logout; do
|
|
||||||
echo "$path: $(curl -s -w '%{http_code}' -o /dev/null $GW$path)"
|
echo "$path: $(curl -s -w '%{http_code}' -o /dev/null $GW$path)"
|
||||||
done
|
done
|
||||||
# 预期: 200 或 405/422(方法不对但不是 401)
|
# 预期: 200 或 405/422(方法不对但不是 401)
|
||||||
@@ -1438,9 +1399,9 @@ done
|
|||||||
>
|
>
|
||||||
> 前置条件:
|
> 前置条件:
|
||||||
> - `.env` 中设置 `AUTH_JWT_SECRET`(否则每次容器重启 session 全部失效)
|
> - `.env` 中设置 `AUTH_JWT_SECRET`(否则每次容器重启 session 全部失效)
|
||||||
> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `deerflow.db`)
|
> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `users.db`)
|
||||||
|
|
||||||
#### TC-DOCKER-01: deerflow.db 通过 volume 持久化
|
#### TC-DOCKER-01: users.db 通过 volume 持久化
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 启动容器
|
# 启动容器
|
||||||
@@ -1455,13 +1416,13 @@ curl -s -X POST $BASE/api/v1/auth/register \
|
|||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{"email":"docker-test@example.com","password":"DockerTest1!"}' -w "\nHTTP %{http_code}"
|
-d '{"email":"docker-test@example.com","password":"DockerTest1!"}' -w "\nHTTP %{http_code}"
|
||||||
|
|
||||||
# 检查宿主机上的 deerflow.db
|
# 检查宿主机上的 users.db
|
||||||
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db
|
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db
|
||||||
sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db \
|
sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db \
|
||||||
"SELECT email FROM users WHERE email='docker-test@example.com';"
|
"SELECT email FROM users WHERE email='docker-test@example.com';"
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** deerflow.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。
|
**预期:** users.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。
|
||||||
|
|
||||||
#### TC-DOCKER-02: 重启容器后 session 保持
|
#### TC-DOCKER-02: 重启容器后 session 保持
|
||||||
|
|
||||||
@@ -1505,24 +1466,22 @@ done
|
|||||||
|
|
||||||
**已知限制:** In-process rate limiter 不跨 worker 共享。生产环境如需精确限速,需要 Redis 等外部存储。
|
**已知限制:** In-process rate limiter 不跨 worker 共享。生产环境如需精确限速,需要 Redis 等外部存储。
|
||||||
|
|
||||||
#### TC-DOCKER-04: IM 渠道使用内部认证
|
#### TC-DOCKER-04: IM 渠道不经过 auth
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 调 Gateway
|
# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 通信
|
||||||
# 请求携带 process-local internal auth header,并带匹配的 CSRF cookie/header
|
# 不走 nginx,不经过 AuthMiddleware
|
||||||
|
|
||||||
# 验证方式:检查 gateway 日志中 channel manager 的请求不包含 auth 错误
|
# 验证方式:检查 gateway 日志中 channel manager 的请求不包含 auth 错误
|
||||||
docker logs deer-flow-gateway 2>&1 | grep -E "ChannelManager|channel" | head -10
|
docker logs deer-flow-gateway 2>&1 | grep -E "ChannelManager|channel" | head -10
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** 无 auth 相关错误。渠道不依赖浏览器 cookie;服务端通过内部认证头把请求归入 `default` 用户桶。
|
**预期:** 无 auth 相关错误。渠道通过 `langgraph-sdk` 直连 LangGraph Server(`http://langgraph:2024`),不走 auth 层。
|
||||||
|
|
||||||
#### TC-DOCKER-05: reset_admin 密码写入 0600 凭证文件(不再走日志)
|
#### TC-DOCKER-05: admin 密码写入 0600 凭证文件(不再走日志)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 首次启动不会自动生成 admin 密码。先重置已有 admin,凭据文件写在挂载到宿主机的 DEER_FLOW_HOME 下。
|
# 凭证文件写在挂载到宿主机的 DEER_FLOW_HOME 下
|
||||||
docker exec deer-flow-gateway python -m app.gateway.auth.reset_admin --email docker-test@example.com
|
|
||||||
|
|
||||||
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/admin_initial_credentials.txt
|
ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/admin_initial_credentials.txt
|
||||||
# 预期文件权限: -rw------- (0600)
|
# 预期文件权限: -rw------- (0600)
|
||||||
|
|
||||||
@@ -1553,15 +1512,14 @@ sleep 15
|
|||||||
docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l
|
docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l
|
||||||
# 预期: 0
|
# 预期: 0
|
||||||
|
|
||||||
# auth 流程正常:未登录受保护接口返回 401
|
# auth 流程正常
|
||||||
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
|
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
|
||||||
# 预期: 401
|
# 预期: 401
|
||||||
|
|
||||||
curl -s -X POST $BASE/api/v1/auth/initialize \
|
curl -s -X POST $BASE/api/v1/auth/login/local \
|
||||||
-H "Content-Type: application/json" \
|
-d "username=admin@deerflow.dev&password=<日志密码>" \
|
||||||
-d '{"email":"admin@example.com","password":"AdminPass1!"}' \
|
|
||||||
-c cookies.txt -w "\nHTTP %{http_code}"
|
-c cookies.txt -w "\nHTTP %{http_code}"
|
||||||
# 预期: 201
|
# 预期: 200
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7.4 补充边界用例
|
### 7.4 补充边界用例
|
||||||
@@ -1629,15 +1587,13 @@ curl -s -D - -X POST $BASE/api/v1/auth/login/local \
|
|||||||
#### TC-EDGE-05: HTTP 无 max_age / HTTPS 有 max_age
|
#### TC-EDGE-05: HTTP 无 max_age / HTTPS 有 max_age
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
GW=http://localhost:8001
|
|
||||||
|
|
||||||
# HTTP
|
# HTTP
|
||||||
curl -s -D - -X POST $GW/api/v1/auth/login/local \
|
curl -s -D - -X POST $BASE/api/v1/auth/login/local \
|
||||||
-d "username=admin@example.com&password=正确密码" 2>/dev/null \
|
-d "username=admin@example.com&password=正确密码" 2>/dev/null \
|
||||||
| grep "access_token=" | grep -oi "max-age=[0-9]*" || echo "NO max-age (HTTP session cookie)"
|
| grep "access_token=" | grep -oi "max-age=[0-9]*" || echo "NO max-age (HTTP session cookie)"
|
||||||
|
|
||||||
# HTTPS:直连 Gateway 才能用 X-Forwarded-Proto 模拟 HTTPS;nginx 会覆盖该 header
|
# HTTPS
|
||||||
curl -s -D - -X POST $GW/api/v1/auth/login/local \
|
curl -s -D - -X POST $BASE/api/v1/auth/login/local \
|
||||||
-H "X-Forwarded-Proto: https" \
|
-H "X-Forwarded-Proto: https" \
|
||||||
-d "username=admin@example.com&password=正确密码" 2>/dev/null \
|
-d "username=admin@example.com&password=正确密码" 2>/dev/null \
|
||||||
| grep "access_token=" | grep -oi "max-age=[0-9]*"
|
| grep "access_token=" | grep -oi "max-age=[0-9]*"
|
||||||
@@ -1756,10 +1712,10 @@ curl -s -X POST $BASE/api/threads \
|
|||||||
-b cookies.txt \
|
-b cookies.txt \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "X-CSRF-Token: $CSRF" \
|
-H "X-CSRF-Token: $CSRF" \
|
||||||
-d '{"metadata":{"owner_id":"victim-user-id","user_id":"victim-user-id"}}' | jq .metadata
|
-d '{"metadata":{"owner_id":"victim-user-id"}}' | jq .metadata.owner_id
|
||||||
```
|
```
|
||||||
|
|
||||||
**预期:** 返回的 `metadata` 不包含 `owner_id` 或 `user_id`。真实所有权写入 `threads_meta.user_id`,不从客户端 metadata 接收,也不通过 metadata 回显。
|
**预期:** 返回的 `metadata.owner_id` 应为当前登录用户的 ID,不是请求中注入的 `victim-user-id`。服务端应覆盖客户端提供的 `user_id`。
|
||||||
|
|
||||||
#### 7.5.6 HTTP Method 探测
|
#### 7.5.6 HTTP Method 探测
|
||||||
|
|
||||||
@@ -1840,6 +1796,6 @@ cd backend && PYTHONPATH=. uv run pytest \
|
|||||||
# 核心接口冒烟
|
# 核心接口冒烟
|
||||||
curl -s $BASE/health # 200
|
curl -s $BASE/health # 200
|
||||||
curl -s $BASE/api/models # 401 (无 cookie)
|
curl -s $BASE/api/models # 401 (无 cookie)
|
||||||
curl -s $BASE/api/v1/auth/setup-status # 200
|
curl -s -X POST $BASE/api/v1/auth/setup-status # 200
|
||||||
curl -s $BASE/api/v1/auth/me -b cookies.txt # 200 (有 cookie)
|
curl -s $BASE/api/v1/auth/me -b cookies.txt # 200 (有 cookie)
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -2,16 +2,13 @@
|
|||||||
|
|
||||||
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
|
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
|
||||||
|
|
||||||
完整设计见 [AUTH_DESIGN.md](AUTH_DESIGN.md)。
|
|
||||||
|
|
||||||
## 核心概念
|
## 核心概念
|
||||||
|
|
||||||
认证模块采用**始终强制**策略:
|
认证模块采用**始终强制**策略:
|
||||||
|
|
||||||
- 首次启动时不会自动创建账号;首次访问 `/setup` 时由操作者创建第一个 admin 账号
|
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
|
||||||
- 认证从一开始就是强制的,无竞争窗口
|
- 认证从一开始就是强制的,无竞争窗口
|
||||||
- 已有 admin 后,服务启动时会把历史对话(升级前创建且缺少 `user_id` 的 thread)迁移到 admin 名下
|
- 历史对话(升级前创建的 thread)自动迁移到 admin 名下
|
||||||
- 新数据按用户隔离:thread、workspace/uploads/outputs、memory、自定义 agent 都归属当前用户
|
|
||||||
|
|
||||||
## 升级步骤
|
## 升级步骤
|
||||||
|
|
||||||
@@ -28,41 +25,39 @@ cd backend && make install
|
|||||||
make dev
|
make dev
|
||||||
```
|
```
|
||||||
|
|
||||||
如果没有 admin 账号,控制台只会提示:
|
控制台会输出:
|
||||||
|
|
||||||
```
|
```
|
||||||
============================================================
|
============================================================
|
||||||
First boot detected — no admin account exists.
|
Admin account created on first boot
|
||||||
Visit /setup to complete admin account creation.
|
Email: admin@deerflow.dev
|
||||||
|
Password: aB3xK9mN_pQ7rT2w
|
||||||
|
Change it after login: Settings → Account
|
||||||
============================================================
|
============================================================
|
||||||
```
|
```
|
||||||
|
|
||||||
首次启动不会在日志里打印随机密码,也不会写入默认 admin。这样避免启动日志泄露凭据,也避免在操作者创建账号前出现可被猜测的默认身份。
|
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
|
||||||
|
|
||||||
### 3. 创建 admin
|
### 3. 登录
|
||||||
|
|
||||||
访问 `http://localhost:2026/setup`,填写邮箱和密码创建第一个 admin 账号。创建成功后会自动登录并进入 workspace。
|
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
|
||||||
|
|
||||||
如果这是从无认证版本升级,创建 admin 后重启一次服务,让启动迁移把缺少 `user_id` 的历史 thread 归属到 admin。
|
### 4. 修改密码
|
||||||
|
|
||||||
### 4. 登录
|
登录后进入 Settings → Account → Change Password。
|
||||||
|
|
||||||
后续访问 `http://localhost:2026/login`,使用已创建的邮箱和密码登录。
|
|
||||||
|
|
||||||
### 5. 添加用户(可选)
|
### 5. 添加用户(可选)
|
||||||
|
|
||||||
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话、上传文件、输出文件、memory 和自定义 agent。
|
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
|
||||||
|
|
||||||
## 安全机制
|
## 安全机制
|
||||||
|
|
||||||
| 机制 | 说明 |
|
| 机制 | 说明 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
|
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
|
||||||
| CSRF Double Submit Cookie | 受保护的 POST/PUT/PATCH/DELETE 请求需携带 `X-CSRF-Token`;登录/注册/初始化/登出走 auth 端点 Origin 校验 |
|
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
|
||||||
| bcrypt 密码哈希 | 密码不以明文存储 |
|
| bcrypt 密码哈希 | 密码不以明文存储 |
|
||||||
| Thread owner filter | `threads_meta.user_id` 由服务端认证上下文写入,搜索、读取、更新、删除默认按当前用户过滤 |
|
| 多租户隔离 | 用户只能访问自己的 thread |
|
||||||
| 文件系统隔离 | 线程数据写入 `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/`,sandbox 内统一映射为 `/mnt/user-data/` |
|
|
||||||
| Memory / agent 隔离 | 用户 memory 和自定义 agent 写入 `{base_dir}/users/{user_id}/...`;旧共享 agent 只作为只读兼容回退 |
|
|
||||||
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
|
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
|
||||||
|
|
||||||
## 常见操作
|
## 常见操作
|
||||||
@@ -79,27 +74,23 @@ python -m app.gateway.auth.reset_admin
|
|||||||
python -m app.gateway.auth.reset_admin --email user@example.com
|
python -m app.gateway.auth.reset_admin --email user@example.com
|
||||||
```
|
```
|
||||||
|
|
||||||
会把新的随机密码写入 `.deer-flow/admin_initial_credentials.txt`,文件权限为 `0600`。命令行只输出文件路径,不输出明文密码。
|
会输出新的随机密码。
|
||||||
|
|
||||||
### 完全重置
|
### 完全重置
|
||||||
|
|
||||||
删除统一 SQLite 数据库,重启后重新访问 `/setup` 创建新 admin:
|
删除用户数据库,重启后自动创建新 admin:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
rm -f backend/.deer-flow/data/deerflow.db
|
rm -f backend/.deer-flow/users.db
|
||||||
# 重启服务后访问 http://localhost:2026/setup
|
# 重启服务,控制台输出新密码
|
||||||
```
|
```
|
||||||
|
|
||||||
## 数据存储
|
## 数据存储
|
||||||
|
|
||||||
| 文件 | 内容 |
|
| 文件 | 内容 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库(users、threads_meta、runs、feedback 等应用数据) |
|
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
|
||||||
| `.deer-flow/users/{user_id}/threads/{thread_id}/user-data/` | 用户线程的 workspace、uploads、outputs |
|
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||||
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
|
|
||||||
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
|
|
||||||
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
|
|
||||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持) |
|
|
||||||
|
|
||||||
### 生产环境建议
|
### 生产环境建议
|
||||||
|
|
||||||
@@ -120,21 +111,19 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
|
|||||||
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
|
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
|
||||||
| `/api/v1/auth/change-password` | POST | 修改密码 |
|
| `/api/v1/auth/change-password` | POST | 修改密码 |
|
||||||
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
|
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
|
||||||
| `/api/v1/auth/initialize` | POST | 首次初始化第一个 admin(仅无 admin 时可调用) |
|
|
||||||
|
|
||||||
## 兼容性
|
## 兼容性
|
||||||
|
|
||||||
- **标准模式**(`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化
|
- **标准模式**(`make dev`):完全兼容,admin 自动创建
|
||||||
- **Gateway 模式**(`make dev-pro`):完全兼容
|
- **Gateway 模式**(`make dev-pro`):完全兼容
|
||||||
- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载
|
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
|
||||||
- **IM 渠道**(Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶
|
- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
|
||||||
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
||||||
|
|
||||||
## 故障排查
|
## 故障排查
|
||||||
|
|
||||||
| 症状 | 原因 | 解决 |
|
| 症状 | 原因 | 解决 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
|
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
|
||||||
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
|
|
||||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||||
| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 |
|
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||||
|
|||||||
@@ -259,8 +259,6 @@ sandbox:
|
|||||||
|
|
||||||
When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` values in the agent prompt so the agent can discover and operate on mounted directories directly instead of assuming everything must live under `/mnt/user-data`.
|
When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` values in the agent prompt so the agent can discover and operate on mounted directories directly instead of assuming everything must live under `/mnt/user-data`.
|
||||||
|
|
||||||
For bare-metal Docker sandbox runs that use localhost, DeerFlow binds the sandbox HTTP port to `127.0.0.1` by default so it is not exposed on every host interface. Docker-outside-of-Docker deployments that connect through `host.docker.internal` keep the broad legacy bind for compatibility. Set `DEER_FLOW_SANDBOX_BIND_HOST` explicitly if your deployment needs a different bind address.
|
|
||||||
|
|
||||||
### Skills
|
### Skills
|
||||||
|
|
||||||
Configure the skills directory for specialized workflows:
|
Configure the skills directory for specialized workflows:
|
||||||
@@ -321,16 +319,11 @@ models:
|
|||||||
- `DEEPSEEK_API_KEY` - DeepSeek API key
|
- `DEEPSEEK_API_KEY` - DeepSeek API key
|
||||||
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
||||||
- `TAVILY_API_KEY` - Tavily search API key
|
- `TAVILY_API_KEY` - Tavily search API key
|
||||||
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
|
|
||||||
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
|
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
|
||||||
- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path
|
|
||||||
- `DEER_FLOW_HOME` - Runtime state directory (defaults to `.deer-flow` under the project root)
|
|
||||||
- `DEER_FLOW_SKILLS_PATH` - Skills directory when `skills.path` is omitted
|
|
||||||
- `GATEWAY_ENABLE_DOCS` - Set to `false` to disable Swagger UI (`/docs`), ReDoc (`/redoc`), and OpenAPI schema (`/openapi.json`) endpoints (default: `true`)
|
|
||||||
|
|
||||||
## Configuration Location
|
## Configuration Location
|
||||||
|
|
||||||
The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`). Set `DEER_FLOW_PROJECT_ROOT` when the process may start from another working directory, or set `DEER_FLOW_CONFIG_PATH` to point at a specific file.
|
The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`), not in the backend directory.
|
||||||
|
|
||||||
## Configuration Priority
|
## Configuration Priority
|
||||||
|
|
||||||
@@ -338,12 +331,12 @@ DeerFlow searches for configuration in this order:
|
|||||||
|
|
||||||
1. Path specified in code via `config_path` argument
|
1. Path specified in code via `config_path` argument
|
||||||
2. Path from `DEER_FLOW_CONFIG_PATH` environment variable
|
2. Path from `DEER_FLOW_CONFIG_PATH` environment variable
|
||||||
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or under the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
|
3. `config.yaml` in current working directory (typically `backend/` when running)
|
||||||
4. Legacy backend/repository-root locations for monorepo compatibility
|
4. `config.yaml` in parent directory (project root: `deer-flow/`)
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
1. **Place `config.yaml` in project root** - Set `DEER_FLOW_PROJECT_ROOT` if the runtime starts elsewhere
|
1. **Place `config.yaml` in project root** - Not in `backend/` directory
|
||||||
2. **Never commit `config.yaml`** - It's already in `.gitignore`
|
2. **Never commit `config.yaml`** - It's already in `.gitignore`
|
||||||
3. **Use environment variables for secrets** - Don't hardcode API keys
|
3. **Use environment variables for secrets** - Don't hardcode API keys
|
||||||
4. **Keep `config.example.yaml` updated** - Document all new options
|
4. **Keep `config.example.yaml` updated** - Document all new options
|
||||||
@@ -354,7 +347,7 @@ DeerFlow searches for configuration in this order:
|
|||||||
|
|
||||||
### "Config file not found"
|
### "Config file not found"
|
||||||
- Ensure `config.yaml` exists in the **project root** directory (`deer-flow/config.yaml`)
|
- Ensure `config.yaml` exists in the **project root** directory (`deer-flow/config.yaml`)
|
||||||
- If the runtime starts outside the project root, set `DEER_FLOW_PROJECT_ROOT`
|
- The backend searches parent directory by default, so root location is preferred
|
||||||
- Alternatively, set `DEER_FLOW_CONFIG_PATH` environment variable to custom location
|
- Alternatively, set `DEER_FLOW_CONFIG_PATH` environment variable to custom location
|
||||||
|
|
||||||
### "Invalid API key"
|
### "Invalid API key"
|
||||||
@@ -364,7 +357,7 @@ DeerFlow searches for configuration in this order:
|
|||||||
### "Skills not loading"
|
### "Skills not loading"
|
||||||
- Check that `deer-flow/skills/` directory exists
|
- Check that `deer-flow/skills/` directory exists
|
||||||
- Verify skills have valid `SKILL.md` files
|
- Verify skills have valid `SKILL.md` files
|
||||||
- Check `skills.path` or `DEER_FLOW_SKILLS_PATH` if using a custom path
|
- Check `skills.path` configuration if using custom path
|
||||||
|
|
||||||
### "Docker sandbox fails to start"
|
### "Docker sandbox fails to start"
|
||||||
- Ensure Docker is running
|
- Ensure Docker is running
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
## 概述
|
## 概述
|
||||||
|
|
||||||
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并可选地将 Office 文档和 PDF 转换为 Markdown 格式。
|
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并自动将 Office 文档和 PDF 转换为 Markdown 格式。
|
||||||
|
|
||||||
## 功能特性
|
## 功能特性
|
||||||
|
|
||||||
- ✅ 支持多文件同时上传
|
- ✅ 支持多文件同时上传
|
||||||
- ✅ 可选地转换文档为 Markdown(PDF、PPT、Excel、Word)
|
- ✅ 自动转换文档为 Markdown(PDF、PPT、Excel、Word)
|
||||||
- ✅ 文件存储在线程隔离的目录中
|
- ✅ 文件存储在线程隔离的目录中
|
||||||
- ✅ Agent 自动感知已上传的文件
|
- ✅ Agent 自动感知已上传的文件
|
||||||
- ✅ 支持文件列表查询和删除
|
- ✅ 支持文件列表查询和删除
|
||||||
@@ -22,8 +22,6 @@ POST /api/threads/{thread_id}/uploads
|
|||||||
**请求体:** `multipart/form-data`
|
**请求体:** `multipart/form-data`
|
||||||
- `files`: 一个或多个文件
|
- `files`: 一个或多个文件
|
||||||
|
|
||||||
网关会在应用层限制上传规模,默认最多 10 个文件、单文件 50 MiB、单次请求总计 100 MiB。可通过 `config.yaml` 的 `uploads.max_files`、`uploads.max_file_size`、`uploads.max_total_size` 调整;前端会读取同一组限制并在选择文件时提示,超过限制时后端返回 `413 Payload Too Large`。
|
|
||||||
|
|
||||||
**响应:**
|
**响应:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -50,23 +48,7 @@ POST /api/threads/{thread_id}/uploads
|
|||||||
- `virtual_path`: Agent 在沙箱中使用的虚拟路径
|
- `virtual_path`: Agent 在沙箱中使用的虚拟路径
|
||||||
- `artifact_url`: 前端通过 HTTP 访问文件的 URL
|
- `artifact_url`: 前端通过 HTTP 访问文件的 URL
|
||||||
|
|
||||||
### 2. 查询上传限制
|
### 2. 列出已上传文件
|
||||||
```
|
|
||||||
GET /api/threads/{thread_id}/uploads/limits
|
|
||||||
```
|
|
||||||
|
|
||||||
返回网关当前生效的上传限制,供前端在用户选择文件前提示和拦截。
|
|
||||||
|
|
||||||
**响应:**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"max_files": 10,
|
|
||||||
"max_file_size": 52428800,
|
|
||||||
"max_total_size": 104857600
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 列出已上传文件
|
|
||||||
```
|
```
|
||||||
GET /api/threads/{thread_id}/uploads/list
|
GET /api/threads/{thread_id}/uploads/list
|
||||||
```
|
```
|
||||||
@@ -89,7 +71,7 @@ GET /api/threads/{thread_id}/uploads/list
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. 删除文件
|
### 3. 删除文件
|
||||||
```
|
```
|
||||||
DELETE /api/threads/{thread_id}/uploads/{filename}
|
DELETE /api/threads/{thread_id}/uploads/{filename}
|
||||||
```
|
```
|
||||||
@@ -104,7 +86,7 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
|
|||||||
|
|
||||||
## 支持的文档格式
|
## 支持的文档格式
|
||||||
|
|
||||||
以下格式在显式启用 `uploads.auto_convert_documents: true` 时会自动转换为 Markdown:
|
以下格式会自动转换为 Markdown:
|
||||||
- PDF (`.pdf`)
|
- PDF (`.pdf`)
|
||||||
- PowerPoint (`.ppt`, `.pptx`)
|
- PowerPoint (`.ppt`, `.pptx`)
|
||||||
- Excel (`.xls`, `.xlsx`)
|
- Excel (`.xls`, `.xlsx`)
|
||||||
@@ -112,8 +94,6 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
|
|||||||
|
|
||||||
转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。
|
转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。
|
||||||
|
|
||||||
默认情况下,自动转换是关闭的,以避免在网关主机上对不受信任的 Office/PDF 上传执行解析。只有在受信任部署中明确接受此风险时,才应将 `uploads.auto_convert_documents` 设置为 `true`。
|
|
||||||
|
|
||||||
## Agent 集成
|
## Agent 集成
|
||||||
|
|
||||||
### 自动文件列举
|
### 自动文件列举
|
||||||
@@ -227,7 +207,6 @@ backend/.deer-flow/threads/
|
|||||||
- 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size`)
|
- 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size`)
|
||||||
- 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击
|
- 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击
|
||||||
- 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问
|
- 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问
|
||||||
- 自动文档转换默认关闭;如需启用,需在 `config.yaml` 中显式设置 `uploads.auto_convert_documents: true`
|
|
||||||
|
|
||||||
## 技术实现
|
## 技术实现
|
||||||
|
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ These are the tool names your provider will see in `request.tool_name`:
|
|||||||
| `web_search` | Web search query |
|
| `web_search` | Web search query |
|
||||||
| `web_fetch` | Fetch URL content |
|
| `web_fetch` | Fetch URL content |
|
||||||
| `image_search` | Image search |
|
| `image_search` | Image search |
|
||||||
| `present_files` | Present file to user |
|
| `present_file` | Present file to user |
|
||||||
| `view_image` | Display image |
|
| `view_image` | Display image |
|
||||||
| `ask_clarification` | Ask user a question |
|
| `ask_clarification` | Ask user a question |
|
||||||
| `task` | Delegate to subagent |
|
| `task` | Delegate to subagent |
|
||||||
|
|||||||
@@ -0,0 +1,343 @@
|
|||||||
|
# DeerFlow 后端拆分设计文档:Harness + App
|
||||||
|
|
||||||
|
> 状态:Draft
|
||||||
|
> 作者:DeerFlow Team
|
||||||
|
> 日期:2026-03-13
|
||||||
|
|
||||||
|
## 1. 背景与动机
|
||||||
|
|
||||||
|
DeerFlow 后端当前是一个单一 Python 包(`src.*`),包含了从底层 agent 编排到上层用户产品的所有代码。随着项目发展,这种结构带来了几个问题:
|
||||||
|
|
||||||
|
- **复用困难**:其他产品(CLI 工具、Slack bot、第三方集成)想用 agent 能力,必须依赖整个后端,包括 FastAPI、IM SDK 等不需要的依赖
|
||||||
|
- **职责模糊**:agent 编排逻辑和用户产品逻辑混在同一个 `src/` 下,边界不清晰
|
||||||
|
- **依赖膨胀**:LangGraph Server 运行时不需要 FastAPI/uvicorn/Slack SDK,但当前必须安装全部依赖
|
||||||
|
|
||||||
|
本文档提出将后端拆分为两部分:**deerflow-harness**(可发布的 agent 框架包)和 **app**(不打包的用户产品代码)。
|
||||||
|
|
||||||
|
## 2. 核心概念
|
||||||
|
|
||||||
|
### 2.1 Harness(线束/框架层)
|
||||||
|
|
||||||
|
Harness 是 agent 的构建与编排框架,回答 **"如何构建和运行 agent"** 的问题:
|
||||||
|
|
||||||
|
- Agent 工厂与生命周期管理
|
||||||
|
- Middleware pipeline
|
||||||
|
- 工具系统(内置工具 + MCP + 社区工具)
|
||||||
|
- 沙箱执行环境
|
||||||
|
- 子 agent 委派
|
||||||
|
- 记忆系统
|
||||||
|
- 技能加载与注入
|
||||||
|
- 模型工厂
|
||||||
|
- 配置系统
|
||||||
|
|
||||||
|
**Harness 是一个可发布的 Python 包**(`deerflow-harness`),可以独立安装和使用。
|
||||||
|
|
||||||
|
**Harness 的设计原则**:对上层应用完全无感知。它不知道也不关心谁在调用它——可以是 Web App、CLI、Slack Bot、或者一个单元测试。
|
||||||
|
|
||||||
|
### 2.2 App(应用层)
|
||||||
|
|
||||||
|
App 是面向用户的产品代码,回答 **"如何将 agent 呈现给用户"** 的问题:
|
||||||
|
|
||||||
|
- Gateway API(FastAPI REST 接口)
|
||||||
|
- IM Channels(飞书、Slack、Telegram 集成)
|
||||||
|
- Custom Agent 的 CRUD 管理
|
||||||
|
- 文件上传/下载的 HTTP 接口
|
||||||
|
|
||||||
|
**App 不打包、不发布**,它是 DeerFlow 项目内部的应用代码,直接运行。
|
||||||
|
|
||||||
|
**App 依赖 Harness,但 Harness 不依赖 App。**
|
||||||
|
|
||||||
|
### 2.3 边界划分
|
||||||
|
|
||||||
|
| 模块 | 归属 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `config/` | Harness | 配置系统是基础设施 |
|
||||||
|
| `reflection/` | Harness | 动态模块加载工具 |
|
||||||
|
| `utils/` | Harness | 通用工具函数 |
|
||||||
|
| `agents/` | Harness | Agent 工厂、middleware、state、memory |
|
||||||
|
| `subagents/` | Harness | 子 agent 委派系统 |
|
||||||
|
| `sandbox/` | Harness | 沙箱执行环境 |
|
||||||
|
| `tools/` | Harness | 工具注册与发现 |
|
||||||
|
| `mcp/` | Harness | MCP 协议集成 |
|
||||||
|
| `skills/` | Harness | 技能加载、解析、定义 schema |
|
||||||
|
| `models/` | Harness | LLM 模型工厂 |
|
||||||
|
| `community/` | Harness | 社区工具(tavily、jina 等) |
|
||||||
|
| `client.py` | Harness | 嵌入式 Python 客户端 |
|
||||||
|
| `gateway/` | App | FastAPI REST API |
|
||||||
|
| `channels/` | App | IM 平台集成 |
|
||||||
|
|
||||||
|
**关于 Custom Agents**:agent 定义格式(`config.yaml` + `SOUL.md` schema)由 Harness 层的 `config/agents_config.py` 定义,但文件的存储、CRUD、发现机制由 App 层的 `gateway/routers/agents.py` 负责。
|
||||||
|
|
||||||
|
## 3. 目标架构
|
||||||
|
|
||||||
|
### 3.1 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/
|
||||||
|
├── packages/
|
||||||
|
│ └── harness/
|
||||||
|
│ ├── pyproject.toml # deerflow-harness 包定义
|
||||||
|
│ └── deerflow/ # Python 包根(import 前缀: deerflow.*)
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── config/
|
||||||
|
│ ├── reflection/
|
||||||
|
│ ├── utils/
|
||||||
|
│ ├── agents/
|
||||||
|
│ │ ├── lead_agent/
|
||||||
|
│ │ ├── middlewares/
|
||||||
|
│ │ ├── memory/
|
||||||
|
│ │ ├── checkpointer/
|
||||||
|
│ │ └── thread_state.py
|
||||||
|
│ ├── subagents/
|
||||||
|
│ ├── sandbox/
|
||||||
|
│ ├── tools/
|
||||||
|
│ ├── mcp/
|
||||||
|
│ ├── skills/
|
||||||
|
│ ├── models/
|
||||||
|
│ ├── community/
|
||||||
|
│ └── client.py
|
||||||
|
├── app/ # 不打包(import 前缀: app.*)
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── gateway/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── app.py
|
||||||
|
│ │ ├── config.py
|
||||||
|
│ │ ├── path_utils.py
|
||||||
|
│ │ └── routers/
|
||||||
|
│ └── channels/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── base.py
|
||||||
|
│ ├── manager.py
|
||||||
|
│ ├── service.py
|
||||||
|
│ ├── store.py
|
||||||
|
│ ├── message_bus.py
|
||||||
|
│ ├── feishu.py
|
||||||
|
│ ├── slack.py
|
||||||
|
│ └── telegram.py
|
||||||
|
├── pyproject.toml # uv workspace root
|
||||||
|
├── langgraph.json
|
||||||
|
├── tests/
|
||||||
|
├── docs/
|
||||||
|
└── Makefile
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 Import 规则
|
||||||
|
|
||||||
|
两个层使用不同的 import 前缀,职责边界一目了然:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
# Harness 内部互相引用(deerflow.* 前缀)
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
from deerflow.agents import make_lead_agent
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
from deerflow.tools import get_available_tools
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
# App 内部互相引用(app.* 前缀)
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
from app.gateway.app import app
|
||||||
|
from app.gateway.routers.uploads import upload_files
|
||||||
|
from app.channels.service import start_channel_service
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
# App 调用 Harness(单向依赖,Harness 永远不 import app)
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
from deerflow.agents import make_lead_agent
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
from deerflow.skills import load_skills
|
||||||
|
from deerflow.config.extensions_config import get_extensions_config
|
||||||
|
```
|
||||||
|
|
||||||
|
**App 调用 Harness 示例 — Gateway 中启动 agent**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/gateway/routers/chat.py
|
||||||
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
async def create_chat_session(thread_id: str, model_name: str):
|
||||||
|
config = get_app_config()
|
||||||
|
model = create_chat_model(name=model_name)
|
||||||
|
agent = make_lead_agent(config=...)
|
||||||
|
# ... 使用 agent 处理用户消息
|
||||||
|
```
|
||||||
|
|
||||||
|
**App 调用 Harness 示例 — Channel 中查询 skills**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/channels/manager.py
|
||||||
|
from deerflow.skills import load_skills
|
||||||
|
from deerflow.agents.memory.updater import get_memory_data
|
||||||
|
|
||||||
|
def handle_status_command():
|
||||||
|
skills = load_skills(enabled_only=True)
|
||||||
|
memory = get_memory_data()
|
||||||
|
return f"Skills: {len(skills)}, Memory facts: {len(memory.get('facts', []))}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**禁止方向**:Harness 代码中绝不能出现 `from app.` 或 `import app.`。
|
||||||
|
|
||||||
|
### 3.3 为什么 App 不打包
|
||||||
|
|
||||||
|
| 方面 | 打包(放 packages/ 下) | 不打包(放 backend/app/) |
|
||||||
|
|------|------------------------|--------------------------|
|
||||||
|
| 命名空间 | 需要 pkgutil `extend_path` 合并,或独立前缀 | 天然独立,`app.*` vs `deerflow.*` |
|
||||||
|
| 发布需求 | 没有——App 是项目内部代码 | 不需要 pyproject.toml |
|
||||||
|
| 复杂度 | 需要管理两个包的构建、版本、依赖声明 | 直接运行,零额外配置 |
|
||||||
|
| 运行方式 | `pip install deerflow-app` | `PYTHONPATH=. uvicorn app.gateway.app:app` |
|
||||||
|
|
||||||
|
App 的唯一消费者是 DeerFlow 项目自身,没有独立发布的需求。放在 `backend/app/` 下作为普通 Python 包,通过 `PYTHONPATH` 或 editable install 让 Python 找到即可。
|
||||||
|
|
||||||
|
### 3.4 依赖关系
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ app/ (不打包,直接运行) │
|
||||||
|
│ ├── fastapi, uvicorn │
|
||||||
|
│ ├── slack-sdk, lark-oapi, ... │
|
||||||
|
│ └── import deerflow.* │
|
||||||
|
└──────────────┬──────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ deerflow-harness (可发布的包) │
|
||||||
|
│ ├── langgraph, langchain │
|
||||||
|
│ ├── markitdown, pydantic, ... │
|
||||||
|
│ └── 零 app 依赖 │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**依赖分类**:
|
||||||
|
|
||||||
|
| 分类 | 依赖包 |
|
||||||
|
|------|--------|
|
||||||
|
| Harness only | agent-sandbox, langchain*, langgraph*, markdownify, markitdown, pydantic, pyyaml, readabilipy, tavily-python, firecrawl-py, tiktoken, ddgs, duckdb, httpx, kubernetes, dotenv |
|
||||||
|
| App only | fastapi, uvicorn, sse-starlette, python-multipart, lark-oapi, slack-sdk, python-telegram-bot, markdown-to-mrkdwn |
|
||||||
|
| Shared | langgraph-sdk(channels 用 HTTP client), pydantic, httpx |
|
||||||
|
|
||||||
|
### 3.5 Workspace 配置
|
||||||
|
|
||||||
|
`backend/pyproject.toml`(workspace root):
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "deer-flow"
|
||||||
|
version = "0.1.0"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = ["deerflow-harness"]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||||
|
# App 的额外依赖(fastapi 等)也声明在 workspace root,因为 app 不打包
|
||||||
|
app = ["fastapi", "uvicorn", "sse-starlette", "python-multipart"]
|
||||||
|
channels = ["lark-oapi", "slack-sdk", "python-telegram-bot"]
|
||||||
|
|
||||||
|
[tool.uv.workspace]
|
||||||
|
members = ["packages/harness"]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
deerflow-harness = { workspace = true }
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 当前的跨层依赖问题
|
||||||
|
|
||||||
|
在拆分之前,需要先解决 `client.py` 中两处从 harness 到 app 的反向依赖:
|
||||||
|
|
||||||
|
### 4.1 `_validate_skill_frontmatter`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# client.py — harness 导入了 app 层代码
|
||||||
|
from src.gateway.routers.skills import _validate_skill_frontmatter
|
||||||
|
```
|
||||||
|
|
||||||
|
**解决方案**:将该函数提取到 `deerflow/skills/validation.py`。这是一个纯逻辑函数(解析 YAML frontmatter、校验字段),与 FastAPI 无关。
|
||||||
|
|
||||||
|
### 4.2 `CONVERTIBLE_EXTENSIONS` + `convert_file_to_markdown`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# client.py — harness 导入了 app 层代码
|
||||||
|
from src.gateway.routers.uploads import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
|
||||||
|
```
|
||||||
|
|
||||||
|
**解决方案**:将它们提取到 `deerflow/utils/file_conversion.py`。仅依赖 `markitdown` + `pathlib`,是通用工具函数。
|
||||||
|
|
||||||
|
## 5. 基础设施变更
|
||||||
|
|
||||||
|
### 5.1 LangGraph Server
|
||||||
|
|
||||||
|
LangGraph Server 只需要 harness 包。`langgraph.json` 更新:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"dependencies": ["./packages/harness"],
|
||||||
|
"graphs": {
|
||||||
|
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||||
|
},
|
||||||
|
"checkpointer": {
|
||||||
|
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Gateway API
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# serve.sh / Makefile
|
||||||
|
# PYTHONPATH 包含 backend/ 根目录,使 app.* 和 deerflow.* 都能被找到
|
||||||
|
PYTHONPATH=. uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 Nginx
|
||||||
|
|
||||||
|
无需变更(只做 URL 路由,不涉及 Python 模块路径)。
|
||||||
|
|
||||||
|
### 5.4 Docker
|
||||||
|
|
||||||
|
Dockerfile 中的 module 引用从 `src.` 改为 `deerflow.` / `app.`,`COPY` 命令需覆盖 `packages/` 和 `app/` 目录。
|
||||||
|
|
||||||
|
## 6. 实施计划
|
||||||
|
|
||||||
|
分 3 个 PR 递进执行:
|
||||||
|
|
||||||
|
### PR 1:提取共享工具函数(Low Risk)
|
||||||
|
|
||||||
|
1. 创建 `src/skills/validation.py`,从 `gateway/routers/skills.py` 提取 `_validate_skill_frontmatter`
|
||||||
|
2. 创建 `src/utils/file_conversion.py`,从 `gateway/routers/uploads.py` 提取文件转换逻辑
|
||||||
|
3. 更新 `client.py`、`gateway/routers/skills.py`、`gateway/routers/uploads.py` 的 import
|
||||||
|
4. 运行全部测试确认无回归
|
||||||
|
|
||||||
|
### PR 2:Rename + 物理拆分(High Risk,原子操作)
|
||||||
|
|
||||||
|
1. 创建 `packages/harness/` 目录,创建 `pyproject.toml`
|
||||||
|
2. `git mv` 将 harness 相关模块从 `src/` 移入 `packages/harness/deerflow/`
|
||||||
|
3. `git mv` 将 app 相关模块从 `src/` 移入 `app/`
|
||||||
|
4. 全局替换 import:
|
||||||
|
- harness 模块:`src.*` → `deerflow.*`(所有 `.py` 文件、`langgraph.json`、测试、文档)
|
||||||
|
- app 模块:`src.gateway.*` → `app.gateway.*`、`src.channels.*` → `app.channels.*`
|
||||||
|
5. 更新 workspace root `pyproject.toml`
|
||||||
|
6. 更新 `langgraph.json`、`Makefile`、`Dockerfile`
|
||||||
|
7. `uv sync` + 全部测试 + 手动验证服务启动
|
||||||
|
|
||||||
|
### PR 3:边界检查 + 文档(Low Risk)
|
||||||
|
|
||||||
|
1. 添加 lint 规则:检查 harness 不 import app 模块
|
||||||
|
2. 更新 `CLAUDE.md`、`README.md`
|
||||||
|
|
||||||
|
## 7. 风险与缓解
|
||||||
|
|
||||||
|
| 风险 | 影响 | 缓解措施 |
|
||||||
|
|------|------|----------|
|
||||||
|
| 全局 rename 误伤 | 字符串中的 `src` 被错误替换 | 正则精确匹配 `\bsrc\.`,review diff |
|
||||||
|
| LangGraph Server 找不到模块 | 服务启动失败 | `langgraph.json` 的 `dependencies` 指向正确的 harness 包路径 |
|
||||||
|
| App 的 `PYTHONPATH` 缺失 | Gateway/Channel 启动 import 报错 | Makefile/Docker 统一设置 `PYTHONPATH=.` |
|
||||||
|
| `config.yaml` 中的 `use` 字段引用旧路径 | 运行时模块解析失败 | `config.yaml` 中的 `use` 字段同步更新为 `deerflow.*` |
|
||||||
|
| 测试中 `sys.path` 混乱 | 测试失败 | 用 editable install(`uv sync`)确保 deerflow 可导入,`conftest.py` 中添加 `app/` 到 `sys.path` |
|
||||||
|
|
||||||
|
## 8. 未来演进
|
||||||
|
|
||||||
|
- **独立发布**:harness 可以发布到内部 PyPI,让其他项目直接 `pip install deerflow-harness`
|
||||||
|
- **插件化 App**:不同的 app(web、CLI、bot)可以各自独立,都依赖同一个 harness
|
||||||
|
- **更细粒度拆分**:如果 harness 内部模块继续增长,可以进一步拆分(如 `deerflow-sandbox`、`deerflow-mcp`)
|
||||||
@@ -45,41 +45,6 @@ Example:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Custom Tool Interceptors
|
|
||||||
|
|
||||||
You can register custom interceptors that run before every MCP tool call. This is useful for injecting per-request headers (e.g., user auth tokens from the LangGraph execution context), logging, or metrics.
|
|
||||||
|
|
||||||
Declare interceptors in `extensions_config.json` using the `mcpInterceptors` field:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"mcpInterceptors": [
|
|
||||||
"my_package.mcp.auth:build_auth_interceptor"
|
|
||||||
],
|
|
||||||
"mcpServers": { ... }
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Each entry is a Python import path in `module:variable` format (resolved via `resolve_variable`). The variable must be a **no-arg builder function** that returns an async interceptor compatible with `MultiServerMCPClient`’s `tool_interceptors` interface, or `None` to skip.
|
|
||||||
|
|
||||||
Example interceptor that injects auth headers from LangGraph metadata:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def build_auth_interceptor():
|
|
||||||
async def interceptor(request, handler):
|
|
||||||
from langgraph.config import get_config
|
|
||||||
metadata = get_config().get("metadata", {})
|
|
||||||
headers = dict(request.headers or {})
|
|
||||||
if token := metadata.get("auth_token"):
|
|
||||||
headers["X-Auth-Token"] = token
|
|
||||||
return await handler(request.override(headers=headers))
|
|
||||||
return interceptor
|
|
||||||
```
|
|
||||||
|
|
||||||
- A single string value is accepted and normalized to a one-element list.
|
|
||||||
- Invalid paths or builder failures are logged as warnings without blocking other interceptors.
|
|
||||||
- The builder return value must be `callable`; non-callable values are skipped with a warning.
|
|
||||||
|
|
||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
MCP servers expose tools that are automatically discovered and integrated into DeerFlow’s agent system at runtime. Once enabled, these tools become available to agents without additional code changes.
|
MCP servers expose tools that are automatically discovered and integrated into DeerFlow’s agent system at runtime. Once enabled, these tools become available to agents without additional code changes.
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ This directory contains detailed documentation for the DeerFlow backend.
|
|||||||
|----------|-------------|
|
|----------|-------------|
|
||||||
| [ARCHITECTURE.md](ARCHITECTURE.md) | System architecture overview |
|
| [ARCHITECTURE.md](ARCHITECTURE.md) | System architecture overview |
|
||||||
| [API.md](API.md) | Complete API reference |
|
| [API.md](API.md) | Complete API reference |
|
||||||
| [AUTH_DESIGN.md](AUTH_DESIGN.md) | User authentication, CSRF, and per-user isolation design |
|
|
||||||
| [CONFIGURATION.md](CONFIGURATION.md) | Configuration options |
|
| [CONFIGURATION.md](CONFIGURATION.md) | Configuration options |
|
||||||
| [SETUP.md](SETUP.md) | Quick setup guide |
|
| [SETUP.md](SETUP.md) | Quick setup guide |
|
||||||
|
|
||||||
@@ -43,7 +42,6 @@ docs/
|
|||||||
├── README.md # This file
|
├── README.md # This file
|
||||||
├── ARCHITECTURE.md # System architecture
|
├── ARCHITECTURE.md # System architecture
|
||||||
├── API.md # API reference
|
├── API.md # API reference
|
||||||
├── AUTH_DESIGN.md # User authentication and isolation design
|
|
||||||
├── CONFIGURATION.md # Configuration guide
|
├── CONFIGURATION.md # Configuration guide
|
||||||
├── SETUP.md # Setup instructions
|
├── SETUP.md # Setup instructions
|
||||||
├── FILE_UPLOAD.md # File upload feature
|
├── FILE_UPLOAD.md # File upload feature
|
||||||
|
|||||||
+7
-13
@@ -23,9 +23,6 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r
|
|||||||
# Option A: Set environment variables (recommended)
|
# Option A: Set environment variables (recommended)
|
||||||
export OPENAI_API_KEY="your-key-here"
|
export OPENAI_API_KEY="your-key-here"
|
||||||
|
|
||||||
# Optional: pin the project root when running from another directory
|
|
||||||
export DEER_FLOW_PROJECT_ROOT="/path/to/deer-flow"
|
|
||||||
|
|
||||||
# Option B: Edit config.yaml directly
|
# Option B: Edit config.yaml directly
|
||||||
vim config.yaml # or your preferred editor
|
vim config.yaml # or your preferred editor
|
||||||
```
|
```
|
||||||
@@ -38,20 +35,17 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r
|
|||||||
|
|
||||||
## Important Notes
|
## Important Notes
|
||||||
|
|
||||||
- **Location**: `config.yaml` should be in `deer-flow/` (project root)
|
- **Location**: `config.yaml` should be in `deer-flow/` (project root), not `deer-flow/backend/`
|
||||||
- **Git**: `config.yaml` is automatically ignored by git (contains secrets)
|
- **Git**: `config.yaml` is automatically ignored by git (contains secrets)
|
||||||
- **Runtime root**: Set `DEER_FLOW_PROJECT_ROOT` if DeerFlow may start from outside the project root
|
- **Priority**: If both `backend/config.yaml` and `../config.yaml` exist, backend version takes precedence
|
||||||
- **Runtime data**: State defaults to `.deer-flow` under the project root; set `DEER_FLOW_HOME` to move it
|
|
||||||
- **Skills**: Skills default to `skills/` under the project root; set `DEER_FLOW_SKILLS_PATH` or `skills.path` to move them
|
|
||||||
|
|
||||||
## Configuration File Locations
|
## Configuration File Locations
|
||||||
|
|
||||||
The backend searches for `config.yaml` in this order:
|
The backend searches for `config.yaml` in this order:
|
||||||
|
|
||||||
1. Explicit `config_path` argument from code
|
1. `DEER_FLOW_CONFIG_PATH` environment variable (if set)
|
||||||
2. `DEER_FLOW_CONFIG_PATH` environment variable (if set)
|
2. `backend/config.yaml` (current directory when running from backend/)
|
||||||
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
|
3. `deer-flow/config.yaml` (parent directory - **recommended location**)
|
||||||
4. Legacy backend/repository-root locations for monorepo compatibility
|
|
||||||
|
|
||||||
**Recommended**: Place `config.yaml` in project root (`deer-flow/config.yaml`).
|
**Recommended**: Place `config.yaml` in project root (`deer-flow/config.yaml`).
|
||||||
|
|
||||||
@@ -83,8 +77,8 @@ python -c "from deerflow.config.app_config import AppConfig; print(AppConfig.res
|
|||||||
|
|
||||||
If it can't find the config:
|
If it can't find the config:
|
||||||
1. Ensure you've copied `config.example.yaml` to `config.yaml`
|
1. Ensure you've copied `config.example.yaml` to `config.yaml`
|
||||||
2. Verify you're in the project root, or set `DEER_FLOW_PROJECT_ROOT`
|
2. Verify you're in the correct directory
|
||||||
3. Check the file exists: `ls -la config.yaml`
|
3. Check the file exists: `ls -la ../config.yaml`
|
||||||
|
|
||||||
### Permission denied
|
### Permission denied
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
- [x] Add Plan Mode with TodoList middleware
|
- [x] Add Plan Mode with TodoList middleware
|
||||||
- [x] Add vision model support with ViewImageMiddleware
|
- [x] Add vision model support with ViewImageMiddleware
|
||||||
- [x] Skills system with SKILL.md format
|
- [x] Skills system with SKILL.md format
|
||||||
- [x] Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
|
|
||||||
|
|
||||||
## Planned Features
|
## Planned Features
|
||||||
|
|
||||||
@@ -22,9 +21,10 @@
|
|||||||
- [ ] Support for more document formats in upload
|
- [ ] Support for more document formats in upload
|
||||||
- [ ] Skill marketplace / remote skill installation
|
- [ ] Skill marketplace / remote skill installation
|
||||||
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
|
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
|
||||||
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
- Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
|
||||||
|
- Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
||||||
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
||||||
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
||||||
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
||||||
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
|
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
|
||||||
|
|
||||||
|
|||||||
@@ -41,13 +41,6 @@ summarization:
|
|||||||
|
|
||||||
# Custom summary prompt (optional)
|
# Custom summary prompt (optional)
|
||||||
summary_prompt: null
|
summary_prompt: null
|
||||||
|
|
||||||
# Tool names treated as skill file reads for skill rescue
|
|
||||||
skill_file_read_tool_names:
|
|
||||||
- read_file
|
|
||||||
- read
|
|
||||||
- view
|
|
||||||
- cat
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Configuration Options
|
### Configuration Options
|
||||||
@@ -132,26 +125,6 @@ keep:
|
|||||||
- **Default**: `null` (uses LangChain's default prompt)
|
- **Default**: `null` (uses LangChain's default prompt)
|
||||||
- **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context.
|
- **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context.
|
||||||
|
|
||||||
#### `preserve_recent_skill_count`
|
|
||||||
- **Type**: Integer (≥ 0)
|
|
||||||
- **Default**: `5`
|
|
||||||
- **Description**: Number of most-recently-loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`, e.g. `/mnt/skills/...`) that are rescued from summarization. Prevents the agent from losing skill instructions after compression. Set to `0` to disable skill rescue entirely.
|
|
||||||
|
|
||||||
#### `preserve_recent_skill_tokens`
|
|
||||||
- **Type**: Integer (≥ 0)
|
|
||||||
- **Default**: `25000`
|
|
||||||
- **Description**: Total token budget reserved for rescued skill reads. Once this budget is exhausted, older skill bundles are allowed to be summarized.
|
|
||||||
|
|
||||||
#### `preserve_recent_skill_tokens_per_skill`
|
|
||||||
- **Type**: Integer (≥ 0)
|
|
||||||
- **Default**: `5000`
|
|
||||||
- **Description**: Per-skill token cap. Any individual skill read whose tool result exceeds this size is not rescued (it falls through to the summarizer like ordinary content).
|
|
||||||
|
|
||||||
#### `skill_file_read_tool_names`
|
|
||||||
- **Type**: List of strings
|
|
||||||
- **Default**: `["read_file", "read", "view", "cat"]`
|
|
||||||
- **Description**: Tool names treated as skill file reads during summarization rescue. A tool call is only eligible for skill rescue when its name appears in this list and its target path is under `skills.container_path`.
|
|
||||||
|
|
||||||
**Default Prompt Behavior:**
|
**Default Prompt Behavior:**
|
||||||
The default LangChain prompt instructs the model to:
|
The default LangChain prompt instructs the model to:
|
||||||
- Extract highest quality/most relevant context
|
- Extract highest quality/most relevant context
|
||||||
@@ -174,7 +147,6 @@ The default LangChain prompt instructs the model to:
|
|||||||
- A single summary message is added
|
- A single summary message is added
|
||||||
- Recent messages are preserved
|
- Recent messages are preserved
|
||||||
6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together
|
6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together
|
||||||
7. **Skill Rescue**: Before the summary is generated, the most recently loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`) are lifted out of the summarization set and prepended to the preserved tail. Selection walks newest-first under three budgets: `preserve_recent_skill_count`, `preserve_recent_skill_tokens`, and `preserve_recent_skill_tokens_per_skill`. The triggering AIMessage and all of its paired ToolMessages move together so tool_call ↔ tool_result pairing stays intact.
|
|
||||||
|
|
||||||
### Token Counting
|
### Token Counting
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,6 @@
|
|||||||
"path": "./app/gateway/langgraph_auth.py:auth"
|
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||||
},
|
},
|
||||||
"checkpointer": {
|
"checkpointer": {
|
||||||
"path": "./packages/harness/deerflow/runtime/checkpointer/async_provider.py:make_checkpointer"
|
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ def _assemble_from_features(
|
|||||||
9. MemoryMiddleware (memory feature)
|
9. MemoryMiddleware (memory feature)
|
||||||
10. ViewImageMiddleware (vision feature)
|
10. ViewImageMiddleware (vision feature)
|
||||||
11. SubagentLimitMiddleware (subagent feature)
|
11. SubagentLimitMiddleware (subagent feature)
|
||||||
12. LoopDetectionMiddleware (loop_detection feature)
|
12. LoopDetectionMiddleware (always)
|
||||||
13. ClarificationMiddleware (always last)
|
13. ClarificationMiddleware (always last)
|
||||||
|
|
||||||
Two-phase ordering:
|
Two-phase ordering:
|
||||||
@@ -254,11 +254,9 @@ def _assemble_from_features(
|
|||||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||||
|
|
||||||
chain.append(ViewImageMiddleware())
|
chain.append(ViewImageMiddleware())
|
||||||
|
from deerflow.tools.builtins import view_image_tool
|
||||||
|
|
||||||
if feat.sandbox is not False:
|
extra_tools.append(view_image_tool)
|
||||||
from deerflow.tools.builtins import view_image_tool
|
|
||||||
|
|
||||||
extra_tools.append(view_image_tool)
|
|
||||||
|
|
||||||
# --- [11] Subagent ---
|
# --- [11] Subagent ---
|
||||||
if feat.subagent is not False:
|
if feat.subagent is not False:
|
||||||
@@ -272,15 +270,10 @@ def _assemble_from_features(
|
|||||||
|
|
||||||
extra_tools.append(task_tool)
|
extra_tools.append(task_tool)
|
||||||
|
|
||||||
# --- [12] LoopDetection ---
|
# --- [12] LoopDetection (always) ---
|
||||||
if feat.loop_detection is not False:
|
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||||
if isinstance(feat.loop_detection, AgentMiddleware):
|
|
||||||
chain.append(feat.loop_detection)
|
|
||||||
else:
|
|
||||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
|
||||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
|
||||||
|
|
||||||
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
|
chain.append(LoopDetectionMiddleware())
|
||||||
|
|
||||||
# --- [13] Clarification (always last among built-ins) ---
|
# --- [13] Clarification (always last among built-ins) ---
|
||||||
chain.append(ClarificationMiddleware())
|
chain.append(ClarificationMiddleware())
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ class RuntimeFeatures:
|
|||||||
vision: bool | AgentMiddleware = False
|
vision: bool | AgentMiddleware = False
|
||||||
auto_title: bool | AgentMiddleware = False
|
auto_title: bool | AgentMiddleware = False
|
||||||
guardrail: Literal[False] | AgentMiddleware = False
|
guardrail: Literal[False] | AgentMiddleware = False
|
||||||
loop_detection: bool | AgentMiddleware = True
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -5,39 +5,28 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
|
||||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||||
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
|
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
|
||||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
|
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
|
||||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
from deerflow.config.agents_config import load_agent_config
|
||||||
from deerflow.config.app_config import AppConfig, get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
|
from deerflow.config.summarization_config import get_summarization_config
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
|
||||||
from deerflow.skills.types import Skill
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_runtime_config(config: RunnableConfig) -> dict:
|
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||||
"""Merge legacy configurable options with LangGraph runtime context."""
|
|
||||||
cfg = dict(config.get("configurable", {}) or {})
|
|
||||||
context = config.get("context", {}) or {}
|
|
||||||
if isinstance(context, dict):
|
|
||||||
cfg.update(context)
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model_name(requested_model_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
|
|
||||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||||
app_config = app_config or get_app_config()
|
app_config = get_app_config()
|
||||||
default_model_name = app_config.models[0].name if app_config.models else None
|
default_model_name = app_config.models[0].name if app_config.models else None
|
||||||
if default_model_name is None:
|
if default_model_name is None:
|
||||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||||
@@ -50,10 +39,9 @@ def _resolve_model_name(requested_model_name: str | None = None, *, app_config:
|
|||||||
return default_model_name
|
return default_model_name
|
||||||
|
|
||||||
|
|
||||||
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
|
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||||
"""Create and configure the summarization middleware from config."""
|
"""Create and configure the summarization middleware from config."""
|
||||||
resolved_app_config = app_config or get_app_config()
|
config = get_summarization_config()
|
||||||
config = resolved_app_config.summarization
|
|
||||||
|
|
||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
return None
|
return None
|
||||||
@@ -74,9 +62,9 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
|
|||||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||||
# LangChain built-in, so we tag the model at creation time).
|
# LangChain built-in, so we tag the model at creation time).
|
||||||
if config.model_name:
|
if config.model_name:
|
||||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config)
|
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||||
else:
|
else:
|
||||||
model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config)
|
model = create_chat_model(thinking_enabled=False)
|
||||||
model = model.with_config(tags=["middleware:summarize"])
|
model = model.with_config(tags=["middleware:summarize"])
|
||||||
|
|
||||||
# Prepare kwargs
|
# Prepare kwargs
|
||||||
@@ -92,24 +80,7 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
|
|||||||
if config.summary_prompt is not None:
|
if config.summary_prompt is not None:
|
||||||
kwargs["summary_prompt"] = config.summary_prompt
|
kwargs["summary_prompt"] = config.summary_prompt
|
||||||
|
|
||||||
hooks: list[BeforeSummarizationHook] = []
|
return SummarizationMiddleware(**kwargs)
|
||||||
if resolved_app_config.memory.enabled:
|
|
||||||
hooks.append(memory_flush_hook)
|
|
||||||
|
|
||||||
# The logic below relies on two assumptions holding true: this factory is
|
|
||||||
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
|
|
||||||
# config is not expected to change after startup.
|
|
||||||
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
|
|
||||||
|
|
||||||
return DeerFlowSummarizationMiddleware(
|
|
||||||
**kwargs,
|
|
||||||
skills_container_path=skills_container_path,
|
|
||||||
skill_file_read_tool_names=config.skill_file_read_tool_names,
|
|
||||||
before_summarization=hooks,
|
|
||||||
preserve_recent_skill_count=config.preserve_recent_skill_count,
|
|
||||||
preserve_recent_skill_tokens=config.preserve_recent_skill_tokens,
|
|
||||||
preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||||
@@ -237,14 +208,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
|||||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||||
def _build_middlewares(
|
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
|
||||||
config: RunnableConfig,
|
|
||||||
model_name: str | None,
|
|
||||||
agent_name: str | None = None,
|
|
||||||
custom_middlewares: list[AgentMiddleware] | None = None,
|
|
||||||
*,
|
|
||||||
app_config: AppConfig | None = None,
|
|
||||||
):
|
|
||||||
"""Build middleware chain based on runtime configuration.
|
"""Build middleware chain based on runtime configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -255,59 +219,50 @@ def _build_middlewares(
|
|||||||
Returns:
|
Returns:
|
||||||
List of middleware instances.
|
List of middleware instances.
|
||||||
"""
|
"""
|
||||||
resolved_app_config = app_config or get_app_config()
|
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||||
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
|
||||||
|
|
||||||
# Always inject current date (and optionally memory) as <system-reminder> into the
|
|
||||||
# first HumanMessage to keep the system prompt fully static for prefix-cache reuse.
|
|
||||||
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
|
|
||||||
|
|
||||||
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
|
||||||
|
|
||||||
# Add summarization middleware if enabled
|
# Add summarization middleware if enabled
|
||||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
summarization_middleware = _create_summarization_middleware()
|
||||||
if summarization_middleware is not None:
|
if summarization_middleware is not None:
|
||||||
middlewares.append(summarization_middleware)
|
middlewares.append(summarization_middleware)
|
||||||
|
|
||||||
# Add TodoList middleware if plan mode is enabled
|
# Add TodoList middleware if plan mode is enabled
|
||||||
cfg = _get_runtime_config(config)
|
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
|
||||||
is_plan_mode = cfg.get("is_plan_mode", False)
|
|
||||||
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
|
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
|
||||||
if todo_list_middleware is not None:
|
if todo_list_middleware is not None:
|
||||||
middlewares.append(todo_list_middleware)
|
middlewares.append(todo_list_middleware)
|
||||||
|
|
||||||
# Add TokenUsageMiddleware when token_usage tracking is enabled
|
# Add TokenUsageMiddleware when token_usage tracking is enabled
|
||||||
if resolved_app_config.token_usage.enabled:
|
if get_app_config().token_usage.enabled:
|
||||||
middlewares.append(TokenUsageMiddleware())
|
middlewares.append(TokenUsageMiddleware())
|
||||||
|
|
||||||
# Add TitleMiddleware
|
# Add TitleMiddleware
|
||||||
middlewares.append(TitleMiddleware(app_config=resolved_app_config))
|
middlewares.append(TitleMiddleware())
|
||||||
|
|
||||||
# Add MemoryMiddleware (after TitleMiddleware)
|
# Add MemoryMiddleware (after TitleMiddleware)
|
||||||
middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=resolved_app_config.memory))
|
middlewares.append(MemoryMiddleware(agent_name=agent_name))
|
||||||
|
|
||||||
# Add ViewImageMiddleware only if the current model supports vision.
|
# Add ViewImageMiddleware only if the current model supports vision.
|
||||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||||
model_config = resolved_app_config.get_model_config(model_name) if model_name else None
|
app_config = get_app_config()
|
||||||
|
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||||
if model_config is not None and model_config.supports_vision:
|
if model_config is not None and model_config.supports_vision:
|
||||||
middlewares.append(ViewImageMiddleware())
|
middlewares.append(ViewImageMiddleware())
|
||||||
|
|
||||||
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
|
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
|
||||||
if resolved_app_config.tool_search.enabled:
|
if app_config.tool_search.enabled:
|
||||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
middlewares.append(DeferredToolFilterMiddleware())
|
middlewares.append(DeferredToolFilterMiddleware())
|
||||||
|
|
||||||
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
||||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
||||||
if subagent_enabled:
|
if subagent_enabled:
|
||||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
||||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||||
|
|
||||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||||
loop_detection_config = resolved_app_config.loop_detection
|
middlewares.append(LoopDetectionMiddleware())
|
||||||
if loop_detection_config.enabled:
|
|
||||||
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
|
|
||||||
|
|
||||||
# Inject custom middlewares before ClarificationMiddleware
|
# Inject custom middlewares before ClarificationMiddleware
|
||||||
if custom_middlewares:
|
if custom_middlewares:
|
||||||
@@ -318,42 +273,12 @@ def _build_middlewares(
|
|||||||
return middlewares
|
return middlewares
|
||||||
|
|
||||||
|
|
||||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
|
||||||
if is_bootstrap:
|
|
||||||
return {"bootstrap"}
|
|
||||||
if agent_config and agent_config.skills is not None:
|
|
||||||
return set(agent_config.skills)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
|
|
||||||
try:
|
|
||||||
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
|
|
||||||
|
|
||||||
skills = get_enabled_skills_for_config(app_config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to load skills for allowed-tools policy")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if available_skills is None:
|
|
||||||
return skills
|
|
||||||
return [skill for skill in skills if skill.name in available_skills]
|
|
||||||
|
|
||||||
|
|
||||||
def make_lead_agent(config: RunnableConfig):
|
def make_lead_agent(config: RunnableConfig):
|
||||||
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
|
|
||||||
runtime_config = _get_runtime_config(config)
|
|
||||||
runtime_app_config = runtime_config.get("app_config")
|
|
||||||
return _make_lead_agent(config, app_config=runtime_app_config or get_app_config())
|
|
||||||
|
|
||||||
|
|
||||||
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|
||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from deerflow.tools import get_available_tools
|
from deerflow.tools import get_available_tools
|
||||||
from deerflow.tools.builtins import setup_agent, update_agent
|
from deerflow.tools.builtins import setup_agent
|
||||||
|
|
||||||
cfg = _get_runtime_config(config)
|
cfg = config.get("configurable", {})
|
||||||
resolved_app_config = app_config
|
|
||||||
|
|
||||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||||
reasoning_effort = cfg.get("reasoning_effort", None)
|
reasoning_effort = cfg.get("reasoning_effort", None)
|
||||||
@@ -362,17 +287,17 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
subagent_enabled = cfg.get("subagent_enabled", False)
|
||||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||||
is_bootstrap = cfg.get("is_bootstrap", False)
|
is_bootstrap = cfg.get("is_bootstrap", False)
|
||||||
agent_name = validate_agent_name(cfg.get("agent_name"))
|
agent_name = cfg.get("agent_name")
|
||||||
|
|
||||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||||
available_skills = _available_skill_names(agent_config, is_bootstrap)
|
|
||||||
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
||||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||||
|
|
||||||
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
||||||
model_name = _resolve_model_name(requested_model_name or agent_model_name, app_config=resolved_app_config)
|
model_name = _resolve_model_name(requested_model_name or agent_model_name)
|
||||||
|
|
||||||
model_config = resolved_app_config.get_model_config(model_name)
|
app_config = get_app_config()
|
||||||
|
model_config = app_config.get_model_config(model_name)
|
||||||
|
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
||||||
@@ -403,44 +328,26 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
"reasoning_effort": reasoning_effort,
|
"reasoning_effort": reasoning_effort,
|
||||||
"is_plan_mode": is_plan_mode,
|
"is_plan_mode": is_plan_mode,
|
||||||
"subagent_enabled": subagent_enabled,
|
"subagent_enabled": subagent_enabled,
|
||||||
"tool_groups": agent_config.tool_groups if agent_config else None,
|
|
||||||
"available_skills": sorted(available_skills) if available_skills is not None else None,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
|
|
||||||
|
|
||||||
if is_bootstrap:
|
if is_bootstrap:
|
||||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||||
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||||
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
|
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||||
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
|
middleware=_build_middlewares(config, model_name=model_name),
|
||||||
system_prompt=apply_prompt_template(
|
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||||
subagent_enabled=subagent_enabled,
|
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
|
||||||
available_skills=set(["bootstrap"]),
|
|
||||||
app_config=resolved_app_config,
|
|
||||||
),
|
|
||||||
state_schema=ThreadState,
|
state_schema=ThreadState,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Custom agents can update their own SOUL.md / config via update_agent.
|
|
||||||
# The default agent (no agent_name) does not see this tool.
|
|
||||||
extra_tools = [update_agent] if agent_name else []
|
|
||||||
# Default lead agent (unchanged behavior)
|
# Default lead agent (unchanged behavior)
|
||||||
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||||
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
|
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
|
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||||
system_prompt=apply_prompt_template(
|
system_prompt=apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
|
||||||
agent_name=agent_name,
|
|
||||||
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
|
|
||||||
app_config=resolved_app_config,
|
|
||||||
),
|
),
|
||||||
state_schema=ThreadState,
|
state_schema=ThreadState,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,32 +1,26 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from deerflow.config.agents_config import load_agent_soul
|
from deerflow.config.agents_config import load_agent_soul
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills import load_skills
|
||||||
from deerflow.skills.types import Skill, SkillCategory
|
from deerflow.skills.types import Skill
|
||||||
from deerflow.subagents import get_available_subagent_names
|
from deerflow.subagents import get_available_subagent_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||||
_enabled_skills_lock = threading.Lock()
|
_enabled_skills_lock = threading.Lock()
|
||||||
_enabled_skills_cache: list[Skill] | None = None
|
_enabled_skills_cache: list[Skill] | None = None
|
||||||
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
|
|
||||||
_enabled_skills_refresh_active = False
|
_enabled_skills_refresh_active = False
|
||||||
_enabled_skills_refresh_version = 0
|
_enabled_skills_refresh_version = 0
|
||||||
_enabled_skills_refresh_event = threading.Event()
|
_enabled_skills_refresh_event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
def _load_enabled_skills_sync() -> list[Skill]:
|
def _load_enabled_skills_sync() -> list[Skill]:
|
||||||
return list(get_or_new_skill_storage().load_skills(enabled_only=True))
|
return list(load_skills(enabled_only=True))
|
||||||
|
|
||||||
|
|
||||||
def _start_enabled_skills_refresh_thread() -> None:
|
def _start_enabled_skills_refresh_thread() -> None:
|
||||||
@@ -84,7 +78,6 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
|
|||||||
_get_cached_skills_prompt_section.cache_clear()
|
_get_cached_skills_prompt_section.cache_clear()
|
||||||
with _enabled_skills_lock:
|
with _enabled_skills_lock:
|
||||||
_enabled_skills_cache = None
|
_enabled_skills_cache = None
|
||||||
_enabled_skills_by_config_cache.clear()
|
|
||||||
_enabled_skills_refresh_version += 1
|
_enabled_skills_refresh_version += 1
|
||||||
_enabled_skills_refresh_event.clear()
|
_enabled_skills_refresh_event.clear()
|
||||||
if _enabled_skills_refresh_active:
|
if _enabled_skills_refresh_active:
|
||||||
@@ -108,15 +101,6 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
|
|||||||
|
|
||||||
|
|
||||||
def _get_enabled_skills():
|
def _get_enabled_skills():
|
||||||
return get_cached_enabled_skills()
|
|
||||||
|
|
||||||
|
|
||||||
def get_cached_enabled_skills() -> list[Skill]:
|
|
||||||
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
|
|
||||||
|
|
||||||
Safe to call from request paths: never blocks on disk I/O. Returns an empty
|
|
||||||
list on cache miss; the next call will see the warmed result.
|
|
||||||
"""
|
|
||||||
with _enabled_skills_lock:
|
with _enabled_skills_lock:
|
||||||
cached = _enabled_skills_cache
|
cached = _enabled_skills_cache
|
||||||
|
|
||||||
@@ -127,33 +111,8 @@ def get_cached_enabled_skills() -> list[Skill]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
def _skill_mutability_label(category: str) -> str:
|
||||||
"""Return enabled skills using the caller's config source.
|
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||||
|
|
||||||
When a concrete ``app_config`` is supplied, cache the loaded skills by that
|
|
||||||
config object's identity so request-scoped config injection still resolves
|
|
||||||
skill paths from the matching config without rescanning storage on every
|
|
||||||
agent factory call.
|
|
||||||
"""
|
|
||||||
if app_config is None:
|
|
||||||
return _get_enabled_skills()
|
|
||||||
|
|
||||||
cache_key = id(app_config)
|
|
||||||
with _enabled_skills_lock:
|
|
||||||
cached = _enabled_skills_by_config_cache.get(cache_key)
|
|
||||||
if cached is not None:
|
|
||||||
cached_config, cached_skills = cached
|
|
||||||
if cached_config is app_config:
|
|
||||||
return list(cached_skills)
|
|
||||||
|
|
||||||
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
|
||||||
with _enabled_skills_lock:
|
|
||||||
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
|
|
||||||
return list(skills)
|
|
||||||
|
|
||||||
|
|
||||||
def _skill_mutability_label(category: SkillCategory | str) -> str:
|
|
||||||
return "[custom, editable]" if category == SkillCategory.CUSTOM else "[built-in]"
|
|
||||||
|
|
||||||
|
|
||||||
def clear_skills_system_prompt_cache() -> None:
|
def clear_skills_system_prompt_cache() -> None:
|
||||||
@@ -164,6 +123,31 @@ async def refresh_skills_system_prompt_cache_async() -> None:
|
|||||||
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
|
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_skills_system_prompt_cache_state() -> None:
|
||||||
|
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||||
|
|
||||||
|
_get_cached_skills_prompt_section.cache_clear()
|
||||||
|
with _enabled_skills_lock:
|
||||||
|
_enabled_skills_cache = None
|
||||||
|
_enabled_skills_refresh_active = False
|
||||||
|
_enabled_skills_refresh_version = 0
|
||||||
|
_enabled_skills_refresh_event.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_enabled_skills_cache() -> None:
|
||||||
|
"""Backward-compatible test helper for direct synchronous reload."""
|
||||||
|
try:
|
||||||
|
skills = _load_enabled_skills_sync()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load enabled skills for prompt injection")
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
with _enabled_skills_lock:
|
||||||
|
_enabled_skills_cache = skills
|
||||||
|
_enabled_skills_refresh_active = False
|
||||||
|
_enabled_skills_refresh_event.set()
|
||||||
|
|
||||||
|
|
||||||
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
|
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
|
||||||
if not skill_evolution_enabled:
|
if not skill_evolution_enabled:
|
||||||
return ""
|
return ""
|
||||||
@@ -180,37 +164,7 @@ Skip simple one-off tasks.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool, *, app_config: AppConfig | None = None) -> str:
|
def _build_subagent_section(max_concurrent: int) -> str:
|
||||||
"""Dynamically build subagent type descriptions from registry.
|
|
||||||
|
|
||||||
Mirrors Codex's pattern where agent_type_description is dynamically generated
|
|
||||||
from all registered roles, so the LLM knows about every available type.
|
|
||||||
"""
|
|
||||||
# Built-in descriptions (kept for backward compatibility with existing prompt quality)
|
|
||||||
builtin_descriptions = {
|
|
||||||
"general-purpose": "For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.",
|
|
||||||
"bash": (
|
|
||||||
"For command execution (git, build, test, deploy operations)" if bash_available else "Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Lazy import moved outside loop to avoid repeated import overhead
|
|
||||||
from deerflow.subagents.registry import get_subagent_config
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for name in available_names:
|
|
||||||
if name in builtin_descriptions:
|
|
||||||
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
|
|
||||||
else:
|
|
||||||
config = get_subagent_config(name, app_config=app_config)
|
|
||||||
if config is not None:
|
|
||||||
desc = config.description.split("\n")[0].strip() # First line only for brevity
|
|
||||||
lines.append(f"- **{name}**: {desc}")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_subagent_section(max_concurrent: int, *, app_config: AppConfig | None = None) -> str:
|
|
||||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -220,12 +174,13 @@ def _build_subagent_section(max_concurrent: int, *, app_config: AppConfig | None
|
|||||||
Formatted subagent section string.
|
Formatted subagent section string.
|
||||||
"""
|
"""
|
||||||
n = max_concurrent
|
n = max_concurrent
|
||||||
available_names = get_available_subagent_names(app_config=app_config) if app_config is not None else get_available_subagent_names()
|
bash_available = "bash" in get_available_subagent_names()
|
||||||
bash_available = "bash" in available_names
|
available_subagents = (
|
||||||
|
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
|
||||||
# Dynamically build subagent type descriptions from registry (aligned with Codex's
|
if bash_available
|
||||||
# agent_type_description pattern where all registered roles are listed in the tool spec).
|
else "- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n"
|
||||||
available_subagents = _build_available_subagents_description(available_names, bash_available, app_config=app_config)
|
"- **bash**: Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
|
||||||
|
)
|
||||||
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
|
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
|
||||||
direct_execution_example = (
|
direct_execution_example = (
|
||||||
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
|
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
|
||||||
@@ -366,7 +321,8 @@ You are {agent_name}, an open-source super agent.
|
|||||||
</role>
|
</role>
|
||||||
|
|
||||||
{soul}
|
{soul}
|
||||||
{self_update_section}
|
{memory_context}
|
||||||
|
|
||||||
<thinking_style>
|
<thinking_style>
|
||||||
- Think concisely and strategically about the user's request BEFORE taking action
|
- Think concisely and strategically about the user's request BEFORE taking action
|
||||||
- Break down the task: What is clear? What is ambiguous? What is missing?
|
- Break down the task: What is clear? What is ambiguous? What is missing?
|
||||||
@@ -464,7 +420,7 @@ You: "Deploying to staging..." [proceed]
|
|||||||
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
|
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
|
||||||
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
|
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
|
||||||
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
|
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
|
||||||
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_files` tool
|
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
|
||||||
{acp_section}
|
{acp_section}
|
||||||
</working_directory>
|
</working_directory>
|
||||||
|
|
||||||
@@ -551,28 +507,21 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
|
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||||
"""Get memory context for injection into system prompt.
|
"""Get memory context for injection into system prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||||
app_config: Explicit application config. When provided, memory options
|
|
||||||
are read from this value instead of the global config singleton.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||||
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
if app_config is None:
|
config = get_memory_config()
|
||||||
from deerflow.config.memory_config import get_memory_config
|
|
||||||
|
|
||||||
config = get_memory_config()
|
|
||||||
else:
|
|
||||||
config = app_config.memory
|
|
||||||
|
|
||||||
if not config.enabled or not config.injection_enabled:
|
if not config.enabled or not config.injection_enabled:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -586,8 +535,8 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
|||||||
{memory_content}
|
{memory_content}
|
||||||
</memory>
|
</memory>
|
||||||
"""
|
"""
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Failed to load memory context")
|
logger.error("Failed to load memory context: %s", e)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@@ -623,24 +572,19 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
|||||||
</skill_system>"""
|
</skill_system>"""
|
||||||
|
|
||||||
|
|
||||||
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
|
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||||
"""Generate the skills prompt section with available skills list."""
|
"""Generate the skills prompt section with available skills list."""
|
||||||
skills = get_enabled_skills_for_config(app_config)
|
skills = _get_enabled_skills()
|
||||||
|
|
||||||
if app_config is None:
|
try:
|
||||||
try:
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
config = get_app_config()
|
config = get_app_config()
|
||||||
container_base_path = config.skills.container_path
|
|
||||||
skill_evolution_enabled = config.skill_evolution.enabled
|
|
||||||
except Exception:
|
|
||||||
container_base_path = "/mnt/skills"
|
|
||||||
skill_evolution_enabled = False
|
|
||||||
else:
|
|
||||||
config = app_config
|
|
||||||
container_base_path = config.skills.container_path
|
container_base_path = config.skills.container_path
|
||||||
skill_evolution_enabled = config.skill_evolution.enabled
|
skill_evolution_enabled = config.skill_evolution.enabled
|
||||||
|
except Exception:
|
||||||
|
container_base_path = "/mnt/skills"
|
||||||
|
skill_evolution_enabled = False
|
||||||
|
|
||||||
if not skills and not skill_evolution_enabled:
|
if not skills and not skill_evolution_enabled:
|
||||||
return ""
|
return ""
|
||||||
@@ -664,27 +608,7 @@ def get_agent_soul(agent_name: str | None) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _build_self_update_section(agent_name: str | None) -> str:
|
def get_deferred_tools_prompt_section() -> str:
|
||||||
"""Prompt block that teaches the custom agent to persist self-updates via update_agent."""
|
|
||||||
if not agent_name:
|
|
||||||
return ""
|
|
||||||
return f"""<self_update>
|
|
||||||
You are running as the custom agent **{agent_name}** with a persisted SOUL.md and config.yaml.
|
|
||||||
|
|
||||||
When the user asks you to update your own description, personality, behaviour, skill set, tool groups, or default model,
|
|
||||||
you MUST persist the change with the `update_agent` tool. Do NOT use `bash`, `write_file`, or any sandbox tool to edit
|
|
||||||
SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace and the changes will be lost on the next turn.
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
|
|
||||||
- Only pass the fields that should change. Omit the others to preserve them.
|
|
||||||
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
|
|
||||||
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
|
|
||||||
</self_update>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
|
|
||||||
"""Generate <available-deferred-tools> block for the system prompt.
|
"""Generate <available-deferred-tools> block for the system prompt.
|
||||||
|
|
||||||
Lists only deferred tool names so the agent knows what exists
|
Lists only deferred tool names so the agent knows what exists
|
||||||
@@ -693,17 +617,12 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) ->
|
|||||||
"""
|
"""
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||||
|
|
||||||
if app_config is None:
|
try:
|
||||||
try:
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
config = get_app_config()
|
if not get_app_config().tool_search.enabled:
|
||||||
except Exception:
|
|
||||||
return ""
|
return ""
|
||||||
else:
|
except Exception:
|
||||||
config = app_config
|
|
||||||
|
|
||||||
if not config.tool_search.enabled:
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
registry = get_deferred_registry()
|
registry = get_deferred_registry()
|
||||||
@@ -714,19 +633,15 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) ->
|
|||||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||||
|
|
||||||
|
|
||||||
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
def _build_acp_section() -> str:
|
||||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||||
if app_config is None:
|
try:
|
||||||
try:
|
from deerflow.config.acp_config import get_acp_agents
|
||||||
from deerflow.config.acp_config import get_acp_agents
|
|
||||||
|
|
||||||
agents = get_acp_agents()
|
agents = get_acp_agents()
|
||||||
except Exception:
|
if not agents:
|
||||||
return ""
|
return ""
|
||||||
else:
|
except Exception:
|
||||||
agents = getattr(app_config, "acp_agents", {}) or {}
|
|
||||||
|
|
||||||
if not agents:
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -734,24 +649,19 @@ def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
|||||||
"- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n"
|
"- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n"
|
||||||
"- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n"
|
"- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n"
|
||||||
"- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n"
|
"- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n"
|
||||||
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_files`"
|
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_file`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
|
def _build_custom_mounts_section() -> str:
|
||||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||||
if app_config is None:
|
try:
|
||||||
try:
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
config = get_app_config()
|
mounts = get_app_config().sandbox.mounts or []
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||||
return ""
|
return ""
|
||||||
else:
|
|
||||||
config = app_config
|
|
||||||
|
|
||||||
mounts = config.sandbox.mounts or []
|
|
||||||
|
|
||||||
if not mounts:
|
if not mounts:
|
||||||
return ""
|
return ""
|
||||||
@@ -765,17 +675,13 @@ def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
|
|||||||
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
||||||
|
|
||||||
|
|
||||||
def apply_prompt_template(
|
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||||
subagent_enabled: bool = False,
|
# Get memory context
|
||||||
max_concurrent_subagents: int = 3,
|
memory_context = _get_memory_context(agent_name)
|
||||||
*,
|
|
||||||
agent_name: str | None = None,
|
|
||||||
available_skills: set[str] | None = None,
|
|
||||||
app_config: AppConfig | None = None,
|
|
||||||
) -> str:
|
|
||||||
# Include subagent section only if enabled (from runtime parameter)
|
# Include subagent section only if enabled (from runtime parameter)
|
||||||
n = max_concurrent_subagents
|
n = max_concurrent_subagents
|
||||||
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
|
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||||
|
|
||||||
# Add subagent reminder to critical_reminders if enabled
|
# Add subagent reminder to critical_reminders if enabled
|
||||||
subagent_reminder = (
|
subagent_reminder = (
|
||||||
@@ -796,28 +702,27 @@ def apply_prompt_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get skills section
|
# Get skills section
|
||||||
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
|
skills_section = get_skills_prompt_section(available_skills)
|
||||||
|
|
||||||
# Get deferred tools section (tool_search)
|
# Get deferred tools section (tool_search)
|
||||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
|
deferred_tools_section = get_deferred_tools_prompt_section()
|
||||||
|
|
||||||
# Build ACP agent section only if ACP agents are configured
|
# Build ACP agent section only if ACP agents are configured
|
||||||
acp_section = _build_acp_section(app_config=app_config)
|
acp_section = _build_acp_section()
|
||||||
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
|
custom_mounts_section = _build_custom_mounts_section()
|
||||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||||
|
|
||||||
# Build and return the fully static system prompt.
|
# Format the prompt with dynamic skills and memory
|
||||||
# Memory and current date are injected per-turn via DynamicContextMiddleware
|
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
# as a <system-reminder> in the first HumanMessage, keeping this prompt
|
|
||||||
# identical across users and sessions for maximum prefix-cache reuse.
|
|
||||||
return SYSTEM_PROMPT_TEMPLATE.format(
|
|
||||||
agent_name=agent_name or "DeerFlow 2.0",
|
agent_name=agent_name or "DeerFlow 2.0",
|
||||||
soul=get_agent_soul(agent_name),
|
soul=get_agent_soul(agent_name),
|
||||||
self_update_section=_build_self_update_section(agent_name),
|
|
||||||
skills_section=skills_section,
|
skills_section=skills_section,
|
||||||
deferred_tools_section=deferred_tools_section,
|
deferred_tools_section=deferred_tools_section,
|
||||||
|
memory_context=memory_context,
|
||||||
subagent_section=subagent_section,
|
subagent_section=subagent_section,
|
||||||
subagent_reminder=subagent_reminder,
|
subagent_reminder=subagent_reminder,
|
||||||
subagent_thinking=subagent_thinking,
|
subagent_thinking=subagent_thinking,
|
||||||
acp_section=acp_and_mounts_section,
|
acp_section=acp_and_mounts_section,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||||
|
|||||||
@@ -1,109 +0,0 @@
|
|||||||
"""Shared helpers for turning conversations into memory update inputs."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from copy import copy
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
|
||||||
_CORRECTION_PATTERNS = (
|
|
||||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
|
||||||
re.compile(r"不对"),
|
|
||||||
re.compile(r"你理解错了"),
|
|
||||||
re.compile(r"你理解有误"),
|
|
||||||
re.compile(r"重试"),
|
|
||||||
re.compile(r"重新来"),
|
|
||||||
re.compile(r"换一种"),
|
|
||||||
re.compile(r"改用"),
|
|
||||||
)
|
|
||||||
_REINFORCEMENT_PATTERNS = (
|
|
||||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
|
||||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
|
||||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
|
||||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_message_text(message: Any) -> str:
|
|
||||||
"""Extract plain text from message content for filtering and signal detection."""
|
|
||||||
content = getattr(message, "content", "")
|
|
||||||
if isinstance(content, list):
|
|
||||||
text_parts: list[str] = []
|
|
||||||
for part in content:
|
|
||||||
if isinstance(part, str):
|
|
||||||
text_parts.append(part)
|
|
||||||
elif isinstance(part, dict):
|
|
||||||
text_val = part.get("text")
|
|
||||||
if isinstance(text_val, str):
|
|
||||||
text_parts.append(text_val)
|
|
||||||
return " ".join(text_parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
|
||||||
"""Keep only user inputs and final assistant responses for memory updates."""
|
|
||||||
filtered = []
|
|
||||||
skip_next_ai = False
|
|
||||||
for msg in messages:
|
|
||||||
msg_type = getattr(msg, "type", None)
|
|
||||||
|
|
||||||
if msg_type == "human":
|
|
||||||
content_str = extract_message_text(msg)
|
|
||||||
if "<uploaded_files>" in content_str:
|
|
||||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
|
||||||
if not stripped:
|
|
||||||
skip_next_ai = True
|
|
||||||
continue
|
|
||||||
clean_msg = copy(msg)
|
|
||||||
clean_msg.content = stripped
|
|
||||||
filtered.append(clean_msg)
|
|
||||||
skip_next_ai = False
|
|
||||||
else:
|
|
||||||
filtered.append(msg)
|
|
||||||
skip_next_ai = False
|
|
||||||
elif msg_type == "ai":
|
|
||||||
tool_calls = getattr(msg, "tool_calls", None)
|
|
||||||
if not tool_calls:
|
|
||||||
if skip_next_ai:
|
|
||||||
skip_next_ai = False
|
|
||||||
continue
|
|
||||||
filtered.append(msg)
|
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
|
|
||||||
def detect_correction(messages: list[Any]) -> bool:
|
|
||||||
"""Detect explicit user corrections in recent conversation turns."""
|
|
||||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
|
||||||
|
|
||||||
for msg in recent_user_msgs:
|
|
||||||
content = extract_message_text(msg).strip()
|
|
||||||
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
|
||||||
"""Detect explicit positive reinforcement signals in recent conversation turns."""
|
|
||||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
|
||||||
|
|
||||||
for msg in recent_user_msgs:
|
|
||||||
content = extract_message_text(msg).strip()
|
|
||||||
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
@@ -40,15 +40,6 @@ class MemoryUpdateQueue:
|
|||||||
self._timer: threading.Timer | None = None
|
self._timer: threading.Timer | None = None
|
||||||
self._processing = False
|
self._processing = False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _queue_key(
|
|
||||||
thread_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
agent_name: str | None,
|
|
||||||
) -> tuple[str, str | None, str | None]:
|
|
||||||
"""Return the debounce identity for a memory update target."""
|
|
||||||
return (thread_id, user_id, agent_name)
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -75,94 +66,49 @@ class MemoryUpdateQueue:
|
|||||||
return
|
return
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._enqueue_locked(
|
existing_context = next(
|
||||||
|
(context for context in self._queue if context.thread_id == thread_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||||
|
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||||
|
context = ConversationContext(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
correction_detected=correction_detected,
|
correction_detected=merged_correction_detected,
|
||||||
reinforcement_detected=reinforcement_detected,
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if this thread already has a pending update
|
||||||
|
# If so, replace it with the newer one
|
||||||
|
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||||
|
self._queue.append(context)
|
||||||
|
|
||||||
|
# Reset or start the debounce timer
|
||||||
self._reset_timer()
|
self._reset_timer()
|
||||||
|
|
||||||
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
|
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
|
||||||
|
|
||||||
def add_nowait(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
messages: list[Any],
|
|
||||||
agent_name: str | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
correction_detected: bool = False,
|
|
||||||
reinforcement_detected: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Add a conversation and start processing immediately in the background."""
|
|
||||||
config = get_memory_config()
|
|
||||||
if not config.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self._enqueue_locked(
|
|
||||||
thread_id=thread_id,
|
|
||||||
messages=messages,
|
|
||||||
agent_name=agent_name,
|
|
||||||
user_id=user_id,
|
|
||||||
correction_detected=correction_detected,
|
|
||||||
reinforcement_detected=reinforcement_detected,
|
|
||||||
)
|
|
||||||
self._schedule_timer(0)
|
|
||||||
|
|
||||||
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
|
|
||||||
|
|
||||||
def _enqueue_locked(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
thread_id: str,
|
|
||||||
messages: list[Any],
|
|
||||||
agent_name: str | None,
|
|
||||||
user_id: str | None,
|
|
||||||
correction_detected: bool,
|
|
||||||
reinforcement_detected: bool,
|
|
||||||
) -> None:
|
|
||||||
queue_key = self._queue_key(thread_id, user_id, agent_name)
|
|
||||||
existing_context = next(
|
|
||||||
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
|
||||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
|
||||||
context = ConversationContext(
|
|
||||||
thread_id=thread_id,
|
|
||||||
messages=messages,
|
|
||||||
agent_name=agent_name,
|
|
||||||
user_id=user_id,
|
|
||||||
correction_detected=merged_correction_detected,
|
|
||||||
reinforcement_detected=merged_reinforcement_detected,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
|
|
||||||
self._queue.append(context)
|
|
||||||
|
|
||||||
def _reset_timer(self) -> None:
|
def _reset_timer(self) -> None:
|
||||||
"""Reset the debounce timer."""
|
"""Reset the debounce timer."""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
self._schedule_timer(config.debounce_seconds)
|
|
||||||
|
|
||||||
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
|
||||||
|
|
||||||
def _schedule_timer(self, delay_seconds: float) -> None:
|
|
||||||
"""Schedule queue processing after the provided delay."""
|
|
||||||
# Cancel existing timer if any
|
# Cancel existing timer if any
|
||||||
if self._timer is not None:
|
if self._timer is not None:
|
||||||
self._timer.cancel()
|
self._timer.cancel()
|
||||||
|
|
||||||
|
# Start new timer
|
||||||
self._timer = threading.Timer(
|
self._timer = threading.Timer(
|
||||||
delay_seconds,
|
config.debounce_seconds,
|
||||||
self._process_queue,
|
self._process_queue,
|
||||||
)
|
)
|
||||||
self._timer.daemon = True
|
self._timer.daemon = True
|
||||||
self._timer.start()
|
self._timer.start()
|
||||||
|
|
||||||
|
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
||||||
|
|
||||||
def _process_queue(self) -> None:
|
def _process_queue(self) -> None:
|
||||||
"""Process all queued conversation contexts."""
|
"""Process all queued conversation contexts."""
|
||||||
# Import here to avoid circular dependency
|
# Import here to avoid circular dependency
|
||||||
@@ -170,8 +116,8 @@ class MemoryUpdateQueue:
|
|||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._processing:
|
if self._processing:
|
||||||
# Preserve immediate flush semantics even if another worker is active.
|
# Already processing, reschedule
|
||||||
self._schedule_timer(0)
|
self._reset_timer()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._queue:
|
if not self._queue:
|
||||||
@@ -225,13 +171,6 @@ class MemoryUpdateQueue:
|
|||||||
|
|
||||||
self._process_queue()
|
self._process_queue()
|
||||||
|
|
||||||
def flush_nowait(self) -> None:
|
|
||||||
"""Start queue processing immediately in a background thread."""
|
|
||||||
with self._lock:
|
|
||||||
# Daemon thread: queued messages may be lost if the process exits
|
|
||||||
# before _process_queue completes. Acceptable for best-effort memory updates.
|
|
||||||
self._schedule_timer(0)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear the queue without processing.
|
"""Clear the queue without processing.
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import abc
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -67,8 +66,6 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||||
# Value: (memory_data, file_mtime)
|
# Value: (memory_data, file_mtime)
|
||||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||||
# Guards all reads and writes to _memory_cache across concurrent callers.
|
|
||||||
self._cache_lock = threading.Lock()
|
|
||||||
|
|
||||||
def _validate_agent_name(self, agent_name: str) -> None:
|
def _validate_agent_name(self, agent_name: str) -> None:
|
||||||
"""Validate that the agent name is safe to use in filesystem paths.
|
"""Validate that the agent name is safe to use in filesystem paths.
|
||||||
@@ -116,60 +113,48 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
logger.warning("Failed to load memory file: %s", e)
|
logger.warning("Failed to load memory file: %s", e)
|
||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _cache_key(agent_name: str | None = None, *, user_id: str | None = None) -> tuple[str | None, str | None]:
|
|
||||||
return (user_id, agent_name)
|
|
||||||
|
|
||||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data (cached with file modification time check)."""
|
"""Load memory data (cached with file modification time check)."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
cache_key = self._cache_key(agent_name, user_id=user_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||||
except OSError:
|
except OSError:
|
||||||
current_mtime = None
|
current_mtime = None
|
||||||
|
|
||||||
with self._cache_lock:
|
cache_key = (user_id, agent_name)
|
||||||
cached = self._memory_cache.get(cache_key)
|
cached = self._memory_cache.get(cache_key)
|
||||||
if cached is not None and cached[1] == current_mtime:
|
|
||||||
return cached[0]
|
|
||||||
|
|
||||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
if cached is None or cached[1] != current_mtime:
|
||||||
|
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||||
with self._cache_lock:
|
|
||||||
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
||||||
|
return memory_data
|
||||||
|
|
||||||
return memory_data
|
return cached[0]
|
||||||
|
|
||||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Reload memory data from file, forcing cache invalidation."""
|
"""Reload memory data from file, forcing cache invalidation."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||||
cache_key = self._cache_key(agent_name, user_id=user_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||||
except OSError:
|
except OSError:
|
||||||
mtime = None
|
mtime = None
|
||||||
|
|
||||||
with self._cache_lock:
|
cache_key = (user_id, agent_name)
|
||||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||||
return memory_data
|
return memory_data
|
||||||
|
|
||||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||||
"""Save memory data to file and update cache."""
|
"""Save memory data to file and update cache."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
cache_key = self._cache_key(agent_name, user_id=user_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# Shallow-copy before adding lastUpdated so the caller's dict is not
|
memory_data["lastUpdated"] = utc_now_iso_z()
|
||||||
# mutated as a side-effect, and the cache reference is not silently
|
|
||||||
# updated before the file write succeeds.
|
|
||||||
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
|
|
||||||
|
|
||||||
temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
|
temp_path = file_path.with_suffix(".tmp")
|
||||||
with open(temp_path, "w", encoding="utf-8") as f:
|
with open(temp_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(memory_data, f, indent=2, ensure_ascii=False)
|
json.dump(memory_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
@@ -180,8 +165,8 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
except OSError:
|
except OSError:
|
||||||
mtime = None
|
mtime = None
|
||||||
|
|
||||||
with self._cache_lock:
|
cache_key = (user_id, agent_name)
|
||||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||||
logger.info("Memory saved to %s", file_path)
|
logger.info("Memory saved to %s", file_path)
|
||||||
return True
|
return True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
"""Hooks fired before summarization removes messages from state."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
|
||||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
|
||||||
from deerflow.config.memory_config import get_memory_config
|
|
||||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
|
||||||
|
|
||||||
|
|
||||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
|
||||||
"""Flush messages about to be summarized into the memory queue."""
|
|
||||||
if not get_memory_config().enabled or not event.thread_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
|
|
||||||
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
|
|
||||||
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
|
|
||||||
if not user_messages or not assistant_messages:
|
|
||||||
return
|
|
||||||
|
|
||||||
correction_detected = detect_correction(filtered_messages)
|
|
||||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
|
||||||
user_id = resolve_runtime_user_id(event.runtime)
|
|
||||||
queue = get_memory_queue()
|
|
||||||
queue.add_nowait(
|
|
||||||
thread_id=event.thread_id,
|
|
||||||
messages=filtered_messages,
|
|
||||||
agent_name=event.agent_name,
|
|
||||||
user_id=user_id,
|
|
||||||
correction_detected=correction_detected,
|
|
||||||
reinforcement_detected=reinforcement_detected,
|
|
||||||
)
|
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
"""Memory updater for reading, writing, and updating memory data."""
|
"""Memory updater for reading, writing, and updating memory data."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import atexit
|
|
||||||
import concurrent.futures
|
|
||||||
import copy
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -26,18 +22,6 @@ from deerflow.models import create_chat_model
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Thread pool for offloading sync memory updates when called from an async
|
|
||||||
# context. Unlike the previous asyncio.run() approach, this runs *sync*
|
|
||||||
# model.invoke() calls — no event loop is created, so the langchain async
|
|
||||||
# httpx client pool (globally cached via @lru_cache) is never touched and
|
|
||||||
# cross-loop connection reuse is impossible.
|
|
||||||
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
|
|
||||||
max_workers=4,
|
|
||||||
thread_name_prefix="memory-updater-sync",
|
|
||||||
)
|
|
||||||
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
|
|
||||||
|
|
||||||
|
|
||||||
def _create_empty_memory() -> dict[str, Any]:
|
def _create_empty_memory() -> dict[str, Any]:
|
||||||
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
|
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
|
||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
@@ -290,154 +274,6 @@ class MemoryUpdater:
|
|||||||
model_name = self._model_name or config.model_name
|
model_name = self._model_name or config.model_name
|
||||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||||
|
|
||||||
def _build_correction_hint(
|
|
||||||
self,
|
|
||||||
correction_detected: bool,
|
|
||||||
reinforcement_detected: bool,
|
|
||||||
) -> str:
|
|
||||||
"""Build optional prompt hints for correction and reinforcement signals."""
|
|
||||||
correction_hint = ""
|
|
||||||
if correction_detected:
|
|
||||||
correction_hint = (
|
|
||||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
|
||||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
|
||||||
"and record the correct approach as a fact with category "
|
|
||||||
'"correction" and confidence >= 0.95 when appropriate.'
|
|
||||||
)
|
|
||||||
if reinforcement_detected:
|
|
||||||
reinforcement_hint = (
|
|
||||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
|
||||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
|
||||||
"Record the confirmed approach, style, or preference as a fact with category "
|
|
||||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
|
||||||
)
|
|
||||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
|
||||||
|
|
||||||
return correction_hint
|
|
||||||
|
|
||||||
def _prepare_update_prompt(
|
|
||||||
self,
|
|
||||||
messages: list[Any],
|
|
||||||
agent_name: str | None,
|
|
||||||
correction_detected: bool,
|
|
||||||
reinforcement_detected: bool,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> tuple[dict[str, Any], str] | None:
|
|
||||||
"""Load memory and build the update prompt for a conversation."""
|
|
||||||
config = get_memory_config()
|
|
||||||
if not config.enabled or not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
current_memory = get_memory_data(agent_name, user_id=user_id)
|
|
||||||
conversation_text = format_conversation_for_update(messages)
|
|
||||||
if not conversation_text.strip():
|
|
||||||
return None
|
|
||||||
|
|
||||||
correction_hint = self._build_correction_hint(
|
|
||||||
correction_detected=correction_detected,
|
|
||||||
reinforcement_detected=reinforcement_detected,
|
|
||||||
)
|
|
||||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
|
||||||
current_memory=json.dumps(current_memory, indent=2, ensure_ascii=False),
|
|
||||||
conversation=conversation_text,
|
|
||||||
correction_hint=correction_hint,
|
|
||||||
)
|
|
||||||
return current_memory, prompt
|
|
||||||
|
|
||||||
def _finalize_update(
|
|
||||||
self,
|
|
||||||
current_memory: dict[str, Any],
|
|
||||||
response_content: Any,
|
|
||||||
thread_id: str | None,
|
|
||||||
agent_name: str | None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Parse the model response, apply updates, and persist memory."""
|
|
||||||
response_text = _extract_text(response_content).strip()
|
|
||||||
|
|
||||||
if response_text.startswith("```"):
|
|
||||||
lines = response_text.split("\n")
|
|
||||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
|
||||||
|
|
||||||
update_data = json.loads(response_text)
|
|
||||||
# Deep-copy before in-place mutation so a subsequent save() failure
|
|
||||||
# cannot corrupt the still-cached original object reference.
|
|
||||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
|
||||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
|
||||||
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
|
||||||
|
|
||||||
async def aupdate_memory(
|
|
||||||
self,
|
|
||||||
messages: list[Any],
|
|
||||||
thread_id: str | None = None,
|
|
||||||
agent_name: str | None = None,
|
|
||||||
correction_detected: bool = False,
|
|
||||||
reinforcement_detected: bool = False,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Update memory asynchronously by delegating to the sync path.
|
|
||||||
|
|
||||||
Uses ``asyncio.to_thread`` to run the *sync* ``model.invoke()`` path
|
|
||||||
in a worker thread so no second event loop is created and the
|
|
||||||
langchain async httpx client pool (shared with the lead agent) is
|
|
||||||
never touched. This eliminates the cross-loop connection-reuse bug
|
|
||||||
described in issue #2615.
|
|
||||||
"""
|
|
||||||
return await asyncio.to_thread(
|
|
||||||
self._do_update_memory_sync,
|
|
||||||
messages=messages,
|
|
||||||
thread_id=thread_id,
|
|
||||||
agent_name=agent_name,
|
|
||||||
correction_detected=correction_detected,
|
|
||||||
reinforcement_detected=reinforcement_detected,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _do_update_memory_sync(
|
|
||||||
self,
|
|
||||||
messages: list[Any],
|
|
||||||
thread_id: str | None = None,
|
|
||||||
agent_name: str | None = None,
|
|
||||||
correction_detected: bool = False,
|
|
||||||
reinforcement_detected: bool = False,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Pure-sync memory update using ``model.invoke()``.
|
|
||||||
|
|
||||||
Uses the *sync* LLM call path so no event loop is created. This
|
|
||||||
guarantees that the langchain provider's globally cached async
|
|
||||||
httpx ``AsyncClient`` / connection pool (the one shared with the
|
|
||||||
lead agent) is never touched — no cross-loop connection reuse is
|
|
||||||
possible.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prepared = self._prepare_update_prompt(
|
|
||||||
messages=messages,
|
|
||||||
agent_name=agent_name,
|
|
||||||
correction_detected=correction_detected,
|
|
||||||
reinforcement_detected=reinforcement_detected,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if prepared is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_memory, prompt = prepared
|
|
||||||
model = self._get_model()
|
|
||||||
response = model.invoke(prompt, config={"run_name": "memory_agent"})
|
|
||||||
return self._finalize_update(
|
|
||||||
current_memory=current_memory,
|
|
||||||
response_content=response.content,
|
|
||||||
thread_id=thread_id,
|
|
||||||
agent_name=agent_name,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Memory update failed: %s", e)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def update_memory(
|
def update_memory(
|
||||||
self,
|
self,
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
@@ -447,16 +283,7 @@ class MemoryUpdater:
|
|||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Synchronously update memory using the sync LLM path.
|
"""Update memory based on conversation messages.
|
||||||
|
|
||||||
Uses ``model.invoke()`` (sync HTTP) which operates on a completely
|
|
||||||
separate connection pool from the async ``AsyncClient`` shared by
|
|
||||||
the lead agent. This eliminates the cross-loop connection-reuse
|
|
||||||
bug described in issue #2615.
|
|
||||||
|
|
||||||
When called from within a running event loop (e.g. from a LangGraph
|
|
||||||
node), the blocking sync call is offloaded to a thread pool so the
|
|
||||||
caller's loop is not blocked.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of conversation messages.
|
messages: List of conversation messages.
|
||||||
@@ -469,35 +296,78 @@ class MemoryUpdater:
|
|||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
config = get_memory_config()
|
||||||
loop = asyncio.get_running_loop()
|
if not config.enabled:
|
||||||
except RuntimeError:
|
return False
|
||||||
loop = None
|
|
||||||
|
|
||||||
if loop is not None and loop.is_running():
|
if not messages:
|
||||||
try:
|
return False
|
||||||
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(
|
|
||||||
self._do_update_memory_sync,
|
try:
|
||||||
messages=messages,
|
# Get current memory
|
||||||
thread_id=thread_id,
|
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||||
agent_name=agent_name,
|
|
||||||
correction_detected=correction_detected,
|
# Format conversation for prompt
|
||||||
reinforcement_detected=reinforcement_detected,
|
conversation_text = format_conversation_for_update(messages)
|
||||||
user_id=user_id,
|
|
||||||
)
|
if not conversation_text.strip():
|
||||||
return future.result()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to offload memory update to executor")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self._do_update_memory_sync(
|
# Build prompt
|
||||||
messages=messages,
|
correction_hint = ""
|
||||||
thread_id=thread_id,
|
if correction_detected:
|
||||||
agent_name=agent_name,
|
correction_hint = (
|
||||||
correction_detected=correction_detected,
|
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||||
reinforcement_detected=reinforcement_detected,
|
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||||
user_id=user_id,
|
"and record the correct approach as a fact with category "
|
||||||
)
|
'"correction" and confidence >= 0.95 when appropriate.'
|
||||||
|
)
|
||||||
|
if reinforcement_detected:
|
||||||
|
reinforcement_hint = (
|
||||||
|
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||||
|
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||||
|
"Record the confirmed approach, style, or preference as a fact with category "
|
||||||
|
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||||
|
)
|
||||||
|
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||||
|
|
||||||
|
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||||
|
current_memory=json.dumps(current_memory, indent=2),
|
||||||
|
conversation=conversation_text,
|
||||||
|
correction_hint=correction_hint,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call LLM
|
||||||
|
model = self._get_model()
|
||||||
|
response = model.invoke(prompt)
|
||||||
|
response_text = _extract_text(response.content).strip()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
# Remove markdown code blocks if present
|
||||||
|
if response_text.startswith("```"):
|
||||||
|
lines = response_text.split("\n")
|
||||||
|
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||||
|
|
||||||
|
update_data = json.loads(response_text)
|
||||||
|
|
||||||
|
# Apply updates
|
||||||
|
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||||
|
|
||||||
|
# Strip file-upload mentions from all summaries before saving.
|
||||||
|
# Uploaded files are session-scoped and won't exist in future sessions,
|
||||||
|
# so recording upload events in long-term memory causes the agent to
|
||||||
|
# try (and fail) to locate those files in subsequent conversations.
|
||||||
|
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||||
|
|
||||||
|
# Save
|
||||||
|
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Memory update failed: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
def _apply_updates(
|
def _apply_updates(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from hashlib import sha256
|
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
@@ -37,13 +36,6 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
|||||||
|
|
||||||
state_schema = ClarificationMiddlewareState
|
state_schema = ClarificationMiddlewareState
|
||||||
|
|
||||||
def _stable_message_id(self, tool_call_id: str, formatted_message: str) -> str:
|
|
||||||
"""Build a deterministic message ID so retried clarification calls replace, not append."""
|
|
||||||
if tool_call_id:
|
|
||||||
return f"clarification:{tool_call_id}"
|
|
||||||
digest = sha256(formatted_message.encode("utf-8")).hexdigest()[:16]
|
|
||||||
return f"clarification:{digest}"
|
|
||||||
|
|
||||||
def _is_chinese(self, text: str) -> bool:
|
def _is_chinese(self, text: str) -> bool:
|
||||||
"""Check if text contains Chinese characters.
|
"""Check if text contains Chinese characters.
|
||||||
|
|
||||||
@@ -139,7 +131,6 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
|||||||
# Create a ToolMessage with the formatted question
|
# Create a ToolMessage with the formatted question
|
||||||
# This will be added to the message history
|
# This will be added to the message history
|
||||||
tool_message = ToolMessage(
|
tool_message = ToolMessage(
|
||||||
id=self._stable_message_id(tool_call_id, formatted_message),
|
|
||||||
content=formatted_message,
|
content=formatted_message,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
name="ask_clarification",
|
name="ask_clarification",
|
||||||
|
|||||||
+25
-100
@@ -13,7 +13,6 @@ at the correct positions (immediately after each dangling AIMessage), not append
|
|||||||
to the end of the message list as before_model + add_messages reducer would do.
|
to the end of the message list as before_model + add_messages reducer would do.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -34,132 +33,58 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
offending AIMessage so the LLM receives a well-formed conversation.
|
offending AIMessage so the LLM receives a well-formed conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _message_tool_calls(msg) -> list[dict]:
|
|
||||||
"""Return normalized tool calls from structured fields or raw provider payloads.
|
|
||||||
|
|
||||||
LangChain stores malformed provider function calls in ``invalid_tool_calls``.
|
|
||||||
They do not execute, but provider adapters may still serialize enough of
|
|
||||||
the call id/name back into the next request that strict OpenAI-compatible
|
|
||||||
validators expect a matching ToolMessage. Treat them as dangling calls so
|
|
||||||
the next model request stays well-formed and the model sees a recoverable
|
|
||||||
tool error instead of another provider 400.
|
|
||||||
"""
|
|
||||||
normalized: list[dict] = []
|
|
||||||
|
|
||||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
|
||||||
normalized.extend(list(tool_calls))
|
|
||||||
|
|
||||||
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
|
|
||||||
if not tool_calls:
|
|
||||||
for raw_tc in raw_tool_calls:
|
|
||||||
if not isinstance(raw_tc, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
function = raw_tc.get("function")
|
|
||||||
name = raw_tc.get("name")
|
|
||||||
if not name and isinstance(function, dict):
|
|
||||||
name = function.get("name")
|
|
||||||
|
|
||||||
args = raw_tc.get("args", {})
|
|
||||||
if not args and isinstance(function, dict):
|
|
||||||
raw_args = function.get("arguments")
|
|
||||||
if isinstance(raw_args, str):
|
|
||||||
try:
|
|
||||||
parsed_args = json.loads(raw_args)
|
|
||||||
except (TypeError, ValueError, json.JSONDecodeError):
|
|
||||||
parsed_args = {}
|
|
||||||
args = parsed_args if isinstance(parsed_args, dict) else {}
|
|
||||||
|
|
||||||
normalized.append(
|
|
||||||
{
|
|
||||||
"id": raw_tc.get("id"),
|
|
||||||
"name": name or "unknown",
|
|
||||||
"args": args if isinstance(args, dict) else {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
for invalid_tc in getattr(msg, "invalid_tool_calls", None) or []:
|
|
||||||
if not isinstance(invalid_tc, dict):
|
|
||||||
continue
|
|
||||||
normalized.append(
|
|
||||||
{
|
|
||||||
"id": invalid_tc.get("id"),
|
|
||||||
"name": invalid_tc.get("name") or "unknown",
|
|
||||||
"args": {},
|
|
||||||
"invalid": True,
|
|
||||||
"error": invalid_tc.get("error"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _synthetic_tool_message_content(tool_call: dict) -> str:
|
|
||||||
if tool_call.get("invalid"):
|
|
||||||
error = tool_call.get("error")
|
|
||||||
if isinstance(error, str) and error:
|
|
||||||
return f"[Tool call could not be executed because its arguments were invalid: {error}]"
|
|
||||||
return "[Tool call could not be executed because its arguments were invalid.]"
|
|
||||||
return "[Tool call was interrupted and did not return a result.]"
|
|
||||||
|
|
||||||
def _build_patched_messages(self, messages: list) -> list | None:
|
def _build_patched_messages(self, messages: list) -> list | None:
|
||||||
"""Return messages with tool results grouped after their tool-call AIMessage.
|
"""Return a new message list with patches inserted at the correct positions.
|
||||||
|
|
||||||
This normalizes model-bound causal order before provider serialization while
|
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||||
preserving already-valid transcripts unchanged.
|
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||||
|
Returns None if no patches are needed.
|
||||||
"""
|
"""
|
||||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
# Collect IDs of all existing ToolMessages
|
||||||
|
existing_tool_msg_ids: set[str] = set()
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||||
|
|
||||||
tool_call_ids: set[str] = set()
|
# Check if any patching is needed
|
||||||
|
needs_patch = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
for tc in self._message_tool_calls(msg):
|
for tc in getattr(msg, "tool_calls", None) or []:
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if tc_id:
|
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||||
tool_call_ids.add(tc_id)
|
needs_patch = True
|
||||||
|
break
|
||||||
|
if needs_patch:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not needs_patch:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build new list with patches inserted right after each dangling AIMessage
|
||||||
patched: list = []
|
patched: list = []
|
||||||
consumed_tool_msg_ids: set[str] = set()
|
patched_ids: set[str] = set()
|
||||||
patch_count = 0
|
patch_count = 0
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
patched.append(msg)
|
patched.append(msg)
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
|
for tc in getattr(msg, "tool_calls", None) or []:
|
||||||
for tc in self._message_tool_calls(msg):
|
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||||
continue
|
|
||||||
|
|
||||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
|
||||||
if existing_tool_msg is not None:
|
|
||||||
patched.append(existing_tool_msg)
|
|
||||||
consumed_tool_msg_ids.add(tc_id)
|
|
||||||
else:
|
|
||||||
patched.append(
|
patched.append(
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=self._synthetic_tool_message_content(tc),
|
content="[Tool call was interrupted and did not return a result.]",
|
||||||
tool_call_id=tc_id,
|
tool_call_id=tc_id,
|
||||||
name=tc.get("name", "unknown"),
|
name=tc.get("name", "unknown"),
|
||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
consumed_tool_msg_ids.add(tc_id)
|
patched_ids.add(tc_id)
|
||||||
patch_count += 1
|
patch_count += 1
|
||||||
|
|
||||||
if patched == messages:
|
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||||
return None
|
|
||||||
|
|
||||||
if patch_count:
|
|
||||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
|
||||||
return patched
|
return patched
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
+1
-48
@@ -16,9 +16,6 @@ from typing import override
|
|||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||||
from langchain_core.messages import ToolMessage
|
|
||||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
|
||||||
from langgraph.types import Command
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +35,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if not registry:
|
if not registry:
|
||||||
return request
|
return request
|
||||||
|
|
||||||
deferred_names = registry.deferred_names
|
deferred_names = {e.name for e in registry.entries}
|
||||||
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
||||||
|
|
||||||
if len(active_tools) < len(request.tools):
|
if len(active_tools) < len(request.tools):
|
||||||
@@ -46,28 +43,6 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
return request.override(tools=active_tools)
|
return request.override(tools=active_tools)
|
||||||
|
|
||||||
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
|
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
|
||||||
|
|
||||||
registry = get_deferred_registry()
|
|
||||||
if not registry:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tool_name = str(request.tool_call.get("name") or "")
|
|
||||||
if not tool_name:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not registry.contains(tool_name):
|
|
||||||
return None
|
|
||||||
|
|
||||||
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
|
|
||||||
return ToolMessage(
|
|
||||||
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def wrap_model_call(
|
def wrap_model_call(
|
||||||
self,
|
self,
|
||||||
@@ -76,17 +51,6 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
return handler(self._filter_tools(request))
|
return handler(self._filter_tools(request))
|
||||||
|
|
||||||
@override
|
|
||||||
def wrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
blocked = self._blocked_tool_message(request)
|
|
||||||
if blocked is not None:
|
|
||||||
return blocked
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
self,
|
self,
|
||||||
@@ -94,14 +58,3 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
return await handler(self._filter_tools(request))
|
return await handler(self._filter_tools(request))
|
||||||
|
|
||||||
@override
|
|
||||||
async def awrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
blocked = self._blocked_tool_message(request)
|
|
||||||
if blocked is not None:
|
|
||||||
return blocked
|
|
||||||
return await handler(request)
|
|
||||||
|
|||||||
@@ -1,204 +0,0 @@
|
|||||||
"""Middleware to inject dynamic context (memory, current date) as a system-reminder.
|
|
||||||
|
|
||||||
The system prompt is kept fully static for maximum prefix-cache reuse across users
|
|
||||||
and sessions. The current date is always injected. Per-user memory is also injected
|
|
||||||
when ``memory.injection_enabled`` is True in the app config. Both are delivered once
|
|
||||||
per conversation as a dedicated <system-reminder> HumanMessage inserted before the
|
|
||||||
first user message (frozen-snapshot pattern).
|
|
||||||
|
|
||||||
When a conversation spans midnight the middleware detects the date change and injects
|
|
||||||
a lightweight date-update reminder as a separate HumanMessage before the current turn.
|
|
||||||
This correction is persisted so subsequent turns on the new day see a consistent history
|
|
||||||
and do not re-inject.
|
|
||||||
|
|
||||||
Reminder format:
|
|
||||||
|
|
||||||
<system-reminder>
|
|
||||||
<memory>...</memory>
|
|
||||||
|
|
||||||
<current_date>2026-05-08, Friday</current_date>
|
|
||||||
</system-reminder>
|
|
||||||
|
|
||||||
Date-update format:
|
|
||||||
|
|
||||||
<system-reminder>
|
|
||||||
<current_date>2026-05-09, Saturday</current_date>
|
|
||||||
</system-reminder>
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING, override
|
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
|
||||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
|
||||||
_SUMMARY_MESSAGE_NAME = "summary"
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_date(content: str) -> str | None:
|
|
||||||
"""Return the first <current_date> value found in *content*, or None."""
|
|
||||||
m = _DATE_RE.search(content)
|
|
||||||
return m.group(1) if m else None
|
|
||||||
|
|
||||||
|
|
||||||
def is_dynamic_context_reminder(message: object) -> bool:
|
|
||||||
"""Return whether *message* is a hidden dynamic-context reminder."""
|
|
||||||
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
|
|
||||||
|
|
||||||
|
|
||||||
def _last_injected_date(messages: list) -> str | None:
|
|
||||||
"""Scan messages in reverse and return the most recently injected date.
|
|
||||||
|
|
||||||
Detection uses the ``dynamic_context_reminder`` additional_kwargs flag rather
|
|
||||||
than content substring matching, so user messages containing ``<system-reminder>``
|
|
||||||
are not mistakenly treated as injected reminders.
|
|
||||||
"""
|
|
||||||
for msg in reversed(messages):
|
|
||||||
if is_dynamic_context_reminder(msg):
|
|
||||||
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
||||||
return _extract_date(content_str)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_user_injection_target(message: object) -> bool:
|
|
||||||
"""Return whether *message* can receive a dynamic-context reminder."""
|
|
||||||
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicContextMiddleware(AgentMiddleware):
|
|
||||||
"""Inject memory and current date into HumanMessages as a <system-reminder>.
|
|
||||||
|
|
||||||
First turn
|
|
||||||
----------
|
|
||||||
Prepends a full system-reminder (memory + date) to the first HumanMessage and
|
|
||||||
persists it (same message ID). The first message is then frozen for the whole
|
|
||||||
session — its content never changes again, so the prefix cache can hit on every
|
|
||||||
subsequent turn.
|
|
||||||
|
|
||||||
Midnight crossing
|
|
||||||
-----------------
|
|
||||||
If the conversation spans midnight, the current date differs from the date that
|
|
||||||
was injected earlier. In that case a lightweight date-update reminder is prepended
|
|
||||||
to the **current** (last) HumanMessage and persisted. Subsequent turns on the new
|
|
||||||
day see the corrected date in history and skip re-injection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent_name: str | None = None, *, app_config: AppConfig | None = None):
|
|
||||||
super().__init__()
|
|
||||||
self._agent_name = agent_name
|
|
||||||
self._app_config = app_config
|
|
||||||
|
|
||||||
def _build_full_reminder(self) -> str:
|
|
||||||
from deerflow.agents.lead_agent.prompt import _get_memory_context
|
|
||||||
|
|
||||||
# Memory injection is gated by injection_enabled; date is always included.
|
|
||||||
injection_enabled = self._app_config.memory.injection_enabled if self._app_config else True
|
|
||||||
memory_context = _get_memory_context(self._agent_name, app_config=self._app_config) if injection_enabled else ""
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
|
||||||
|
|
||||||
lines: list[str] = ["<system-reminder>"]
|
|
||||||
if memory_context:
|
|
||||||
lines.append(memory_context.strip())
|
|
||||||
lines.append("") # blank line separating memory from date
|
|
||||||
lines.append(f"<current_date>{current_date}</current_date>")
|
|
||||||
lines.append("</system-reminder>")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
def _build_date_update_reminder(self) -> str:
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
"<system-reminder>",
|
|
||||||
f"<current_date>{current_date}</current_date>",
|
|
||||||
"</system-reminder>",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_reminder_and_user_messages(original: HumanMessage, reminder_content: str) -> tuple[HumanMessage, HumanMessage]:
|
|
||||||
"""Return (reminder_msg, user_msg) using the ID-swap technique.
|
|
||||||
|
|
||||||
reminder_msg takes the original message's ID so that add_messages replaces it
|
|
||||||
in-place (preserving position). user_msg carries the original content with a
|
|
||||||
derived ``{id}__user`` ID and is appended immediately after by add_messages.
|
|
||||||
|
|
||||||
If the original message has no ID a stable UUID is generated so the derived
|
|
||||||
``{id}__user`` ID never collapses to the ambiguous ``None__user`` string.
|
|
||||||
"""
|
|
||||||
stable_id = original.id or str(uuid.uuid4())
|
|
||||||
reminder_msg = HumanMessage(
|
|
||||||
content=reminder_content,
|
|
||||||
id=stable_id,
|
|
||||||
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
|
||||||
)
|
|
||||||
user_msg = HumanMessage(
|
|
||||||
content=original.content,
|
|
||||||
id=f"{stable_id}__user",
|
|
||||||
name=original.name,
|
|
||||||
additional_kwargs=original.additional_kwargs,
|
|
||||||
)
|
|
||||||
return reminder_msg, user_msg
|
|
||||||
|
|
||||||
def _inject(self, state) -> dict | None:
|
|
||||||
messages = list(state.get("messages", []))
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
|
||||||
last_date = _last_injected_date(messages)
|
|
||||||
logger.debug(
|
|
||||||
"DynamicContextMiddleware._inject: msg_count=%d last_date=%r current_date=%r",
|
|
||||||
len(messages),
|
|
||||||
last_date,
|
|
||||||
current_date,
|
|
||||||
)
|
|
||||||
|
|
||||||
if last_date is None:
|
|
||||||
# ── First turn: inject full reminder as a separate HumanMessage ─────
|
|
||||||
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
|
|
||||||
if first_idx is None:
|
|
||||||
return None
|
|
||||||
full_reminder = self._build_full_reminder()
|
|
||||||
logger.info(
|
|
||||||
"DynamicContextMiddleware: injecting full reminder (len=%d, has_memory=%s) into first HumanMessage id=%r",
|
|
||||||
len(full_reminder),
|
|
||||||
"<memory>" in full_reminder,
|
|
||||||
messages[first_idx].id,
|
|
||||||
)
|
|
||||||
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[first_idx], full_reminder)
|
|
||||||
return {"messages": [reminder_msg, user_msg]}
|
|
||||||
|
|
||||||
if last_date == current_date:
|
|
||||||
# ── Same day: nothing to do ──────────────────────────────────────────
|
|
||||||
return None
|
|
||||||
|
|
||||||
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
|
|
||||||
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
|
|
||||||
if last_human_idx is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[last_human_idx], self._build_date_update_reminder())
|
|
||||||
logger.info("DynamicContextMiddleware: midnight crossing detected — injected date update before current turn")
|
|
||||||
return {"messages": [reminder_msg, user_msg]}
|
|
||||||
|
|
||||||
@override
|
|
||||||
def before_agent(self, state, runtime: Runtime) -> dict | None:
|
|
||||||
return self._inject(state)
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
|
||||||
return self._inject(state)
|
|
||||||
+2
-95
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from email.utils import parsedate_to_datetime
|
from email.utils import parsedate_to_datetime
|
||||||
@@ -20,8 +19,6 @@ from langchain.agents.middleware.types import (
|
|||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langgraph.errors import GraphBubbleUp
|
from langgraph.errors import GraphBubbleUp
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||||
@@ -70,71 +67,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
retry_base_delay_ms: int = 1000
|
retry_base_delay_ms: int = 1000
|
||||||
retry_cap_delay_ms: int = 8000
|
retry_cap_delay_ms: int = 8000
|
||||||
|
|
||||||
def __init__(self, *, app_config: AppConfig, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
|
|
||||||
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
|
|
||||||
|
|
||||||
# Circuit Breaker state
|
|
||||||
self._circuit_lock = threading.Lock()
|
|
||||||
self._circuit_failure_count = 0
|
|
||||||
self._circuit_open_until = 0.0
|
|
||||||
self._circuit_state = "closed"
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
|
|
||||||
def _check_circuit(self) -> bool:
|
|
||||||
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
|
|
||||||
with self._circuit_lock:
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
if self._circuit_state == "open":
|
|
||||||
if now < self._circuit_open_until:
|
|
||||||
return True
|
|
||||||
self._circuit_state = "half_open"
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
|
|
||||||
if self._circuit_state == "half_open":
|
|
||||||
if self._circuit_probe_in_flight:
|
|
||||||
return True
|
|
||||||
self._circuit_probe_in_flight = True
|
|
||||||
return False
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _record_success(self) -> None:
|
|
||||||
with self._circuit_lock:
|
|
||||||
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
|
|
||||||
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
|
|
||||||
self._circuit_failure_count = 0
|
|
||||||
self._circuit_open_until = 0.0
|
|
||||||
self._circuit_state = "closed"
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
|
|
||||||
def _record_failure(self) -> None:
|
|
||||||
with self._circuit_lock:
|
|
||||||
if self._circuit_state == "half_open":
|
|
||||||
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
|
|
||||||
self._circuit_state = "open"
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
logger.error(
|
|
||||||
"Circuit breaker probe failed (Open). Will probe again after %ds.",
|
|
||||||
self.circuit_recovery_timeout_sec,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._circuit_failure_count += 1
|
|
||||||
if self._circuit_failure_count >= self.circuit_failure_threshold:
|
|
||||||
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
|
|
||||||
if self._circuit_state != "open":
|
|
||||||
self._circuit_state = "open"
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
logger.error(
|
|
||||||
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
|
|
||||||
self.circuit_failure_threshold,
|
|
||||||
self.circuit_recovery_timeout_sec,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||||
detail = _extract_error_detail(exc)
|
detail = _extract_error_detail(exc)
|
||||||
lowered = detail.lower()
|
lowered = detail.lower()
|
||||||
@@ -151,8 +83,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
"APITimeoutError",
|
"APITimeoutError",
|
||||||
"APIConnectionError",
|
"APIConnectionError",
|
||||||
"InternalServerError",
|
"InternalServerError",
|
||||||
"ReadError", # httpx.ReadError: connection dropped mid-stream
|
|
||||||
"RemoteProtocolError", # httpx: server closed connection unexpectedly
|
|
||||||
}:
|
}:
|
||||||
return True, "transient"
|
return True, "transient"
|
||||||
if status_code in _RETRIABLE_STATUS_CODES:
|
if status_code in _RETRIABLE_STATUS_CODES:
|
||||||
@@ -174,9 +104,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||||
|
|
||||||
def _build_circuit_breaker_message(self) -> str:
|
|
||||||
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
|
||||||
|
|
||||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||||
detail = _extract_error_detail(exc)
|
detail = _extract_error_detail(exc)
|
||||||
if reason == "quota":
|
if reason == "quota":
|
||||||
@@ -211,20 +138,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
if self._check_circuit():
|
|
||||||
return AIMessage(content=self._build_circuit_breaker_message())
|
|
||||||
|
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
response = handler(request)
|
return handler(request)
|
||||||
self._record_success()
|
|
||||||
return response
|
|
||||||
except GraphBubbleUp:
|
except GraphBubbleUp:
|
||||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||||
with self._circuit_lock:
|
|
||||||
if self._circuit_state == "half_open":
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
retriable, reason = self._classify_error(exc)
|
retriable, reason = self._classify_error(exc)
|
||||||
@@ -247,8 +166,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
_extract_error_detail(exc),
|
_extract_error_detail(exc),
|
||||||
exc_info=exc,
|
exc_info=exc,
|
||||||
)
|
)
|
||||||
if retriable:
|
|
||||||
self._record_failure()
|
|
||||||
return AIMessage(content=self._build_user_message(exc, reason))
|
return AIMessage(content=self._build_user_message(exc, reason))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -257,20 +174,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
if self._check_circuit():
|
|
||||||
return AIMessage(content=self._build_circuit_breaker_message())
|
|
||||||
|
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
response = await handler(request)
|
return await handler(request)
|
||||||
self._record_success()
|
|
||||||
return response
|
|
||||||
except GraphBubbleUp:
|
except GraphBubbleUp:
|
||||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||||
with self._circuit_lock:
|
|
||||||
if self._circuit_state == "half_open":
|
|
||||||
self._circuit_probe_in_flight = False
|
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
retriable, reason = self._classify_error(exc)
|
retriable, reason = self._classify_error(exc)
|
||||||
@@ -293,8 +202,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
_extract_error_detail(exc),
|
_extract_error_detail(exc),
|
||||||
exc_info=exc,
|
exc_info=exc,
|
||||||
)
|
)
|
||||||
if retriable:
|
|
||||||
self._record_failure()
|
|
||||||
return AIMessage(content=self._build_user_message(exc, reason))
|
return AIMessage(content=self._build_user_message(exc, reason))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+18
-151
@@ -12,23 +12,18 @@ Detection strategy:
|
|||||||
response so the agent is forced to produce a final text answer.
|
response so the agent is forced to produce a final text answer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from copy import deepcopy
|
from typing import override
|
||||||
from typing import TYPE_CHECKING, override
|
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Defaults — can be overridden via constructor
|
# Defaults — can be overridden via constructor
|
||||||
@@ -36,8 +31,6 @@ _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
|||||||
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
||||||
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||||
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
|
||||||
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||||
@@ -132,21 +125,12 @@ def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
|||||||
|
|
||||||
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||||
|
|
||||||
_TOOL_FREQ_WARNING_MSG = (
|
|
||||||
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
|
||||||
)
|
|
||||||
|
|
||||||
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
||||||
|
|
||||||
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
|
|
||||||
|
|
||||||
|
|
||||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||||
"""Detects and breaks repetitive tool call loops.
|
"""Detects and breaks repetitive tool call loops.
|
||||||
|
|
||||||
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
|
|
||||||
construct via :meth:`from_config` to ensure values pass Pydantic validation.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
warn_threshold: Number of identical tool call sets before injecting
|
warn_threshold: Number of identical tool call sets before injecting
|
||||||
a warning message. Default: 3.
|
a warning message. Default: 3.
|
||||||
@@ -156,20 +140,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
Default: 20.
|
Default: 20.
|
||||||
max_tracked_threads: Maximum number of threads to track before
|
max_tracked_threads: Maximum number of threads to track before
|
||||||
evicting the least recently used. Default: 100.
|
evicting the least recently used. Default: 100.
|
||||||
tool_freq_warn: Number of calls to the same tool *type* (regardless
|
|
||||||
of arguments) before injecting a frequency warning. Catches
|
|
||||||
cross-file read loops that hash-based detection misses.
|
|
||||||
Default: 30.
|
|
||||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
|
||||||
forcing a stop. Default: 50.
|
|
||||||
tool_freq_overrides: Per-tool overrides for frequency thresholds,
|
|
||||||
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
|
|
||||||
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
|
|
||||||
that specific tool. Tools not listed here fall back to the global
|
|
||||||
thresholds. Useful for raising limits on intentionally
|
|
||||||
high-frequency tools (e.g. ``bash`` in batch pipelines) without
|
|
||||||
weakening protection on all other tools. Default: ``None``
|
|
||||||
(no overrides).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -178,36 +148,16 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
||||||
window_size: int = _DEFAULT_WINDOW_SIZE,
|
window_size: int = _DEFAULT_WINDOW_SIZE,
|
||||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
|
||||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
|
||||||
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.warn_threshold = warn_threshold
|
self.warn_threshold = warn_threshold
|
||||||
self.hard_limit = hard_limit
|
self.hard_limit = hard_limit
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.max_tracked_threads = max_tracked_threads
|
self.max_tracked_threads = max_tracked_threads
|
||||||
self.tool_freq_warn = tool_freq_warn
|
|
||||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
|
||||||
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
# Per-thread tracking using OrderedDict for LRU eviction
|
||||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
||||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
|
||||||
"""Construct from a Pydantic-validated config, trusting its validation."""
|
|
||||||
return cls(
|
|
||||||
warn_threshold=config.warn_threshold,
|
|
||||||
hard_limit=config.hard_limit,
|
|
||||||
window_size=config.window_size,
|
|
||||||
max_tracked_threads=config.max_tracked_threads,
|
|
||||||
tool_freq_warn=config.tool_freq_warn,
|
|
||||||
tool_freq_hard_limit=config.tool_freq_hard_limit,
|
|
||||||
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||||
@@ -224,19 +174,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
while len(self._history) > self.max_tracked_threads:
|
while len(self._history) > self.max_tracked_threads:
|
||||||
evicted_id, _ = self._history.popitem(last=False)
|
evicted_id, _ = self._history.popitem(last=False)
|
||||||
self._warned.pop(evicted_id, None)
|
self._warned.pop(evicted_id, None)
|
||||||
self._tool_freq.pop(evicted_id, None)
|
|
||||||
self._tool_freq_warned.pop(evicted_id, None)
|
|
||||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||||
|
|
||||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||||
"""Track tool calls and check for loops.
|
"""Track tool calls and check for loops.
|
||||||
|
|
||||||
Two detection layers:
|
|
||||||
1. **Hash-based** (existing): catches identical tool call sets.
|
|
||||||
2. **Frequency-based** (new): catches the same *tool type* being
|
|
||||||
called many times with varying arguments (e.g. ``read_file``
|
|
||||||
on 40 different files).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(warning_message_or_none, should_hard_stop)
|
(warning_message_or_none, should_hard_stop)
|
||||||
"""
|
"""
|
||||||
@@ -271,7 +213,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
count = history.count(call_hash)
|
count = history.count(call_hash)
|
||||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||||
|
|
||||||
# --- Layer 1: hash-based (identical call sets) ---
|
|
||||||
if count >= self.hard_limit:
|
if count >= self.hard_limit:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loop hard limit reached — forcing stop",
|
"Loop hard limit reached — forcing stop",
|
||||||
@@ -298,45 +239,8 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return _WARNING_MSG, False
|
return _WARNING_MSG, False
|
||||||
|
# Warning already injected for this hash — suppress
|
||||||
# --- Layer 2: per-tool-type frequency ---
|
return None, False
|
||||||
freq = self._tool_freq[thread_id]
|
|
||||||
for tc in tool_calls:
|
|
||||||
name = tc.get("name", "")
|
|
||||||
if not name:
|
|
||||||
continue
|
|
||||||
freq[name] += 1
|
|
||||||
tc_count = freq[name]
|
|
||||||
|
|
||||||
if name in self._tool_freq_overrides:
|
|
||||||
eff_warn, eff_hard = self._tool_freq_overrides[name]
|
|
||||||
else:
|
|
||||||
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
|
|
||||||
|
|
||||||
if tc_count >= eff_hard:
|
|
||||||
logger.error(
|
|
||||||
"Tool frequency hard limit reached — forcing stop",
|
|
||||||
extra={
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"tool_name": name,
|
|
||||||
"count": tc_count,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
|
||||||
|
|
||||||
if tc_count >= eff_warn:
|
|
||||||
warned = self._tool_freq_warned[thread_id]
|
|
||||||
if name not in warned:
|
|
||||||
warned.add(name)
|
|
||||||
logger.warning(
|
|
||||||
"Tool frequency warning — too many calls to same tool type",
|
|
||||||
extra={
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"tool_name": name,
|
|
||||||
"count": tc_count,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
|
|
||||||
|
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
@@ -357,26 +261,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||||
return str(content) + f"\n\n{text}"
|
return str(content) + f"\n\n{text}"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
|
|
||||||
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
|
|
||||||
update = {
|
|
||||||
"tool_calls": [],
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
|
|
||||||
for key in ("tool_calls", "function_call"):
|
|
||||||
additional_kwargs.pop(key, None)
|
|
||||||
update["additional_kwargs"] = additional_kwargs
|
|
||||||
|
|
||||||
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
|
|
||||||
if response_metadata.get("finish_reason") == "tool_calls":
|
|
||||||
response_metadata["finish_reason"] = "stop"
|
|
||||||
update["response_metadata"] = response_metadata
|
|
||||||
|
|
||||||
return update
|
|
||||||
|
|
||||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
warning, hard_stop = self._track_and_check(state, runtime)
|
warning, hard_stop = self._track_and_check(state, runtime)
|
||||||
|
|
||||||
@@ -384,35 +268,22 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
# Strip tool_calls from the last AIMessage to force text output
|
# Strip tool_calls from the last AIMessage to force text output
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
last_msg = messages[-1]
|
last_msg = messages[-1]
|
||||||
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
|
stripped_msg = last_msg.model_copy(
|
||||||
stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
|
update={
|
||||||
|
"tool_calls": [],
|
||||||
|
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
||||||
|
}
|
||||||
|
)
|
||||||
return {"messages": [stripped_msg]}
|
return {"messages": [stripped_msg]}
|
||||||
|
|
||||||
if warning:
|
if warning:
|
||||||
# WORKAROUND for v2.0-m1 — see #2724.
|
# Inject as HumanMessage instead of SystemMessage to avoid
|
||||||
#
|
# Anthropic's "multiple non-consecutive system messages" error.
|
||||||
# Append the warning to the AIMessage content instead of
|
# Anthropic models require system messages only at the start of
|
||||||
# injecting a separate HumanMessage. Inserting any non-tool
|
# the conversation; injecting one mid-conversation crashes
|
||||||
# message between an AIMessage(tool_calls=...) and its
|
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||||
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
|
# with all providers. See #1299.
|
||||||
# validation ("tool_call_ids did not have response messages")
|
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
||||||
# because the tools node has not run yet at after_model time.
|
|
||||||
# tool_calls are preserved so the tools node still executes.
|
|
||||||
#
|
|
||||||
# This is a temporary mitigation: mutating an existing
|
|
||||||
# AIMessage to carry framework-authored text leaks loop-warning
|
|
||||||
# text into downstream consumers (MemoryMiddleware fact
|
|
||||||
# extraction, TitleMiddleware, telemetry, model replay) as if
|
|
||||||
# the model said it. The proper fix is to defer warning
|
|
||||||
# injection from after_model to wrap_model_call so every prior
|
|
||||||
# ToolMessage is already in the request — see RFC #2517 (which
|
|
||||||
# lists "loop intervention does not leave invalid
|
|
||||||
# tool-call/tool-message state" as acceptance criteria) and
|
|
||||||
# the prototype on `fix/loop-detection-tool-call-pairing`.
|
|
||||||
messages = state.get("messages", [])
|
|
||||||
last_msg = messages[-1]
|
|
||||||
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
|
|
||||||
return {"messages": [patched_msg]}
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -430,10 +301,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if thread_id:
|
if thread_id:
|
||||||
self._history.pop(thread_id, None)
|
self._history.pop(thread_id, None)
|
||||||
self._warned.pop(thread_id, None)
|
self._warned.pop(thread_id, None)
|
||||||
self._tool_freq.pop(thread_id, None)
|
|
||||||
self._tool_freq_warned.pop(thread_id, None)
|
|
||||||
else:
|
else:
|
||||||
self._history.clear()
|
self._history.clear()
|
||||||
self._warned.clear()
|
self._warned.clear()
|
||||||
self._tool_freq.clear()
|
|
||||||
self._tool_freq_warned.clear()
|
|
||||||
|
|||||||
@@ -1,23 +1,51 @@
|
|||||||
"""Middleware for memory mechanism."""
|
"""Middleware for memory mechanism."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, override
|
import re
|
||||||
|
from typing import Any, override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||||
|
_CORRECTION_PATTERNS = (
|
||||||
|
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||||
|
re.compile(r"不对"),
|
||||||
|
re.compile(r"你理解错了"),
|
||||||
|
re.compile(r"你理解有误"),
|
||||||
|
re.compile(r"重试"),
|
||||||
|
re.compile(r"重新来"),
|
||||||
|
re.compile(r"换一种"),
|
||||||
|
re.compile(r"改用"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_REINFORCEMENT_PATTERNS = (
|
||||||
|
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddlewareState(AgentState):
|
class MemoryMiddlewareState(AgentState):
|
||||||
"""Compatible with the `ThreadState` schema."""
|
"""Compatible with the `ThreadState` schema."""
|
||||||
@@ -25,6 +53,125 @@ class MemoryMiddlewareState(AgentState):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_message_text(message: Any) -> str:
|
||||||
|
"""Extract plain text from message content for filtering and signal detection."""
|
||||||
|
content = getattr(message, "content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, str):
|
||||||
|
text_parts.append(part)
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
text_val = part.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
text_parts.append(text_val)
|
||||||
|
return " ".join(text_parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||||
|
"""Filter messages to keep only user inputs and final assistant responses.
|
||||||
|
|
||||||
|
This filters out:
|
||||||
|
- Tool messages (intermediate tool call results)
|
||||||
|
- AI messages with tool_calls (intermediate steps, not final responses)
|
||||||
|
- The <uploaded_files> block injected by UploadsMiddleware into human messages
|
||||||
|
(file paths are session-scoped and must not persist in long-term memory).
|
||||||
|
The user's actual question is preserved; only turns whose content is entirely
|
||||||
|
the upload block (nothing remains after stripping) are dropped along with
|
||||||
|
their paired assistant response.
|
||||||
|
|
||||||
|
Only keeps:
|
||||||
|
- Human messages (with the ephemeral upload block removed)
|
||||||
|
- AI messages without tool_calls (final assistant responses), unless the
|
||||||
|
paired human turn was upload-only and had no real user text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of all conversation messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list containing only user inputs and final assistant responses.
|
||||||
|
"""
|
||||||
|
filtered = []
|
||||||
|
skip_next_ai = False
|
||||||
|
for msg in messages:
|
||||||
|
msg_type = getattr(msg, "type", None)
|
||||||
|
|
||||||
|
if msg_type == "human":
|
||||||
|
content_str = _extract_message_text(msg)
|
||||||
|
if "<uploaded_files>" in content_str:
|
||||||
|
# Strip the ephemeral upload block; keep the user's real question.
|
||||||
|
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||||
|
if not stripped:
|
||||||
|
# Nothing left — the entire turn was upload bookkeeping;
|
||||||
|
# skip it and the paired assistant response.
|
||||||
|
skip_next_ai = True
|
||||||
|
continue
|
||||||
|
# Rebuild the message with cleaned content so the user's question
|
||||||
|
# is still available for memory summarisation.
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
clean_msg = copy(msg)
|
||||||
|
clean_msg.content = stripped
|
||||||
|
filtered.append(clean_msg)
|
||||||
|
skip_next_ai = False
|
||||||
|
else:
|
||||||
|
filtered.append(msg)
|
||||||
|
skip_next_ai = False
|
||||||
|
elif msg_type == "ai":
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
|
if not tool_calls:
|
||||||
|
if skip_next_ai:
|
||||||
|
skip_next_ai = False
|
||||||
|
continue
|
||||||
|
filtered.append(msg)
|
||||||
|
# Skip tool messages and AI messages with tool_calls
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def detect_correction(messages: list[Any]) -> bool:
|
||||||
|
"""Detect explicit user corrections in recent conversation turns.
|
||||||
|
|
||||||
|
The queue keeps only one pending context per thread, so callers pass the
|
||||||
|
latest filtered message list. Checking only recent user turns keeps signal
|
||||||
|
detection conservative while avoiding stale corrections from long histories.
|
||||||
|
"""
|
||||||
|
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||||
|
|
||||||
|
for msg in recent_user_msgs:
|
||||||
|
content = _extract_message_text(msg).strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||||
|
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||||
|
|
||||||
|
Complements detect_correction() by identifying when the user confirms the
|
||||||
|
agent's approach was correct. This allows the memory system to record what
|
||||||
|
worked well, not just what went wrong.
|
||||||
|
|
||||||
|
The queue keeps only one pending context per thread, so callers pass the
|
||||||
|
latest filtered message list. Checking only recent user turns keeps signal
|
||||||
|
detection conservative while avoiding stale signals from long histories.
|
||||||
|
"""
|
||||||
|
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||||
|
|
||||||
|
for msg in recent_user_msgs:
|
||||||
|
content = _extract_message_text(msg).strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||||
"""Middleware that queues conversation for memory update after agent execution.
|
"""Middleware that queues conversation for memory update after agent execution.
|
||||||
|
|
||||||
@@ -37,17 +184,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
|
|
||||||
state_schema = MemoryMiddlewareState
|
state_schema = MemoryMiddlewareState
|
||||||
|
|
||||||
def __init__(self, agent_name: str | None = None, *, memory_config: "MemoryConfig | None" = None):
|
def __init__(self, agent_name: str | None = None):
|
||||||
"""Initialize the MemoryMiddleware.
|
"""Initialize the MemoryMiddleware.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||||
memory_config: Explicit memory config. When omitted, legacy global
|
|
||||||
config fallback is used.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._agent_name = agent_name
|
self._agent_name = agent_name
|
||||||
self._memory_config = memory_config
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
@@ -60,7 +204,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
Returns:
|
Returns:
|
||||||
None (no state changes needed from this middleware).
|
None (no state changes needed from this middleware).
|
||||||
"""
|
"""
|
||||||
config = self._memory_config or get_memory_config()
|
config = get_memory_config()
|
||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -80,7 +224,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Filter to only keep user inputs and final assistant responses
|
# Filter to only keep user inputs and final assistant responses
|
||||||
filtered_messages = filter_messages_for_memory(messages)
|
filtered_messages = _filter_messages_for_memory(messages)
|
||||||
|
|
||||||
# Only queue if there's meaningful conversation
|
# Only queue if there's meaningful conversation
|
||||||
# At minimum need one user message and one assistant response
|
# At minimum need one user message and one assistant response
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from langchain.agents import AgentState
|
|||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
|
||||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -64,7 +63,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
|||||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||||
|
|
||||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||||
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
|
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||||
return {"messages": [updated_msg]}
|
return {"messages": [updated_msg]}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -1,374 +1,13 @@
|
|||||||
"""Summarization middleware extensions for DeerFlow."""
|
from typing import override
|
||||||
|
|
||||||
from __future__ import annotations
|
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
|
||||||
|
from langchain_core.messages.human import HumanMessage
|
||||||
import logging
|
|
||||||
from collections.abc import Collection
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Protocol, override, runtime_checkable
|
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
|
||||||
from langchain.agents.middleware import SummarizationMiddleware
|
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
|
||||||
from langgraph.config import get_config
|
|
||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
|
|
||||||
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
|
||||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
class SummarizationMiddleware(BaseSummarizationMiddleware):
|
||||||
class SummarizationEvent:
|
|
||||||
"""Context emitted before conversation history is summarized away."""
|
|
||||||
|
|
||||||
messages_to_summarize: tuple[AnyMessage, ...]
|
|
||||||
preserved_messages: tuple[AnyMessage, ...]
|
|
||||||
thread_id: str | None
|
|
||||||
agent_name: str | None
|
|
||||||
runtime: Runtime
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class BeforeSummarizationHook(Protocol):
|
|
||||||
"""Hook invoked before summarization removes messages from state."""
|
|
||||||
|
|
||||||
def __call__(self, event: SummarizationEvent) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
|
||||||
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
||||||
if thread_id is None:
|
|
||||||
try:
|
|
||||||
config_data = get_config()
|
|
||||||
except RuntimeError:
|
|
||||||
return None
|
|
||||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
|
||||||
return thread_id
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
|
||||||
"""Resolve the current agent name from runtime context or LangGraph config."""
|
|
||||||
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
|
||||||
if agent_name is None:
|
|
||||||
try:
|
|
||||||
config_data = get_config()
|
|
||||||
except RuntimeError:
|
|
||||||
return None
|
|
||||||
agent_name = config_data.get("configurable", {}).get("agent_name")
|
|
||||||
return agent_name
|
|
||||||
|
|
||||||
|
|
||||||
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
|
|
||||||
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
|
|
||||||
args = tool_call.get("args") or {}
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
return None
|
|
||||||
for key in ("path", "file_path", "filepath"):
|
|
||||||
value = args.get(key)
|
|
||||||
if isinstance(value, str) and value:
|
|
||||||
return value
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _clone_ai_message(
|
|
||||||
message: AIMessage,
|
|
||||||
tool_calls: list[dict[str, Any]],
|
|
||||||
*,
|
|
||||||
content: Any | None = None,
|
|
||||||
) -> AIMessage:
|
|
||||||
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
|
||||||
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _SkillBundle:
|
|
||||||
"""Skill-related tool calls and tool results associated with one AIMessage."""
|
|
||||||
|
|
||||||
ai_index: int
|
|
||||||
skill_tool_indices: tuple[int, ...]
|
|
||||||
skill_tool_call_ids: frozenset[str]
|
|
||||||
skill_tool_tokens: int
|
|
||||||
skill_key: str
|
|
||||||
|
|
||||||
|
|
||||||
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|
||||||
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
skills_container_path: str | None = None,
|
|
||||||
skill_file_read_tool_names: Collection[str] | None = None,
|
|
||||||
before_summarization: list[BeforeSummarizationHook] | None = None,
|
|
||||||
preserve_recent_skill_count: int = 5,
|
|
||||||
preserve_recent_skill_tokens: int = 25_000,
|
|
||||||
preserve_recent_skill_tokens_per_skill: int = 5_000,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._skills_container_path = skills_container_path or "/mnt/skills"
|
|
||||||
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
|
|
||||||
self._before_summarization_hooks = before_summarization or []
|
|
||||||
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
|
||||||
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
|
||||||
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
|
||||||
|
|
||||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
||||||
return self._maybe_summarize(state, runtime)
|
|
||||||
|
|
||||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
||||||
return await self._amaybe_summarize(state, runtime)
|
|
||||||
|
|
||||||
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
||||||
messages = state["messages"]
|
|
||||||
self._ensure_message_ids(messages)
|
|
||||||
|
|
||||||
total_tokens = self.token_counter(messages)
|
|
||||||
if not self._should_summarize(messages, total_tokens):
|
|
||||||
return None
|
|
||||||
|
|
||||||
cutoff_index = self._determine_cutoff_index(messages)
|
|
||||||
if cutoff_index <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
||||||
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
|
||||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
||||||
summary = self._create_summary(messages_to_summarize)
|
|
||||||
new_messages = self._build_new_messages(summary)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"messages": [
|
|
||||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
||||||
*new_messages,
|
|
||||||
*preserved_messages,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
||||||
messages = state["messages"]
|
|
||||||
self._ensure_message_ids(messages)
|
|
||||||
|
|
||||||
total_tokens = self.token_counter(messages)
|
|
||||||
if not self._should_summarize(messages, total_tokens):
|
|
||||||
return None
|
|
||||||
|
|
||||||
cutoff_index = self._determine_cutoff_index(messages)
|
|
||||||
if cutoff_index <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
||||||
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
|
||||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
||||||
summary = await self._acreate_summary(messages_to_summarize)
|
|
||||||
new_messages = self._build_new_messages(summary)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"messages": [
|
|
||||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
||||||
*new_messages,
|
|
||||||
*preserved_messages,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||||
"""
|
"""
|
||||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||||
|
|
||||||
def _preserve_dynamic_context_reminders(
|
|
||||||
self,
|
|
||||||
messages_to_summarize: list[AnyMessage],
|
|
||||||
preserved_messages: list[AnyMessage],
|
|
||||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
||||||
"""Keep hidden dynamic-context reminders out of summary compression.
|
|
||||||
|
|
||||||
These reminders carry the current date and optional memory. If summarization
|
|
||||||
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
|
|
||||||
for the first user message and inject the reminder in the wrong place.
|
|
||||||
"""
|
|
||||||
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
|
|
||||||
if not reminders:
|
|
||||||
return messages_to_summarize, preserved_messages
|
|
||||||
|
|
||||||
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
|
|
||||||
return remaining, reminders + preserved_messages
|
|
||||||
|
|
||||||
def _partition_with_skill_rescue(
|
|
||||||
self,
|
|
||||||
messages: list[AnyMessage],
|
|
||||||
cutoff_index: int,
|
|
||||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
||||||
"""Partition like the parent, then rescue recently-loaded skill bundles."""
|
|
||||||
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
|
|
||||||
|
|
||||||
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
|
|
||||||
return to_summarize, preserved
|
|
||||||
|
|
||||||
try:
|
|
||||||
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
|
|
||||||
return to_summarize, preserved
|
|
||||||
|
|
||||||
if not bundles:
|
|
||||||
return to_summarize, preserved
|
|
||||||
|
|
||||||
rescue_bundles = self._select_bundles_to_rescue(bundles)
|
|
||||||
if not rescue_bundles:
|
|
||||||
return to_summarize, preserved
|
|
||||||
|
|
||||||
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
|
|
||||||
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
|
|
||||||
rescued: list[AnyMessage] = []
|
|
||||||
remaining: list[AnyMessage] = []
|
|
||||||
for i, msg in enumerate(to_summarize):
|
|
||||||
bundle = bundles_by_ai_index.get(i)
|
|
||||||
if bundle is not None and isinstance(msg, AIMessage):
|
|
||||||
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
|
|
||||||
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
|
|
||||||
|
|
||||||
if rescued_tool_calls:
|
|
||||||
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
|
|
||||||
if remaining_tool_calls or msg.content:
|
|
||||||
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if i in rescue_tool_indices:
|
|
||||||
rescued.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
remaining.append(msg)
|
|
||||||
|
|
||||||
return remaining, rescued + preserved
|
|
||||||
|
|
||||||
def _find_skill_bundles(
|
|
||||||
self,
|
|
||||||
messages: list[AnyMessage],
|
|
||||||
skills_root: str,
|
|
||||||
) -> list[_SkillBundle]:
|
|
||||||
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
|
|
||||||
bundles: list[_SkillBundle] = []
|
|
||||||
n = len(messages)
|
|
||||||
i = 0
|
|
||||||
while i < n:
|
|
||||||
msg = messages[i]
|
|
||||||
if not (isinstance(msg, AIMessage) and msg.tool_calls):
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_calls = list(msg.tool_calls)
|
|
||||||
skill_paths_by_id: dict[str, str] = {}
|
|
||||||
for tc in tool_calls:
|
|
||||||
if self._is_skill_tool_call(tc, skills_root):
|
|
||||||
tc_id = tc.get("id")
|
|
||||||
path = _tool_call_path(tc)
|
|
||||||
if tc_id and path:
|
|
||||||
skill_paths_by_id[tc_id] = path
|
|
||||||
|
|
||||||
if not skill_paths_by_id:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
skill_tool_tokens = 0
|
|
||||||
skill_key_parts: list[str] = []
|
|
||||||
skill_tool_indices: list[int] = []
|
|
||||||
matched_skill_call_ids: set[str] = set()
|
|
||||||
|
|
||||||
j = i + 1
|
|
||||||
while j < n and isinstance(messages[j], ToolMessage):
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
for k in range(i + 1, j):
|
|
||||||
tool_msg = messages[k]
|
|
||||||
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
|
|
||||||
skill_tool_tokens += self.token_counter([tool_msg])
|
|
||||||
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
|
|
||||||
skill_tool_indices.append(k)
|
|
||||||
matched_skill_call_ids.add(tool_msg.tool_call_id)
|
|
||||||
|
|
||||||
if not skill_tool_indices:
|
|
||||||
i = j
|
|
||||||
continue
|
|
||||||
|
|
||||||
bundles.append(
|
|
||||||
_SkillBundle(
|
|
||||||
ai_index=i,
|
|
||||||
skill_tool_indices=tuple(skill_tool_indices),
|
|
||||||
skill_tool_call_ids=frozenset(matched_skill_call_ids),
|
|
||||||
skill_tool_tokens=skill_tool_tokens,
|
|
||||||
skill_key="|".join(sorted(skill_key_parts)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
i = j
|
|
||||||
|
|
||||||
return bundles
|
|
||||||
|
|
||||||
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
|
|
||||||
"""Pick bundles to keep, walking newest-first under count/token budgets."""
|
|
||||||
selected: list[_SkillBundle] = []
|
|
||||||
if not bundles:
|
|
||||||
return selected
|
|
||||||
|
|
||||||
seen_skill_keys: set[str] = set()
|
|
||||||
total_tokens = 0
|
|
||||||
kept = 0
|
|
||||||
|
|
||||||
for bundle in reversed(bundles):
|
|
||||||
if kept >= self._preserve_recent_skill_count:
|
|
||||||
break
|
|
||||||
if bundle.skill_key in seen_skill_keys:
|
|
||||||
continue
|
|
||||||
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
|
|
||||||
continue
|
|
||||||
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
|
|
||||||
continue
|
|
||||||
|
|
||||||
selected.append(bundle)
|
|
||||||
total_tokens += bundle.skill_tool_tokens
|
|
||||||
kept += 1
|
|
||||||
seen_skill_keys.add(bundle.skill_key)
|
|
||||||
|
|
||||||
selected.reverse()
|
|
||||||
return selected
|
|
||||||
|
|
||||||
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
|
|
||||||
"""Return True when ``tool_call`` reads a file under the configured skills root."""
|
|
||||||
name = tool_call.get("name") or ""
|
|
||||||
if name not in self._skill_file_read_tool_names:
|
|
||||||
return False
|
|
||||||
path = _tool_call_path(tool_call)
|
|
||||||
if not path:
|
|
||||||
return False
|
|
||||||
normalized_root = skills_root.rstrip("/")
|
|
||||||
return path == normalized_root or path.startswith(normalized_root + "/")
|
|
||||||
|
|
||||||
def _fire_hooks(
|
|
||||||
self,
|
|
||||||
messages_to_summarize: list[AnyMessage],
|
|
||||||
preserved_messages: list[AnyMessage],
|
|
||||||
runtime: Runtime,
|
|
||||||
) -> None:
|
|
||||||
if not self._before_summarization_hooks:
|
|
||||||
return
|
|
||||||
|
|
||||||
event = SummarizationEvent(
|
|
||||||
messages_to_summarize=tuple(messages_to_summarize),
|
|
||||||
preserved_messages=tuple(preserved_messages),
|
|
||||||
thread_id=_resolve_thread_id(runtime),
|
|
||||||
agent_name=_resolve_agent_name(runtime),
|
|
||||||
runtime=runtime,
|
|
||||||
)
|
|
||||||
|
|
||||||
for hook in self._before_summarization_hooks:
|
|
||||||
try:
|
|
||||||
hook(event)
|
|
||||||
except Exception:
|
|
||||||
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
|
|
||||||
logger.exception("before_summarization hook %s failed", hook_name)
|
|
||||||
|
|||||||
@@ -1,22 +1,16 @@
|
|||||||
"""Middleware for automatic thread title generation."""
|
"""Middleware for automatic thread title generation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import Any, NotRequired, override
|
||||||
from typing import TYPE_CHECKING, Any, NotRequired, override
|
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
|
||||||
from deerflow.config.title_config import get_title_config
|
from deerflow.config.title_config import get_title_config
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
from deerflow.config.title_config import TitleConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -31,18 +25,6 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
|
|
||||||
state_schema = TitleMiddlewareState
|
state_schema = TitleMiddlewareState
|
||||||
|
|
||||||
def __init__(self, *, app_config: "AppConfig | None" = None, title_config: "TitleConfig | None" = None):
|
|
||||||
super().__init__()
|
|
||||||
self._app_config = app_config
|
|
||||||
self._title_config = title_config
|
|
||||||
|
|
||||||
def _get_title_config(self):
|
|
||||||
if self._title_config is not None:
|
|
||||||
return self._title_config
|
|
||||||
if self._app_config is not None:
|
|
||||||
return self._app_config.title
|
|
||||||
return get_title_config()
|
|
||||||
|
|
||||||
def _normalize_content(self, content: object) -> str:
|
def _normalize_content(self, content: object) -> str:
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
@@ -62,13 +44,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_user_message_for_title(message: object) -> bool:
|
|
||||||
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
|
|
||||||
|
|
||||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||||
"""Check if we should generate a title for this thread."""
|
"""Check if we should generate a title for this thread."""
|
||||||
config = self._get_title_config()
|
config = get_title_config()
|
||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -82,7 +60,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Count user and assistant messages
|
# Count user and assistant messages
|
||||||
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
|
user_messages = [m for m in messages if m.type == "human"]
|
||||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||||
|
|
||||||
# Generate title after first complete exchange
|
# Generate title after first complete exchange
|
||||||
@@ -93,14 +71,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
|
|
||||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||||
"""
|
"""
|
||||||
config = self._get_title_config()
|
config = get_title_config()
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
|
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||||
|
|
||||||
user_msg = self._normalize_content(user_msg_content)
|
user_msg = self._normalize_content(user_msg_content)
|
||||||
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
|
assistant_msg = self._normalize_content(assistant_msg_content)
|
||||||
|
|
||||||
prompt = config.prompt_template.format(
|
prompt = config.prompt_template.format(
|
||||||
max_words=config.max_words,
|
max_words=config.max_words,
|
||||||
@@ -109,20 +87,15 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
)
|
)
|
||||||
return prompt, user_msg
|
return prompt, user_msg
|
||||||
|
|
||||||
def _strip_think_tags(self, text: str) -> str:
|
|
||||||
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
|
|
||||||
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
|
|
||||||
|
|
||||||
def _parse_title(self, content: object) -> str:
|
def _parse_title(self, content: object) -> str:
|
||||||
"""Normalize model output into a clean title string."""
|
"""Normalize model output into a clean title string."""
|
||||||
config = self._get_title_config()
|
config = get_title_config()
|
||||||
title_content = self._normalize_content(content)
|
title_content = self._normalize_content(content)
|
||||||
title_content = self._strip_think_tags(title_content)
|
|
||||||
title = title_content.strip().strip('"').strip("'")
|
title = title_content.strip().strip('"').strip("'")
|
||||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||||
|
|
||||||
def _fallback_title(self, user_msg: str) -> str:
|
def _fallback_title(self, user_msg: str) -> str:
|
||||||
config = self._get_title_config()
|
config = get_title_config()
|
||||||
fallback_chars = min(config.max_chars, 50)
|
fallback_chars = min(config.max_chars, 50)
|
||||||
if len(user_msg) > fallback_chars:
|
if len(user_msg) > fallback_chars:
|
||||||
return user_msg[:fallback_chars].rstrip() + "..."
|
return user_msg[:fallback_chars].rstrip() + "..."
|
||||||
@@ -139,7 +112,6 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
except Exception:
|
except Exception:
|
||||||
parent = {}
|
parent = {}
|
||||||
config = {**parent}
|
config = {**parent}
|
||||||
config["run_name"] = "title_agent"
|
|
||||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -156,17 +128,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
if not self._should_generate_title(state):
|
if not self._should_generate_title(state):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
config = self._get_title_config()
|
config = get_title_config()
|
||||||
prompt, user_msg = self._build_title_prompt(state)
|
prompt, user_msg = self._build_title_prompt(state)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_kwargs = {"thinking_enabled": False}
|
|
||||||
if self._app_config is not None:
|
|
||||||
model_kwargs["app_config"] = self._app_config
|
|
||||||
if config.model_name:
|
if config.model_name:
|
||||||
model = create_chat_model(name=config.model_name, **model_kwargs)
|
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||||
else:
|
else:
|
||||||
model = create_chat_model(**model_kwargs)
|
model = create_chat_model(thinking_enabled=False)
|
||||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||||
title = self._parse_title(response.content)
|
title = self._parse_title(response.content)
|
||||||
if title:
|
if title:
|
||||||
|
|||||||
@@ -1,27 +1,17 @@
|
|||||||
"""Middleware that extends TodoListMiddleware with context-loss detection and premature-exit prevention.
|
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||||||
|
|
||||||
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||||||
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||||||
active context window. This middleware detects that situation and injects a
|
active context window. This middleware detects that situation and injects a
|
||||||
reminder message so the model still knows about the outstanding todo list.
|
reminder message so the model still knows about the outstanding todo list.
|
||||||
|
|
||||||
Additionally, this middleware prevents the agent from exiting the loop while
|
|
||||||
there are still incomplete todo items. When the model produces a final response
|
|
||||||
(no tool calls) but todos are not yet complete, the middleware queues a reminder
|
|
||||||
for the next model request and jumps back to the model node to force continued
|
|
||||||
engagement. The completion reminder is injected via ``wrap_model_call`` instead
|
|
||||||
of being persisted into graph state as a normal user-visible message.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import threading
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any, override
|
from typing import Any, override
|
||||||
|
|
||||||
from langchain.agents.middleware import TodoListMiddleware
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
@@ -44,11 +34,6 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _completion_reminder_count(messages: list[Any]) -> int:
|
|
||||||
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
|
|
||||||
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
|
|
||||||
|
|
||||||
|
|
||||||
def _format_todos(todos: list[Todo]) -> str:
|
def _format_todos(todos: list[Todo]) -> str:
|
||||||
"""Format a list of Todo items into a human-readable string."""
|
"""Format a list of Todo items into a human-readable string."""
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
@@ -59,51 +44,6 @@ def _format_todos(todos: list[Todo]) -> str:
|
|||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _format_completion_reminder(todos: list[Todo]) -> str:
|
|
||||||
"""Format a completion reminder for incomplete todo items."""
|
|
||||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
|
||||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
|
||||||
return (
|
|
||||||
"<system_reminder>\n"
|
|
||||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
|
||||||
f"{incomplete_text}\n\n"
|
|
||||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
|
||||||
"as you finish them, and only respond when all items are done.\n"
|
|
||||||
"</system_reminder>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
|
|
||||||
|
|
||||||
|
|
||||||
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
|
|
||||||
"""Return True when an AIMessage is not a clean final answer.
|
|
||||||
|
|
||||||
Todo completion reminders should only fire when the model has produced a
|
|
||||||
plain final response. Provider/tool parsing details have moved across
|
|
||||||
LangChain versions and integrations, so keep all tool-intent/error signals
|
|
||||||
behind this helper instead of checking one concrete field at the call site.
|
|
||||||
"""
|
|
||||||
if message.tool_calls:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if getattr(message, "invalid_tool_calls", None):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Backward/provider compatibility: some integrations preserve raw or legacy
|
|
||||||
# tool-call intent in additional_kwargs even when structured tool_calls is
|
|
||||||
# empty. If this helper changes, update the matching sentinel test
|
|
||||||
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
|
|
||||||
# if that test fails after a LangChain upgrade, review this helper so new
|
|
||||||
# tool-call/error fields are not silently treated as clean final answers.
|
|
||||||
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
|
|
||||||
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
response_metadata = getattr(message, "response_metadata", {}) or {}
|
|
||||||
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
|
|
||||||
|
|
||||||
|
|
||||||
class TodoMiddleware(TodoListMiddleware):
|
class TodoMiddleware(TodoListMiddleware):
|
||||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||||
|
|
||||||
@@ -117,7 +57,7 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
def before_model(
|
def before_model(
|
||||||
self,
|
self,
|
||||||
state: PlanningState,
|
state: PlanningState,
|
||||||
runtime: Runtime,
|
runtime: Runtime, # noqa: ARG002
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||||
@@ -138,7 +78,6 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
formatted = _format_todos(todos)
|
formatted = _format_todos(todos)
|
||||||
reminder = HumanMessage(
|
reminder = HumanMessage(
|
||||||
name="todo_reminder",
|
name="todo_reminder",
|
||||||
additional_kwargs={"hide_from_ui": True},
|
|
||||||
content=(
|
content=(
|
||||||
"<system_reminder>\n"
|
"<system_reminder>\n"
|
||||||
"Your todo list from earlier is no longer visible in the current context window, "
|
"Your todo list from earlier is no longer visible in the current context window, "
|
||||||
@@ -159,201 +98,3 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Async version of before_model."""
|
"""Async version of before_model."""
|
||||||
return self.before_model(state, runtime)
|
return self.before_model(state, runtime)
|
||||||
|
|
||||||
# Maximum number of completion reminders before allowing the agent to exit.
|
|
||||||
# This prevents infinite loops when the agent cannot make further progress.
|
|
||||||
_MAX_COMPLETION_REMINDERS = 2
|
|
||||||
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
|
|
||||||
_MAX_COMPLETION_REMINDER_KEYS = 4096
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
|
|
||||||
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
|
|
||||||
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
|
|
||||||
self._completion_reminder_next_order = 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_thread_id(runtime: Runtime) -> str:
|
|
||||||
context = getattr(runtime, "context", None)
|
|
||||||
thread_id = context.get("thread_id") if context else None
|
|
||||||
return str(thread_id) if thread_id else "default"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_run_id(runtime: Runtime) -> str:
|
|
||||||
context = getattr(runtime, "context", None)
|
|
||||||
run_id = context.get("run_id") if context else None
|
|
||||||
return str(run_id) if run_id else "default"
|
|
||||||
|
|
||||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
|
||||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
|
||||||
|
|
||||||
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
|
||||||
self._completion_reminder_next_order += 1
|
|
||||||
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
|
|
||||||
|
|
||||||
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
|
|
||||||
keys = set(self._pending_completion_reminders)
|
|
||||||
keys.update(self._completion_reminder_counts)
|
|
||||||
keys.update(self._completion_reminder_touch_order)
|
|
||||||
return keys
|
|
||||||
|
|
||||||
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
|
||||||
self._pending_completion_reminders.pop(key, None)
|
|
||||||
self._completion_reminder_counts.pop(key, None)
|
|
||||||
self._completion_reminder_touch_order.pop(key, None)
|
|
||||||
|
|
||||||
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
|
|
||||||
keys = self._completion_reminder_keys_locked()
|
|
||||||
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
|
|
||||||
if overflow <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
candidates = [key for key in keys if key != protected_key]
|
|
||||||
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
|
|
||||||
for key in candidates[:overflow]:
|
|
||||||
self._drop_completion_reminder_key_locked(key)
|
|
||||||
|
|
||||||
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
self._pending_completion_reminders.setdefault(key, []).append(reminder)
|
|
||||||
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
|
|
||||||
self._touch_completion_reminder_key_locked(key)
|
|
||||||
self._prune_completion_reminder_state_locked(protected_key=key)
|
|
||||||
|
|
||||||
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
return self._completion_reminder_counts.get(key, 0)
|
|
||||||
|
|
||||||
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
reminders = self._pending_completion_reminders.pop(key, [])
|
|
||||||
if reminders or key in self._completion_reminder_counts:
|
|
||||||
self._touch_completion_reminder_key_locked(key)
|
|
||||||
return reminders
|
|
||||||
|
|
||||||
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
|
|
||||||
thread_id, current_run_id = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
for key in self._completion_reminder_keys_locked():
|
|
||||||
if key[0] == thread_id and key[1] != current_run_id:
|
|
||||||
self._drop_completion_reminder_key_locked(key)
|
|
||||||
|
|
||||||
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
|
|
||||||
key = self._pending_key(runtime)
|
|
||||||
with self._lock:
|
|
||||||
self._drop_completion_reminder_key_locked(key)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_other_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_other_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@hook_config(can_jump_to=["model"])
|
|
||||||
@override
|
|
||||||
def after_model(
|
|
||||||
self,
|
|
||||||
state: PlanningState,
|
|
||||||
runtime: Runtime,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Prevent premature agent exit when todo items are still incomplete.
|
|
||||||
|
|
||||||
In addition to the base class check for parallel ``write_todos`` calls,
|
|
||||||
this override intercepts model responses that have no tool calls while
|
|
||||||
there are still incomplete todo items. It injects a reminder
|
|
||||||
``HumanMessage`` and jumps back to the model node so the agent
|
|
||||||
continues working through the todo list.
|
|
||||||
|
|
||||||
A retry cap of ``_MAX_COMPLETION_REMINDERS`` (default 2) prevents
|
|
||||||
infinite loops when the agent cannot make further progress.
|
|
||||||
"""
|
|
||||||
# 1. Preserve base class logic (parallel write_todos detection).
|
|
||||||
base_result = super().after_model(state, runtime)
|
|
||||||
if base_result is not None:
|
|
||||||
return base_result
|
|
||||||
|
|
||||||
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
|
|
||||||
# intent or tool-call parse errors should be handled by the tool path
|
|
||||||
# instead of being masked by todo reminders.
|
|
||||||
messages = state.get("messages") or []
|
|
||||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
|
||||||
if not last_ai or _has_tool_call_intent_or_error(last_ai):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 3. Allow exit when all todos are completed or there are no todos.
|
|
||||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
|
||||||
if not todos or all(t.get("status") == "completed" for t in todos):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
|
||||||
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 5. Queue a reminder for the next model request and jump back. We must
|
|
||||||
# not persist this control prompt as a normal HumanMessage, otherwise it
|
|
||||||
# can leak into user-visible message streams and saved transcripts.
|
|
||||||
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
|
|
||||||
return {"jump_to": "model"}
|
|
||||||
|
|
||||||
@override
|
|
||||||
@hook_config(can_jump_to=["model"])
|
|
||||||
async def aafter_model(
|
|
||||||
self,
|
|
||||||
state: PlanningState,
|
|
||||||
runtime: Runtime,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Async version of after_model."""
|
|
||||||
return self.after_model(state, runtime)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_pending_completion_reminders(reminders: list[str]) -> str:
|
|
||||||
return "\n\n".join(dict.fromkeys(reminders))
|
|
||||||
|
|
||||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
|
||||||
reminders = self._drain_completion_reminders(request.runtime)
|
|
||||||
if not reminders:
|
|
||||||
return request
|
|
||||||
new_messages = [
|
|
||||||
*request.messages,
|
|
||||||
HumanMessage(
|
|
||||||
content=self._format_pending_completion_reminders(reminders),
|
|
||||||
name="todo_completion_reminder",
|
|
||||||
additional_kwargs={"hide_from_ui": True},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return request.override(messages=new_messages)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def wrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
|
||||||
) -> ModelCallResult:
|
|
||||||
return handler(self._augment_request(request))
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def awrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
||||||
) -> ModelCallResult:
|
|
||||||
return await handler(self._augment_request(request))
|
|
||||||
|
|
||||||
@override
|
|
||||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_current_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
|
||||||
self._clear_current_run_completion_reminders(runtime)
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -1,358 +1,37 @@
|
|||||||
"""Middleware for logging token usage and annotating step attribution."""
|
"""Middleware for logging LLM token usage."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from typing import override
|
||||||
from typing import Any, override
|
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain.agents.middleware.todo import Todo
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
|
|
||||||
|
|
||||||
|
|
||||||
def _string_arg(value: Any) -> str | None:
|
|
||||||
if isinstance(value, str):
|
|
||||||
normalized = value.strip()
|
|
||||||
return normalized or None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_todos(value: Any) -> list[Todo]:
|
|
||||||
if not isinstance(value, list):
|
|
||||||
return []
|
|
||||||
|
|
||||||
normalized: list[Todo] = []
|
|
||||||
for item in value:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
todo: Todo = {}
|
|
||||||
content = _string_arg(item.get("content"))
|
|
||||||
status = item.get("status")
|
|
||||||
|
|
||||||
if content is not None:
|
|
||||||
todo["content"] = content
|
|
||||||
if status in {"pending", "in_progress", "completed"}:
|
|
||||||
todo["status"] = status
|
|
||||||
|
|
||||||
normalized.append(todo)
|
|
||||||
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
|
|
||||||
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
|
|
||||||
status = current.get("status")
|
|
||||||
previous_content = previous.get("content") if previous else None
|
|
||||||
current_content = current.get("content")
|
|
||||||
|
|
||||||
if previous is None:
|
|
||||||
if status == "completed":
|
|
||||||
return "todo_complete"
|
|
||||||
if status == "in_progress":
|
|
||||||
return "todo_start"
|
|
||||||
return "todo_update"
|
|
||||||
|
|
||||||
if previous_content != current_content:
|
|
||||||
return "todo_update"
|
|
||||||
|
|
||||||
if status == "completed":
|
|
||||||
return "todo_complete"
|
|
||||||
if status == "in_progress":
|
|
||||||
return "todo_start"
|
|
||||||
return "todo_update"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
|
|
||||||
# This is the single source of truth for precise write_todos token
|
|
||||||
# attribution. The frontend intentionally falls back to a generic
|
|
||||||
# "Update to-do list" label when this metadata is missing or malformed.
|
|
||||||
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
|
|
||||||
matched_previous_indices: set[int] = set()
|
|
||||||
|
|
||||||
for index, todo in enumerate(previous_todos):
|
|
||||||
content = todo.get("content")
|
|
||||||
if isinstance(content, str) and content:
|
|
||||||
previous_by_content[content].append((index, todo))
|
|
||||||
|
|
||||||
actions: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for index, todo in enumerate(next_todos):
|
|
||||||
content = todo.get("content")
|
|
||||||
if not isinstance(content, str) or not content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
previous_match: Todo | None = None
|
|
||||||
content_matches = previous_by_content.get(content)
|
|
||||||
if content_matches:
|
|
||||||
while content_matches and content_matches[0][0] in matched_previous_indices:
|
|
||||||
content_matches.pop(0)
|
|
||||||
if content_matches:
|
|
||||||
previous_index, previous_match = content_matches.pop(0)
|
|
||||||
matched_previous_indices.add(previous_index)
|
|
||||||
|
|
||||||
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
|
|
||||||
previous_match = previous_todos[index]
|
|
||||||
matched_previous_indices.add(index)
|
|
||||||
|
|
||||||
if previous_match is not None:
|
|
||||||
previous_content = previous_match.get("content")
|
|
||||||
previous_status = previous_match.get("status")
|
|
||||||
if previous_content == content and previous_status == todo.get("status"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
actions.append(
|
|
||||||
{
|
|
||||||
"kind": _todo_action_kind(previous_match, todo),
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, todo in enumerate(previous_todos):
|
|
||||||
if index in matched_previous_indices:
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = todo.get("content")
|
|
||||||
if not isinstance(content, str) or not content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
actions.append(
|
|
||||||
{
|
|
||||||
"kind": "todo_remove",
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return actions
|
|
||||||
|
|
||||||
|
|
||||||
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
|
|
||||||
name = _string_arg(tool_call.get("name")) or "unknown"
|
|
||||||
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
|
|
||||||
tool_call_id = _string_arg(tool_call.get("id"))
|
|
||||||
|
|
||||||
if name == "write_todos":
|
|
||||||
next_todos = _normalize_todos(args.get("todos"))
|
|
||||||
actions = _build_todo_actions(todos, next_todos)
|
|
||||||
if not actions:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"kind": "tool",
|
|
||||||
"tool_name": name,
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
**action,
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
for action in actions
|
|
||||||
]
|
|
||||||
|
|
||||||
if name == "task":
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"kind": "subagent",
|
|
||||||
"description": _string_arg(args.get("description")),
|
|
||||||
"subagent_type": _string_arg(args.get("subagent_type")),
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
if name in {"web_search", "image_search"}:
|
|
||||||
query = _string_arg(args.get("query"))
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"kind": "search",
|
|
||||||
"tool_name": name,
|
|
||||||
"query": query,
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
if name == "present_files":
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"kind": "present_files",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
if name == "ask_clarification":
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"kind": "clarification",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"kind": "tool",
|
|
||||||
"tool_name": name,
|
|
||||||
"description": _string_arg(args.get("description")),
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
|
||||||
if actions:
|
|
||||||
first_kind = actions[0].get("kind")
|
|
||||||
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
|
|
||||||
return "todo_update"
|
|
||||||
if len(actions) == 1 and first_kind == "subagent":
|
|
||||||
return "subagent_dispatch"
|
|
||||||
return "tool_batch"
|
|
||||||
|
|
||||||
if message.content:
|
|
||||||
return "final_answer"
|
|
||||||
return "thinking"
|
|
||||||
|
|
||||||
|
|
||||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
|
||||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
|
||||||
for tc in message.tool_calls or []:
|
|
||||||
if isinstance(tc, dict):
|
|
||||||
if tc.get("id") == tool_call_id:
|
|
||||||
return True
|
|
||||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
|
||||||
tool_calls = getattr(message, "tool_calls", None) or []
|
|
||||||
actions: list[dict[str, Any]] = []
|
|
||||||
current_todos = list(todos)
|
|
||||||
|
|
||||||
for raw_tool_call in tool_calls:
|
|
||||||
if not isinstance(raw_tool_call, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
described_actions = _describe_tool_call(raw_tool_call, current_todos)
|
|
||||||
actions.extend(described_actions)
|
|
||||||
|
|
||||||
if raw_tool_call.get("name") == "write_todos":
|
|
||||||
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
|
|
||||||
current_todos = _normalize_todos(args.get("todos"))
|
|
||||||
|
|
||||||
tool_call_ids: list[str] = []
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
if not isinstance(tool_call, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_call_id = _string_arg(tool_call.get("id"))
|
|
||||||
if tool_call_id is not None:
|
|
||||||
tool_call_ids.append(tool_call_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
# Schema changes should remain additive where possible so older
|
|
||||||
# frontends can ignore unknown fields and fall back safely.
|
|
||||||
"version": 1,
|
|
||||||
"kind": _infer_step_kind(message, actions),
|
|
||||||
"shared_attribution": len(actions) > 1,
|
|
||||||
"tool_call_ids": tool_call_ids,
|
|
||||||
"actions": actions,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TokenUsageMiddleware(AgentMiddleware):
|
class TokenUsageMiddleware(AgentMiddleware):
|
||||||
"""Logs token usage from model responses and annotates the AI step."""
|
"""Logs token usage from model response usage_metadata."""
|
||||||
|
|
||||||
def _apply(self, state: AgentState) -> dict | None:
|
|
||||||
messages = state.get("messages", [])
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
|
||||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
|
||||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
|
||||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
|
||||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
|
||||||
# written back to the same dispatch message (merging into one update).
|
|
||||||
state_updates: dict[int, AIMessage] = {}
|
|
||||||
if len(messages) >= 2:
|
|
||||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
|
||||||
|
|
||||||
idx = len(messages) - 2
|
|
||||||
while idx >= 0:
|
|
||||||
tool_msg = messages[idx]
|
|
||||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
|
||||||
if subagent_usage:
|
|
||||||
# Search backward from the ToolMessage to find the AIMessage
|
|
||||||
# that dispatched it. A single model response can dispatch
|
|
||||||
# multiple task tool calls, so we can't assume a fixed offset.
|
|
||||||
dispatch_idx = idx - 1
|
|
||||||
while dispatch_idx >= 0:
|
|
||||||
candidate = messages[dispatch_idx]
|
|
||||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
|
||||||
# Accumulate into an existing update for the same
|
|
||||||
# AIMessage (multiple task calls in one response),
|
|
||||||
# or merge fresh from the original message.
|
|
||||||
existing_update = state_updates.get(dispatch_idx)
|
|
||||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
|
||||||
merged = {
|
|
||||||
**prev,
|
|
||||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
|
||||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
|
||||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
|
||||||
}
|
|
||||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
|
||||||
break
|
|
||||||
dispatch_idx -= 1
|
|
||||||
idx -= 1
|
|
||||||
|
|
||||||
last = messages[-1]
|
|
||||||
if not isinstance(last, AIMessage):
|
|
||||||
if state_updates:
|
|
||||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
|
||||||
return None
|
|
||||||
|
|
||||||
usage = getattr(last, "usage_metadata", None)
|
|
||||||
if usage:
|
|
||||||
input_token_details = usage.get("input_token_details") or {}
|
|
||||||
output_token_details = usage.get("output_token_details") or {}
|
|
||||||
detail_parts = []
|
|
||||||
if input_token_details:
|
|
||||||
detail_parts.append(f"input_token_details={input_token_details}")
|
|
||||||
if output_token_details:
|
|
||||||
detail_parts.append(f"output_token_details={output_token_details}")
|
|
||||||
detail_suffix = f" {' '.join(detail_parts)}" if detail_parts else ""
|
|
||||||
logger.info(
|
|
||||||
"LLM token usage: input=%s output=%s total=%s%s",
|
|
||||||
usage.get("input_tokens", "?"),
|
|
||||||
usage.get("output_tokens", "?"),
|
|
||||||
usage.get("total_tokens", "?"),
|
|
||||||
detail_suffix,
|
|
||||||
)
|
|
||||||
|
|
||||||
todos = state.get("todos") or []
|
|
||||||
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
|
|
||||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
|
||||||
|
|
||||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
|
||||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
|
||||||
|
|
||||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
|
||||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
|
||||||
state_updates[len(messages) - 1] = updated_msg
|
|
||||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
return self._apply(state)
|
return self._log_usage(state)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
return self._apply(state)
|
return self._log_usage(state)
|
||||||
|
|
||||||
|
def _log_usage(self, state: AgentState) -> None:
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
last = messages[-1]
|
||||||
|
usage = getattr(last, "usage_metadata", None)
|
||||||
|
if usage:
|
||||||
|
logger.info(
|
||||||
|
"LLM token usage: input=%s output=%s total=%s",
|
||||||
|
usage.get("input_tokens", "?"),
|
||||||
|
usage.get("output_tokens", "?"),
|
||||||
|
usage.get("total_tokens", "?"),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
"""Helpers for keeping AIMessage tool-call metadata consistent."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
|
|
||||||
|
|
||||||
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
|
|
||||||
if not isinstance(raw_tool_call, dict):
|
|
||||||
return None
|
|
||||||
|
|
||||||
raw_id = raw_tool_call.get("id")
|
|
||||||
return raw_id if isinstance(raw_id, str) and raw_id else None
|
|
||||||
|
|
||||||
|
|
||||||
def clone_ai_message_with_tool_calls(
|
|
||||||
message: AIMessage,
|
|
||||||
tool_calls: list[dict[str, Any]],
|
|
||||||
*,
|
|
||||||
content: Any | None = None,
|
|
||||||
) -> AIMessage:
|
|
||||||
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
|
|
||||||
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
|
|
||||||
|
|
||||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
|
||||||
if content is not None:
|
|
||||||
update["content"] = content
|
|
||||||
|
|
||||||
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
|
|
||||||
raw_tool_calls = additional_kwargs.get("tool_calls")
|
|
||||||
if isinstance(raw_tool_calls, list):
|
|
||||||
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
|
|
||||||
if synced_raw_tool_calls:
|
|
||||||
additional_kwargs["tool_calls"] = synced_raw_tool_calls
|
|
||||||
else:
|
|
||||||
additional_kwargs.pop("tool_calls", None)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
additional_kwargs.pop("function_call", None)
|
|
||||||
|
|
||||||
update["additional_kwargs"] = additional_kwargs
|
|
||||||
|
|
||||||
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
|
|
||||||
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
|
|
||||||
response_metadata["finish_reason"] = "stop"
|
|
||||||
update["response_metadata"] = response_metadata
|
|
||||||
|
|
||||||
return message.model_copy(update=update)
|
|
||||||
+7
-31
@@ -11,8 +11,6 @@ from langgraph.errors import GraphBubbleUp
|
|||||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||||
@@ -69,7 +67,6 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
def _build_runtime_middlewares(
|
def _build_runtime_middlewares(
|
||||||
*,
|
*,
|
||||||
app_config: AppConfig,
|
|
||||||
include_uploads: bool,
|
include_uploads: bool,
|
||||||
include_dangling_tool_call_patch: bool,
|
include_dangling_tool_call_patch: bool,
|
||||||
lazy_init: bool = True,
|
lazy_init: bool = True,
|
||||||
@@ -94,10 +91,12 @@ def _build_runtime_middlewares(
|
|||||||
|
|
||||||
middlewares.append(DanglingToolCallMiddleware())
|
middlewares.append(DanglingToolCallMiddleware())
|
||||||
|
|
||||||
middlewares.append(LLMErrorHandlingMiddleware(app_config=app_config))
|
middlewares.append(LLMErrorHandlingMiddleware())
|
||||||
|
|
||||||
# Guardrail middleware (if configured)
|
# Guardrail middleware (if configured)
|
||||||
guardrails_config = app_config.guardrails
|
from deerflow.config.guardrails_config import get_guardrails_config
|
||||||
|
|
||||||
|
guardrails_config = get_guardrails_config()
|
||||||
if guardrails_config.enabled and guardrails_config.provider:
|
if guardrails_config.enabled and guardrails_config.provider:
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@@ -126,42 +125,19 @@ def _build_runtime_middlewares(
|
|||||||
return middlewares
|
return middlewares
|
||||||
|
|
||||||
|
|
||||||
def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
|
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||||
return _build_runtime_middlewares(
|
return _build_runtime_middlewares(
|
||||||
app_config=app_config,
|
|
||||||
include_uploads=True,
|
include_uploads=True,
|
||||||
include_dangling_tool_call_patch=True,
|
include_dangling_tool_call_patch=True,
|
||||||
lazy_init=lazy_init,
|
lazy_init=lazy_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_subagent_runtime_middlewares(
|
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||||
*,
|
|
||||||
app_config: AppConfig | None = None,
|
|
||||||
model_name: str | None = None,
|
|
||||||
lazy_init: bool = True,
|
|
||||||
) -> list[AgentMiddleware]:
|
|
||||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||||
if app_config is None:
|
return _build_runtime_middlewares(
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
app_config = get_app_config()
|
|
||||||
|
|
||||||
middlewares = _build_runtime_middlewares(
|
|
||||||
app_config=app_config,
|
|
||||||
include_uploads=False,
|
include_uploads=False,
|
||||||
include_dangling_tool_call_patch=True,
|
include_dangling_tool_call_patch=True,
|
||||||
lazy_init=lazy_init,
|
lazy_init=lazy_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_name is None and app_config.models:
|
|
||||||
model_name = app_config.models[0].name
|
|
||||||
|
|
||||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
|
||||||
if model_config is not None and model_config.supports_vision:
|
|
||||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
|
||||||
|
|
||||||
middlewares.append(ViewImageMiddleware())
|
|
||||||
|
|
||||||
return middlewares
|
|
||||||
|
|||||||
@@ -263,25 +263,21 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
files_message = self._create_files_message(new_files, historical_files)
|
files_message = self._create_files_message(new_files, historical_files)
|
||||||
|
|
||||||
# Extract original content - handle both string and list formats
|
# Extract original content - handle both string and list formats
|
||||||
original_content = last_message.content
|
original_content = ""
|
||||||
if isinstance(original_content, str):
|
if isinstance(last_message.content, str):
|
||||||
# Simple case: string content, just prepend files message
|
original_content = last_message.content
|
||||||
updated_content = f"{files_message}\n\n{original_content}"
|
elif isinstance(last_message.content, list):
|
||||||
elif isinstance(original_content, list):
|
text_parts = []
|
||||||
# Complex case: list content (multimodal), preserve all blocks
|
for block in last_message.content:
|
||||||
# Prepend files message as the first text block
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
files_block = {"type": "text", "text": f"{files_message}\n\n"}
|
text_parts.append(block.get("text", ""))
|
||||||
# Keep all original blocks (including images)
|
original_content = "\n".join(text_parts)
|
||||||
updated_content = [files_block, *original_content]
|
|
||||||
else:
|
|
||||||
# Other types, preserve as-is
|
|
||||||
updated_content = original_content
|
|
||||||
|
|
||||||
# Create new message with combined content.
|
# Create new message with combined content.
|
||||||
# Preserve additional_kwargs (including files metadata) so the frontend
|
# Preserve additional_kwargs (including files metadata) so the frontend
|
||||||
# can read structured file info from the streamed message.
|
# can read structured file info from the streamed message.
|
||||||
updated_message = HumanMessage(
|
updated_message = HumanMessage(
|
||||||
content=updated_content,
|
content=f"{files_message}\n\n{original_content}",
|
||||||
id=last_message.id,
|
id=last_message.id,
|
||||||
name=last_message.name,
|
name=last_message.name,
|
||||||
additional_kwargs=last_message.additional_kwargs,
|
additional_kwargs=last_message.additional_kwargs,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
|
|||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.installer import install_skill_from_archive
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
claim_unique_filename,
|
claim_unique_filename,
|
||||||
delete_file_safe,
|
delete_file_safe,
|
||||||
@@ -264,35 +264,25 @@ class DeerFlowClient:
|
|||||||
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
|
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
|
||||||
"""Copy message additional_kwargs when present."""
|
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
|
||||||
additional_kwargs = getattr(msg, "additional_kwargs", None)
|
|
||||||
if isinstance(additional_kwargs, dict) and additional_kwargs:
|
|
||||||
return dict(additional_kwargs)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
|
||||||
"""Build a ``messages-tuple`` AI text event."""
|
|
||||||
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||||
if usage:
|
if usage:
|
||||||
data["usage_metadata"] = usage
|
data["usage_metadata"] = usage
|
||||||
if additional_kwargs:
|
|
||||||
data["additional_kwargs"] = additional_kwargs
|
|
||||||
return StreamEvent(type="messages-tuple", data=data)
|
return StreamEvent(type="messages-tuple", data=data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
|
||||||
"""Build a ``messages-tuple`` AI tool-calls event."""
|
"""Build a ``messages-tuple`` AI tool-calls event."""
|
||||||
data: dict[str, Any] = {
|
return StreamEvent(
|
||||||
"type": "ai",
|
type="messages-tuple",
|
||||||
"content": "",
|
data={
|
||||||
"id": msg_id,
|
"type": "ai",
|
||||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
"content": "",
|
||||||
}
|
"id": msg_id,
|
||||||
if additional_kwargs:
|
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||||
data["additional_kwargs"] = additional_kwargs
|
},
|
||||||
return StreamEvent(type="messages-tuple", data=data)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
|
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
|
||||||
@@ -317,30 +307,19 @@ class DeerFlowClient:
|
|||||||
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
||||||
if getattr(msg, "usage_metadata", None):
|
if getattr(msg, "usage_metadata", None):
|
||||||
d["usage_metadata"] = msg.usage_metadata
|
d["usage_metadata"] = msg.usage_metadata
|
||||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
|
||||||
d["additional_kwargs"] = additional_kwargs
|
|
||||||
return d
|
return d
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
d = {
|
return {
|
||||||
"type": "tool",
|
"type": "tool",
|
||||||
"content": DeerFlowClient._extract_text(msg.content),
|
"content": DeerFlowClient._extract_text(msg.content),
|
||||||
"name": getattr(msg, "name", None),
|
"name": getattr(msg, "name", None),
|
||||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||||
"id": getattr(msg, "id", None),
|
"id": getattr(msg, "id", None),
|
||||||
}
|
}
|
||||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
|
||||||
d["additional_kwargs"] = additional_kwargs
|
|
||||||
return d
|
|
||||||
if isinstance(msg, HumanMessage):
|
if isinstance(msg, HumanMessage):
|
||||||
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
|
||||||
d["additional_kwargs"] = additional_kwargs
|
|
||||||
return d
|
|
||||||
if isinstance(msg, SystemMessage):
|
if isinstance(msg, SystemMessage):
|
||||||
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
|
||||||
d["additional_kwargs"] = additional_kwargs
|
|
||||||
return d
|
|
||||||
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -563,7 +542,6 @@ class DeerFlowClient:
|
|||||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
|
|
||||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||||
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
||||||
"""
|
"""
|
||||||
@@ -586,7 +564,6 @@ class DeerFlowClient:
|
|||||||
# in both the final ``messages`` chunk and the values snapshot —
|
# in both the final ``messages`` chunk and the values snapshot —
|
||||||
# count it only on whichever arrives first.
|
# count it only on whichever arrives first.
|
||||||
counted_usage_ids: set[str] = set()
|
counted_usage_ids: set[str] = set()
|
||||||
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
|
|
||||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
||||||
@@ -616,20 +593,6 @@ class DeerFlowClient:
|
|||||||
"total_tokens": total_tokens,
|
"total_tokens": total_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
|
|
||||||
if not additional_kwargs:
|
|
||||||
return None
|
|
||||||
if not msg_id:
|
|
||||||
return additional_kwargs
|
|
||||||
|
|
||||||
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
|
|
||||||
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
|
|
||||||
if not delta:
|
|
||||||
return None
|
|
||||||
|
|
||||||
sent.update(delta)
|
|
||||||
return delta
|
|
||||||
|
|
||||||
for item in self._agent.stream(
|
for item in self._agent.stream(
|
||||||
state,
|
state,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -657,31 +620,17 @@ class DeerFlowClient:
|
|||||||
|
|
||||||
if isinstance(msg_chunk, AIMessage):
|
if isinstance(msg_chunk, AIMessage):
|
||||||
text = self._extract_text(msg_chunk.content)
|
text = self._extract_text(msg_chunk.content)
|
||||||
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
|
|
||||||
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
||||||
sent_additional_kwargs = False
|
|
||||||
|
|
||||||
if text:
|
if text:
|
||||||
if msg_id:
|
if msg_id:
|
||||||
streamed_ids.add(msg_id)
|
streamed_ids.add(msg_id)
|
||||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||||
yield self._ai_text_event(
|
|
||||||
msg_id,
|
|
||||||
text,
|
|
||||||
counted_usage,
|
|
||||||
additional_kwargs_delta,
|
|
||||||
)
|
|
||||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
|
||||||
|
|
||||||
if msg_chunk.tool_calls:
|
if msg_chunk.tool_calls:
|
||||||
if msg_id:
|
if msg_id:
|
||||||
streamed_ids.add(msg_id)
|
streamed_ids.add(msg_id)
|
||||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
|
||||||
yield self._ai_tool_calls_event(
|
|
||||||
msg_id,
|
|
||||||
msg_chunk.tool_calls,
|
|
||||||
additional_kwargs_delta,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(msg_chunk, ToolMessage):
|
elif isinstance(msg_chunk, ToolMessage):
|
||||||
if msg_id:
|
if msg_id:
|
||||||
@@ -704,45 +653,17 @@ class DeerFlowClient:
|
|||||||
if msg_id and msg_id in streamed_ids:
|
if msg_id and msg_id in streamed_ids:
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
|
||||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
|
||||||
if additional_kwargs_delta:
|
|
||||||
# Metadata-only follow-up: ``messages-tuple`` has no
|
|
||||||
# dedicated attribution event, so clients should
|
|
||||||
# merge this empty-content AI event by message id
|
|
||||||
# and ignore it for text rendering.
|
|
||||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
||||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
|
||||||
sent_additional_kwargs = False
|
|
||||||
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
|
||||||
yield self._ai_tool_calls_event(
|
|
||||||
msg_id,
|
|
||||||
msg.tool_calls,
|
|
||||||
additional_kwargs_delta,
|
|
||||||
)
|
|
||||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
|
||||||
|
|
||||||
text = self._extract_text(msg.content)
|
text = self._extract_text(msg.content)
|
||||||
if text:
|
if text:
|
||||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||||
yield self._ai_text_event(
|
|
||||||
msg_id,
|
|
||||||
text,
|
|
||||||
counted_usage,
|
|
||||||
additional_kwargs_delta,
|
|
||||||
)
|
|
||||||
elif msg_id:
|
|
||||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
|
||||||
if not additional_kwargs_delta:
|
|
||||||
continue
|
|
||||||
# See the metadata-only follow-up convention above.
|
|
||||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
|
||||||
|
|
||||||
elif isinstance(msg, ToolMessage):
|
elif isinstance(msg, ToolMessage):
|
||||||
yield self._tool_message_event(msg)
|
yield self._tool_message_event(msg)
|
||||||
@@ -802,10 +723,6 @@ class DeerFlowClient:
|
|||||||
Dict with "models" key containing list of model info dicts,
|
Dict with "models" key containing list of model info dicts,
|
||||||
matching the Gateway API ``ModelsListResponse`` schema.
|
matching the Gateway API ``ModelsListResponse`` schema.
|
||||||
"""
|
"""
|
||||||
token_usage_enabled = getattr(getattr(self._app_config, "token_usage", None), "enabled", False)
|
|
||||||
if not isinstance(token_usage_enabled, bool):
|
|
||||||
token_usage_enabled = False
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"models": [
|
"models": [
|
||||||
{
|
{
|
||||||
@@ -817,8 +734,7 @@ class DeerFlowClient:
|
|||||||
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
|
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
|
||||||
}
|
}
|
||||||
for model in self._app_config.models
|
for model in self._app_config.models
|
||||||
],
|
]
|
||||||
"token_usage": {"enabled": token_usage_enabled},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def list_skills(self, enabled_only: bool = False) -> dict:
|
def list_skills(self, enabled_only: bool = False) -> dict:
|
||||||
@@ -831,6 +747,8 @@ class DeerFlowClient:
|
|||||||
Dict with "skills" key containing list of skill info dicts,
|
Dict with "skills" key containing list of skill info dicts,
|
||||||
matching the Gateway API ``SkillsListResponse`` schema.
|
matching the Gateway API ``SkillsListResponse`` schema.
|
||||||
"""
|
"""
|
||||||
|
from deerflow.skills.loader import load_skills
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"skills": [
|
"skills": [
|
||||||
{
|
{
|
||||||
@@ -840,7 +758,7 @@ class DeerFlowClient:
|
|||||||
"category": s.category,
|
"category": s.category,
|
||||||
"enabled": s.enabled,
|
"enabled": s.enabled,
|
||||||
}
|
}
|
||||||
for s in get_or_new_skill_storage().load_skills(enabled_only=enabled_only)
|
for s in load_skills(enabled_only=enabled_only)
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -949,9 +867,9 @@ class DeerFlowClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Skill info dict, or None if not found.
|
Skill info dict, or None if not found.
|
||||||
"""
|
"""
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.loader import load_skills
|
||||||
|
|
||||||
skill = next((s for s in get_or_new_skill_storage().load_skills(enabled_only=False) if s.name == name), None)
|
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
|
||||||
if skill is None:
|
if skill is None:
|
||||||
return None
|
return None
|
||||||
return {
|
return {
|
||||||
@@ -976,9 +894,9 @@ class DeerFlowClient:
|
|||||||
ValueError: If the skill is not found.
|
ValueError: If the skill is not found.
|
||||||
OSError: If the config file cannot be written.
|
OSError: If the config file cannot be written.
|
||||||
"""
|
"""
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.loader import load_skills
|
||||||
|
|
||||||
skills = get_or_new_skill_storage().load_skills(enabled_only=False)
|
skills = load_skills(enabled_only=False)
|
||||||
skill = next((s for s in skills if s.name == name), None)
|
skill = next((s for s in skills if s.name == name), None)
|
||||||
if skill is None:
|
if skill is None:
|
||||||
raise ValueError(f"Skill '{name}' not found")
|
raise ValueError(f"Skill '{name}' not found")
|
||||||
@@ -1001,7 +919,7 @@ class DeerFlowClient:
|
|||||||
self._agent_config_key = None
|
self._agent_config_key = None
|
||||||
reload_extensions_config()
|
reload_extensions_config()
|
||||||
|
|
||||||
updated = next((s for s in get_or_new_skill_storage().load_skills(enabled_only=False) if s.name == name), None)
|
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
|
||||||
if updated is None:
|
if updated is None:
|
||||||
raise RuntimeError(f"Skill '{name}' disappeared after update")
|
raise RuntimeError(f"Skill '{name}' disappeared after update")
|
||||||
return {
|
return {
|
||||||
@@ -1025,7 +943,7 @@ class DeerFlowClient:
|
|||||||
FileNotFoundError: If the file does not exist.
|
FileNotFoundError: If the file does not exist.
|
||||||
ValueError: If the file is invalid.
|
ValueError: If the file is invalid.
|
||||||
"""
|
"""
|
||||||
return get_or_new_skill_storage().install_skill_from_archive(skill_path)
|
return install_skill_from_archive(skill_path)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Public API — memory management
|
# Public API — memory management
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import errno
|
|
||||||
import logging
|
import logging
|
||||||
import shlex
|
import shlex
|
||||||
import threading
|
import threading
|
||||||
@@ -7,14 +6,11 @@ import uuid
|
|||||||
|
|
||||||
from agent_sandbox import Sandbox as AioSandboxClient
|
from agent_sandbox import Sandbox as AioSandboxClient
|
||||||
|
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
|
||||||
from deerflow.sandbox.sandbox import Sandbox
|
from deerflow.sandbox.sandbox import Sandbox
|
||||||
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MAX_DOWNLOAD_SIZE = 100 * 1024 * 1024 # 100 MB
|
|
||||||
|
|
||||||
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
|
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
|
||||||
|
|
||||||
|
|
||||||
@@ -52,12 +48,6 @@ class AioSandbox(Sandbox):
|
|||||||
self._home_dir = context.home_dir
|
self._home_dir = context.home_dir
|
||||||
return self._home_dir
|
return self._home_dir
|
||||||
|
|
||||||
# Default no_change_timeout for exec_command (seconds). Matches the
|
|
||||||
# client-level timeout so that long-running commands which produce no
|
|
||||||
# output are not prematurely terminated by the sandbox's built-in 120 s
|
|
||||||
# default.
|
|
||||||
_DEFAULT_NO_CHANGE_TIMEOUT = 600
|
|
||||||
|
|
||||||
def execute_command(self, command: str) -> str:
|
def execute_command(self, command: str) -> str:
|
||||||
"""Execute a shell command in the sandbox.
|
"""Execute a shell command in the sandbox.
|
||||||
|
|
||||||
@@ -76,13 +66,13 @@ class AioSandbox(Sandbox):
|
|||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
try:
|
try:
|
||||||
result = self._client.shell.exec_command(command=command, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
result = self._client.shell.exec_command(command=command)
|
||||||
output = result.data.output if result.data else ""
|
output = result.data.output if result.data else ""
|
||||||
|
|
||||||
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
||||||
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
|
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
|
||||||
fresh_id = str(uuid.uuid4())
|
fresh_id = str(uuid.uuid4())
|
||||||
result = self._client.shell.exec_command(command=command, id=fresh_id, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
result = self._client.shell.exec_command(command=command, id=fresh_id)
|
||||||
output = result.data.output if result.data else ""
|
output = result.data.output if result.data else ""
|
||||||
|
|
||||||
return output if output else "(no output)"
|
return output if output else "(no output)"
|
||||||
@@ -106,49 +96,6 @@ class AioSandbox(Sandbox):
|
|||||||
logger.error(f"Failed to read file in sandbox: {e}")
|
logger.error(f"Failed to read file in sandbox: {e}")
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
def download_file(self, path: str) -> bytes:
|
|
||||||
"""Download file bytes from the sandbox.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
PermissionError: If the path contains '..' traversal segments or is
|
|
||||||
outside ``VIRTUAL_PATH_PREFIX``.
|
|
||||||
OSError: If the file cannot be retrieved from the sandbox.
|
|
||||||
"""
|
|
||||||
# Reject path traversal before sending to the container API.
|
|
||||||
# LocalSandbox gets this implicitly via _resolve_path;
|
|
||||||
# here the path is forwarded verbatim so we must check explicitly.
|
|
||||||
normalised = path.replace("\\", "/")
|
|
||||||
for segment in normalised.split("/"):
|
|
||||||
if segment == "..":
|
|
||||||
logger.error(f"Refused download due to path traversal: {path}")
|
|
||||||
raise PermissionError(f"Access denied: path traversal detected in '{path}'")
|
|
||||||
|
|
||||||
stripped_path = normalised.lstrip("/")
|
|
||||||
allowed_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
|
||||||
if stripped_path != allowed_prefix and not stripped_path.startswith(f"{allowed_prefix}/"):
|
|
||||||
logger.error("Refused download outside allowed directory: path=%s, allowed_prefix=%s", path, VIRTUAL_PATH_PREFIX)
|
|
||||||
raise PermissionError(f"Access denied: path must be under '{VIRTUAL_PATH_PREFIX}': '{path}'")
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
chunks: list[bytes] = []
|
|
||||||
total = 0
|
|
||||||
for chunk in self._client.file.download_file(path=path):
|
|
||||||
total += len(chunk)
|
|
||||||
if total > _MAX_DOWNLOAD_SIZE:
|
|
||||||
raise OSError(
|
|
||||||
errno.EFBIG,
|
|
||||||
f"File exceeds maximum download size of {_MAX_DOWNLOAD_SIZE} bytes",
|
|
||||||
path,
|
|
||||||
)
|
|
||||||
chunks.append(chunk)
|
|
||||||
return b"".join(chunks)
|
|
||||||
except OSError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to download file in sandbox: {e}")
|
|
||||||
raise OSError(f"Failed to download file '{path}' from sandbox: {e}") from e
|
|
||||||
|
|
||||||
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
||||||
"""List the contents of a directory in the sandbox.
|
"""List the contents of a directory in the sandbox.
|
||||||
|
|
||||||
@@ -161,7 +108,7 @@ class AioSandbox(Sandbox):
|
|||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
try:
|
try:
|
||||||
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500", no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
||||||
output = result.data.output if result.data else ""
|
output = result.data.output if result.data else ""
|
||||||
if output:
|
if output:
|
||||||
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
||||||
|
|||||||
@@ -120,16 +120,6 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
||||||
self._start_idle_checker()
|
self._start_idle_checker()
|
||||||
|
|
||||||
@property
|
|
||||||
def uses_thread_data_mounts(self) -> bool:
|
|
||||||
"""Whether thread workspace/uploads/outputs are visible via mounts.
|
|
||||||
|
|
||||||
Local container backends bind-mount the thread data directories, so files
|
|
||||||
written by the gateway are already visible when the sandbox starts.
|
|
||||||
Remote backends may require explicit file sync.
|
|
||||||
"""
|
|
||||||
return isinstance(self._backend, LocalContainerBackend)
|
|
||||||
|
|
||||||
# ── Factory methods ──────────────────────────────────────────────────
|
# ── Factory methods ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def _create_backend(self) -> SandboxBackend:
|
def _create_backend(self) -> SandboxBackend:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shlex
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -87,88 +86,6 @@ def _format_container_mount(runtime: str, host_path: str, container_path: str, r
|
|||||||
return ["-v", mount_spec]
|
return ["-v", mount_spec]
|
||||||
|
|
||||||
|
|
||||||
def _redact_container_command_for_log(cmd: list[str]) -> list[str]:
|
|
||||||
"""Return a Docker/Container command with environment values redacted."""
|
|
||||||
redacted: list[str] = []
|
|
||||||
redact_next_env = False
|
|
||||||
|
|
||||||
for arg in cmd:
|
|
||||||
if redact_next_env:
|
|
||||||
if "=" in arg:
|
|
||||||
key = arg.split("=", 1)[0]
|
|
||||||
redacted.append(f"{key}=<redacted>" if key else "<redacted>")
|
|
||||||
else:
|
|
||||||
redacted.append(arg)
|
|
||||||
redact_next_env = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if arg in {"-e", "--env"}:
|
|
||||||
redacted.append(arg)
|
|
||||||
redact_next_env = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
if arg.startswith("--env="):
|
|
||||||
value = arg.removeprefix("--env=")
|
|
||||||
if "=" in value:
|
|
||||||
key = value.split("=", 1)[0]
|
|
||||||
redacted.append(f"--env={key}=<redacted>" if key else "--env=<redacted>")
|
|
||||||
else:
|
|
||||||
redacted.append(arg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
redacted.append(arg)
|
|
||||||
|
|
||||||
return redacted
|
|
||||||
|
|
||||||
|
|
||||||
def _format_container_command_for_log(cmd: list[str]) -> str:
|
|
||||||
if os.name == "nt":
|
|
||||||
return subprocess.list2cmdline(cmd)
|
|
||||||
return shlex.join(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_sandbox_host(host: str) -> str:
|
|
||||||
return host.strip().lower()
|
|
||||||
|
|
||||||
|
|
||||||
def _is_ipv6_loopback_sandbox_host(host: str) -> bool:
|
|
||||||
return _normalize_sandbox_host(host) in {"::1", "[::1]"}
|
|
||||||
|
|
||||||
|
|
||||||
def _is_loopback_sandbox_host(host: str) -> bool:
|
|
||||||
return _normalize_sandbox_host(host) in {"", "localhost", "127.0.0.1", "::1", "[::1]"}
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str | None = None) -> str:
|
|
||||||
"""Choose the host interface for legacy Docker ``-p`` sandbox publishing.
|
|
||||||
|
|
||||||
Bare-metal/local runs talk to sandboxes through localhost and should not
|
|
||||||
expose the sandbox HTTP API on every host interface. Docker-outside-of-
|
|
||||||
Docker deployments commonly use ``host.docker.internal`` from another
|
|
||||||
container; keep their legacy broad bind unless operators opt into a
|
|
||||||
narrower bind with ``DEER_FLOW_SANDBOX_BIND_HOST``. When operators choose
|
|
||||||
an IPv6 loopback sandbox host, bind Docker to IPv6 loopback as well so the
|
|
||||||
advertised sandbox URL and published socket use the same address family.
|
|
||||||
"""
|
|
||||||
explicit_bind = bind_host if bind_host is not None else os.environ.get("DEER_FLOW_SANDBOX_BIND_HOST")
|
|
||||||
if explicit_bind is not None:
|
|
||||||
explicit_bind = explicit_bind.strip()
|
|
||||||
if explicit_bind:
|
|
||||||
logger.debug("Docker sandbox bind: %s (explicit bind host override)", explicit_bind)
|
|
||||||
return explicit_bind
|
|
||||||
|
|
||||||
host = sandbox_host if sandbox_host is not None else os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
|
||||||
if _is_ipv6_loopback_sandbox_host(host):
|
|
||||||
logger.debug("Docker sandbox bind: [::1] (IPv6 loopback sandbox host)")
|
|
||||||
return "[::1]"
|
|
||||||
if _is_loopback_sandbox_host(host):
|
|
||||||
logger.debug("Docker sandbox bind: 127.0.0.1 (loopback default)")
|
|
||||||
return "127.0.0.1"
|
|
||||||
|
|
||||||
logger.debug("Docker sandbox bind: 0.0.0.0 (non-loopback sandbox host compatibility)")
|
|
||||||
return "0.0.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
class LocalContainerBackend(SandboxBackend):
|
class LocalContainerBackend(SandboxBackend):
|
||||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||||
|
|
||||||
@@ -507,17 +424,12 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
if self._runtime == "docker":
|
if self._runtime == "docker":
|
||||||
cmd.extend(["--security-opt", "seccomp=unconfined"])
|
cmd.extend(["--security-opt", "seccomp=unconfined"])
|
||||||
|
|
||||||
if self._runtime == "docker":
|
|
||||||
port_mapping = f"{_resolve_docker_bind_host()}:{port}:8080"
|
|
||||||
else:
|
|
||||||
port_mapping = f"{port}:8080"
|
|
||||||
|
|
||||||
cmd.extend(
|
cmd.extend(
|
||||||
[
|
[
|
||||||
"--rm",
|
"--rm",
|
||||||
"-d",
|
"-d",
|
||||||
"-p",
|
"-p",
|
||||||
port_mapping,
|
f"{port}:8080",
|
||||||
"--name",
|
"--name",
|
||||||
container_name,
|
container_name,
|
||||||
]
|
]
|
||||||
@@ -552,8 +464,7 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
|
|
||||||
cmd.append(self._image)
|
cmd.append(self._image)
|
||||||
|
|
||||||
log_cmd = _format_container_command_for_log(_redact_container_command_for_log(cmd))
|
logger.info(f"Starting container using {self._runtime}: {' '.join(cmd)}")
|
||||||
logger.info(f"Starting container using {self._runtime}: {log_cmd}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ import logging
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
from .backend import SandboxBackend
|
from .backend import SandboxBackend
|
||||||
from .sandbox_info import SandboxInfo
|
from .sandbox_info import SandboxInfo
|
||||||
|
|
||||||
@@ -86,52 +84,8 @@ class RemoteSandboxBackend(SandboxBackend):
|
|||||||
"""
|
"""
|
||||||
return self._provisioner_discover(sandbox_id)
|
return self._provisioner_discover(sandbox_id)
|
||||||
|
|
||||||
def list_running(self) -> list[SandboxInfo]:
|
|
||||||
"""Return all sandboxes currently managed by the provisioner.
|
|
||||||
|
|
||||||
Calls ``GET /api/sandboxes`` so that ``AioSandboxProvider._reconcile_orphans()``
|
|
||||||
can adopt pods that were created by a previous process and were never
|
|
||||||
explicitly destroyed.
|
|
||||||
Without this, a process restart silently orphans all existing k8s Pods —
|
|
||||||
they stay running forever because the idle checker only
|
|
||||||
tracks in-process state.
|
|
||||||
"""
|
|
||||||
return self._provisioner_list()
|
|
||||||
|
|
||||||
# ── Provisioner API calls ─────────────────────────────────────────────
|
# ── Provisioner API calls ─────────────────────────────────────────────
|
||||||
|
|
||||||
def _provisioner_list(self) -> list[SandboxInfo]:
|
|
||||||
"""GET /api/sandboxes → list all running sandboxes."""
|
|
||||||
try:
|
|
||||||
resp = requests.get(f"{self._provisioner_url}/api/sandboxes", timeout=10)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
logger.warning("Provisioner list_running returned non-dict payload: %r", type(data))
|
|
||||||
return []
|
|
||||||
|
|
||||||
sandboxes = data.get("sandboxes", [])
|
|
||||||
if not isinstance(sandboxes, list):
|
|
||||||
logger.warning("Provisioner list_running returned non-list sandboxes: %r", type(sandboxes))
|
|
||||||
return []
|
|
||||||
|
|
||||||
infos: list[SandboxInfo] = []
|
|
||||||
for sandbox in sandboxes:
|
|
||||||
if not isinstance(sandbox, dict):
|
|
||||||
logger.warning("Provisioner list_running entry is not a dict: %r", type(sandbox))
|
|
||||||
continue
|
|
||||||
|
|
||||||
sandbox_id = sandbox.get("sandbox_id")
|
|
||||||
sandbox_url = sandbox.get("sandbox_url")
|
|
||||||
if isinstance(sandbox_id, str) and sandbox_id and isinstance(sandbox_url, str) and sandbox_url:
|
|
||||||
infos.append(SandboxInfo(sandbox_id=sandbox_id, sandbox_url=sandbox_url))
|
|
||||||
|
|
||||||
logger.info("Provisioner list_running: %d sandbox(es) found", len(infos))
|
|
||||||
return infos
|
|
||||||
except requests.RequestException as exc:
|
|
||||||
logger.warning("Provisioner list_running failed: %s", exc)
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||||
"""POST /api/sandboxes → create Pod + Service."""
|
"""POST /api/sandboxes → create Pod + Service."""
|
||||||
try:
|
try:
|
||||||
@@ -140,7 +94,6 @@ class RemoteSandboxBackend(SandboxBackend):
|
|||||||
json={
|
json={
|
||||||
"sandbox_id": sandbox_id,
|
"sandbox_id": sandbox_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"user_id": get_effective_user_id(),
|
|
||||||
},
|
},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ class JinaClient:
|
|||||||
|
|
||||||
return response.text
|
return response.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Request to Jina API failed: {type(e).__name__}: {e}"
|
error_message = f"Request to Jina API failed: {str(e)}"
|
||||||
logger.warning(error_message)
|
logger.exception(error_message)
|
||||||
return f"Error: {error_message}"
|
return f"Error: {error_message}"
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
|
||||||
from deerflow.community.jina_ai.jina_client import JinaClient
|
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||||
@@ -28,5 +26,5 @@ async def web_fetch_tool(url: str) -> str:
|
|||||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||||
return html_content
|
return html_content
|
||||||
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
|
article = readability_extractor.extract_article(html_content)
|
||||||
return article.to_markdown()[:4096]
|
return article.to_markdown()[:4096]
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
from .tools import web_search_tool
|
|
||||||
|
|
||||||
__all__ = ["web_search_tool"]
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user