Compare commits

..

2 Commits

Author SHA1 Message Date
rayhpeng 9ed83c84dc fix(runtime): use pass for protocol stubs 2026-06-01 15:31:46 +08:00
rayhpeng 30bb2d5149 refactor(runtime): add run DDD boundary skeleton 2026-06-01 09:22:32 +08:00
258 changed files with 3503 additions and 22146 deletions
-1
View File
@@ -21,7 +21,6 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# DEEPSEEK_API_KEY=your-deepseek-api-key
# NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai
# MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io
# STEPFUN_API_KEY=your-stepfun-api-key # OpenAI-compatible, see https://platform.stepfun.com
# VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible
# FEISHU_APP_ID=your-feishu-app-id
# FEISHU_APP_SECRET=your-feishu-app-secret
-159
View File
@@ -1,159 +0,0 @@
name: 🐛 Bug report
description: Report something that isn't working so maintainers can reproduce and fix it.
title: "[bug] "
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to file a bug. A clear, reproducible report is the
single biggest factor in how fast it gets fixed.
Please fill in every required field — especially **reproduction steps** and **logs**.
- type: checkboxes
id: preflight
attributes:
label: Before you start
options:
- label: I searched [existing issues](https://github.com/bytedance/deer-flow/issues?q=is%3Aissue) and this is not a duplicate.
required: true
- label: I can reproduce this on the latest `main`.
required: false
- type: input
id: summary
attributes:
label: Problem summary
description: One sentence describing the bug.
placeholder: e.g. make dev fails to start the gateway service
validations:
required: true
- type: dropdown
id: areas
attributes:
label: Affected area(s)
description: Which part of DeerFlow does this touch? Select all that apply.
multiple: true
options:
- Frontend (UI / Next.js)
- Backend API (gateway / endpoints / SSE)
- Agents / LangGraph (graph, prompts, langgraph.json)
- Sandbox / Docker
- Skills
- MCP
- Config / setup (make, config.yaml, env)
- Docs
- Not sure
validations:
required: true
- type: textarea
id: actual
attributes:
label: What happened?
description: The actual behavior. Include the key error lines verbatim.
placeholder: When I do X, I expected Y but I got Z.
validations:
required: true
- type: textarea
id: expected
attributes:
label: Expected behavior
placeholder: What did you expect to happen instead?
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: Exact commands and sequence. Minimal steps that reliably reproduce the problem.
placeholder: |
1. make check
2. make install
3. make dev
4. ...
validations:
required: true
- type: textarea
id: logs
attributes:
label: Relevant logs
description: Paste key lines from logs (for example `logs/gateway.log`, `logs/frontend.log`). Redact secrets.
render: shell
validations:
required: true
- type: dropdown
id: run_mode
attributes:
label: How are you running DeerFlow?
options:
- Local (make dev)
- Docker (make docker-start)
- CI
- Other
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
options:
- macOS
- Linux
- Windows
- Other
validations:
required: true
- type: input
id: platform_details
attributes:
label: Platform details
description: Architecture and shell, if relevant.
placeholder: e.g. arm64, zsh
- type: input
id: python_version
attributes:
label: Python version
placeholder: e.g. Python 3.12.9
- type: input
id: node_version
attributes:
label: Node.js version
placeholder: e.g. v22.11.0
- type: input
id: pnpm_version
attributes:
label: pnpm version
placeholder: e.g. 10.26.2
- type: input
id: uv_version
attributes:
label: uv version
placeholder: e.g. 0.7.20
- type: textarea
id: git_info
attributes:
label: Git state
description: Output of `git branch --show-current` and the latest commit SHA.
placeholder: |
branch: feature/my-branch
commit: abcdef1
- type: textarea
id: additional
attributes:
label: Additional context
description: Screenshots, related issues, config snippets (redacted), or anything else that helps triage.
-11
View File
@@ -1,11 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: 💬 Questions & usage help
url: https://github.com/bytedance/deer-flow/discussions/categories/q-a
about: "How do I use X? Why does Y behave like that? Ask in Discussions — it gets answered faster and stays searchable."
- name: 💡 Ideas & proposals
url: https://github.com/bytedance/deer-flow/discussions/categories/ideas
about: Have a half-formed idea? Float it in Discussions before opening a formal feature request.
- name: 🔒 Report a security vulnerability
url: https://github.com/bytedance/deer-flow/security/policy
about: Do not open a public issue for security problems. Follow the security policy instead.
@@ -1,67 +0,0 @@
name: 💡 Feature request
description: Propose a new capability or an improvement to an existing one.
title: "[feat] "
labels: ["enhancement"]
body:
- type: markdown
attributes:
value: |
Thanks for the suggestion. For non-trivial features, please open a
[Discussion](https://github.com/bytedance/deer-flow/discussions/categories/ideas)
first to align on scope before writing code.
- type: checkboxes
id: preflight
attributes:
label: Before you start
options:
- label: I searched [existing issues](https://github.com/bytedance/deer-flow/issues?q=is%3Aissue) and this is not a duplicate.
required: true
- type: textarea
id: problem
attributes:
label: Problem / motivation
description: What problem does this solve? What is painful today, or what does it unblock?
placeholder: "I'm always frustrated when ..."
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: Describe the change from a user's / caller's perspective.
validations:
required: true
- type: dropdown
id: areas
attributes:
label: Affected area(s)
description: Which part of DeerFlow would this touch? Select all that apply.
multiple: true
options:
- Frontend (UI / Next.js)
- Backend API (gateway / endpoints / SSE)
- Agents / LangGraph (graph, prompts, langgraph.json)
- Sandbox / Docker
- Skills
- MCP
- Config / setup
- Docs
- Not sure
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives considered
description: Other approaches you weighed and why you discarded them.
- type: textarea
id: additional
attributes:
label: Additional context
description: Mockups, links, related issues, or anything else that helps.
@@ -0,0 +1,128 @@
name: Runtime Information
description: Report runtime/environment details to help reproduce an issue.
title: "[runtime] "
labels:
- needs-triage
body:
- type: markdown
attributes:
value: |
Thanks for sharing runtime details.
Complete this form so maintainers can quickly reproduce and diagnose the problem.
- type: input
id: summary
attributes:
label: Problem summary
description: Short summary of the issue.
placeholder: e.g. make dev fails to start gateway service
validations:
required: true
- type: textarea
id: expected
attributes:
label: Expected behavior
placeholder: What did you expect to happen?
validations:
required: true
- type: textarea
id: actual
attributes:
label: Actual behavior
placeholder: What happened instead? Include key error lines.
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
options:
- macOS
- Linux
- Windows
- Other
validations:
required: true
- type: input
id: platform_details
attributes:
label: Platform details
description: Add architecture and shell if relevant.
placeholder: e.g. arm64, zsh
- type: input
id: python_version
attributes:
label: Python version
placeholder: e.g. Python 3.12.9
- type: input
id: node_version
attributes:
label: Node.js version
placeholder: e.g. v23.11.0
- type: input
id: pnpm_version
attributes:
label: pnpm version
placeholder: e.g. 10.26.2
- type: input
id: uv_version
attributes:
label: uv version
placeholder: e.g. 0.7.20
- type: dropdown
id: run_mode
attributes:
label: How are you running DeerFlow?
options:
- Local (make dev)
- Docker (make docker-dev)
- CI
- Other
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Reproduction steps
description: Provide exact commands and sequence.
placeholder: |
1. make check
2. make install
3. make dev
4. ...
validations:
required: true
- type: textarea
id: logs
attributes:
label: Relevant logs
description: Paste key lines from logs (for example logs/gateway.log, logs/frontend.log).
render: shell
validations:
required: true
- type: textarea
id: git_info
attributes:
label: Git state
description: Share output of git branch and latest commit SHA.
placeholder: |
branch: feature/my-branch
commit: abcdef1
- type: textarea
id: additional
attributes:
label: Additional context
description: Add anything else that might help triage.
-119
View File
@@ -1,119 +0,0 @@
# Declarative label source of truth for DeerFlow.
#
# This file is the single source of truth for repository labels used by the
# auto-labeling workflows (.github/workflows/pr-labeler.yml, pr-triage.yml,
# issue-triage.yml). Auto-labelers can only apply labels that already exist,
# so every label referenced by a workflow MUST be declared here.
#
# Apply with: uv run --with pyyaml python scripts/sync_labels.py [--repo OWNER/NAME]
# CI keeps it in sync via .github/workflows/label-sync.yml (runs on changes here).
#
# Sync is additive/update-only: it creates or updates the labels listed below
# and never deletes labels that are not listed.
#
# Color = 6-digit hex without the leading '#'.
labels:
# ── Type ─────────────────────────────────────────────────────────────────
# Mostly GitHub defaults; declared here so colors/descriptions stay stable
# and so issue templates can rely on them existing.
- name: bug
color: d73a4a
description: Something isn't working
- name: enhancement
color: a2eeef
description: New feature or request
- name: documentation
color: 0075ca
description: Improvements or additions to documentation
- name: question
color: d876e3
description: Further information is requested
# ── Area (auto, by changed paths — see .github/labeler.yml) ───────────────
# Mirrors the "Surface area" section of the pull request template.
- name: "area:frontend"
color: c5def5
description: Next.js frontend under frontend/
- name: "area:backend"
color: c5def5
description: Gateway / runtime / core backend under backend/
- name: "area:agents"
color: c5def5
description: Agents, subagents, graph wiring, prompts, langgraph.json
- name: "area:sandbox"
color: c5def5
description: Sandboxed execution and docker/
- name: "area:skills"
color: c5def5
description: Skills under skills/ or the skills harness
- name: "area:mcp"
color: c5def5
description: Model Context Protocol integration
- name: "area:ci"
color: c5def5
description: GitHub Actions, CI config, repo tooling
- name: "area:docs"
color: c5def5
description: Documentation and Markdown only
- name: "area:deps"
color: c5def5
description: Dependency manifests / lockfiles
# ── Size (auto, by additions + deletions — see pr-triage.yml) ─────────────
- name: "size/XS"
color: "009900"
description: PR changes < 20 lines
- name: "size/S"
color: 77bb00
description: PR changes 20-100 lines
- name: "size/M"
color: eebb00
description: PR changes 100-300 lines
- name: "size/L"
color: ee9900
description: PR changes 300-700 lines
- name: "size/XL"
color: ee5500
description: PR changes 700+ lines
# ── Risk (auto, by changed paths — see pr-triage.yml) ─────────────────────
- name: "risk:low"
color: 0e8a16
description: "Low risk: docs / i18n / assets only"
- name: "risk:medium"
color: fbca04
description: "Medium risk: regular code changes"
- name: "risk:high"
color: b60205
description: "High risk: backend API, agents, sandbox, auth, deps, CI"
# ── Priority (manual) ─────────────────────────────────────────────────────
- name: P0
color: b60205
description: Critical priority
- name: P1
color: d93f0b
description: Major priority
- name: P2
color: e99695
description: Normal priority
# ── Status (auto + manual) ────────────────────────────────────────────────
- name: needs-triage
color: fef2c0
description: Awaiting maintainer triage
- name: needs-validation
color: d4c5f9
description: Touches front/back contract surface; needs real-path validation
- name: skip-validation
color: cccccc
description: "Maintainer override: do not auto-add needs-validation on this PR"
- name: reviewing
color: 5319e7
description: A maintainer is reviewing this PR
# ── Contributor ───────────────────────────────────────────────────────────
- name: first-time-contributor
color: c2e0c6
description: First contribution to this repository — be welcoming
-14
View File
@@ -59,17 +59,3 @@ Fixes #
Frontend: cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test
Frontend E2E (if you touched frontend/): cd frontend && make test-e2e -->
## AI assistance
<!-- DeerFlow is an AI project — most PRs here use AI coding tools, and that's
welcome. Disclosing it just helps reviewers calibrate how closely to read the
diff. Please fill all three; don't delete the section. -->
**Tool(s) used:** <!-- e.g. Claude Code, Cursor, GitHub Copilot, Codex, Windsurf, or "none" -->
**How you used it:** <!-- e.g. "generated the module from a spec", "autocomplete only",
"AI wrote tests, I wrote the impl". A prompt or conversation link is great too. -->
- [ ] I've read and understand every line of this change and take responsibility for it — it's not unreviewed AI output.
-38
View File
@@ -1,38 +0,0 @@
name: Label Sync
# Keeps repository labels in sync with the declarative source of truth
# (.github/labels.yml). Runs whenever that file changes on main, and can be
# triggered manually. Additive/update-only — never deletes labels.
on:
push:
branches: [main]
paths:
- ".github/labels.yml"
- "scripts/sync_labels.py"
- ".github/workflows/label-sync.yml"
workflow_dispatch:
permissions:
contents: read
issues: write
concurrency:
group: label-sync
cancel-in-progress: false
jobs:
sync:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Sync labels
run: uv run --with pyyaml python scripts/sync_labels.py
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
-108
View File
@@ -1,108 +0,0 @@
name: Replay E2E (front-back contract)
# Guards the front-back contract via record/replay (no API key in CI):
# Layer 1 — backend golden: replay a recorded trace through the real gateway,
# assert the SSE event sequence matches the committed golden.
# Layer 2 — full-stack render: real Next.js frontend + real gateway (replay
# model) + Chromium; assert the replayed turns render in the browser.
# Triggered by changes on EITHER side of the contract so a backend change can no
# longer pass without the frontend-facing checks running.
on:
push:
branches: ["main"]
paths:
- "frontend/**"
- "backend/app/gateway/**"
- "backend/packages/harness/**"
- "backend/tests/fixtures/replay/**"
- "backend/tests/replay_provider.py"
- "backend/tests/_replay_fixture.py"
- "backend/tests/seed_runs_router.py"
- "backend/tests/test_replay_golden.py"
- "backend/scripts/run_replay_gateway.py"
- ".github/workflows/replay-e2e.yml"
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- "frontend/**"
- "backend/app/gateway/**"
- "backend/packages/harness/**"
- "backend/tests/fixtures/replay/**"
- "backend/tests/replay_provider.py"
- "backend/tests/_replay_fixture.py"
- "backend/tests/seed_runs_router.py"
- "backend/tests/test_replay_golden.py"
- "backend/scripts/run_replay_gateway.py"
- ".github/workflows/replay-e2e.yml"
concurrency:
group: replay-e2e-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
backend-replay-golden:
name: Layer 1 — backend golden (no API key)
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install backend dependencies
working-directory: backend
run: uv sync --group dev
- name: Replay golden (backend SSE contract)
working-directory: backend
run: PYTHONPATH=. uv run pytest tests/test_replay_golden.py -v
fullstack-replay-render:
name: Layer 2 — full-stack render (no API key)
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 25
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install backend dependencies (replay gateway)
working-directory: backend
run: uv sync --group dev
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Install Playwright Chromium
working-directory: frontend
run: npx playwright install chromium --with-deps
- name: Full-stack replay render (DOM assertions are the gate)
working-directory: frontend
run: pnpm exec playwright test -c playwright.real-backend.config.ts
- name: Upload report + render artifact
uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: replay-render
path: |
frontend/playwright-report/
frontend/test-results/
retention-days: 7
-223
View File
@@ -1,223 +0,0 @@
name: Triage
# One workflow for all event-driven PR/issue labeling. Replaces the former
# pr-labeler / pr-triage / issue-triage workflows (and drops actions/labeler).
#
# Design notes:
# * All jobs are pure-metadata: they read changed-file lists / PR fields / the
# review payload via the API and write labels. PR code is NEVER checked out
# or executed, so pull_request_target is safe here.
# * Each job only reconciles labels in namespaces IT owns
# (area:* / size/* / risk:* / needs-validation). It never touches labels
# applied by maintainers or other tools (bug, priority, etc.). first-time-
# contributor and reviewing are add-only.
# * State is read LIVE (listFiles + listLabelsOnIssue) at run time, not from
# the (stale) event payload, so rapid synchronize events converge instead
# of thrashing.
on:
pull_request_target:
types: [opened, synchronize, reopened, ready_for_review]
pull_request_review:
types: [submitted]
issues:
types: [opened]
permissions:
contents: read
pull-requests: write
issues: write
jobs:
# ── PR: area / size / risk / needs-validation / first-time ─────────────────
pr-labels:
if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false
runs-on: ubuntu-latest
concurrency:
group: triage-pr-${{ github.event.pull_request.number }}
cancel-in-progress: true
steps:
- name: Apply PR labels from live state
uses: actions/github-script@v8
with:
script: |
const pr = context.payload.pull_request;
const { owner, repo } = context.repo;
const num = pr.number;
// ---- live changed files ----
const files = await github.paginate(github.rest.pulls.listFiles, {
owner, repo, pull_number: num, per_page: 100,
});
const paths = files.map(f => f.filename);
const m = (re) => paths.some(p => re.test(p));
// ---- area: replaces .github/labeler.yml (path -> area) ----
const AREA_RULES = [
['area:frontend', [/^frontend\//]],
['area:backend', [/^backend\/app\//, /^backend\/packages\/harness\/deerflow\/(runtime|persistence|config|tools|guardrails|tracing|models|utils|uploads)\//]],
['area:agents', [/^backend\/packages\/harness\/deerflow\/(agents|subagents|reflection)\//, /(^|\/)langgraph\.json$/, /^backend\/.*\/prompts\//]],
['area:sandbox', [/^docker\//, /^backend\/packages\/harness\/deerflow\/sandbox\//, /(^|\/)Dockerfile$/]],
['area:skills', [/^skills\//, /^backend\/packages\/harness\/deerflow\/skills\//, /^frontend\/src\/core\/skills\//]],
['area:mcp', [/^backend\/packages\/harness\/deerflow\/mcp\//, /^frontend\/src\/core\/mcp\//]],
['area:ci', [/^\.github\//, /^scripts\//]],
['area:docs', [/^docs\//, /\.mdx?$/]],
['area:deps', [/(^|\/)(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/]],
];
const areaLabels = AREA_RULES
.filter(([, res]) => res.some(re => m(re)))
.map(([label]) => label);
// ---- size: additions+deletions, excluding lockfiles/snapshots ----
const EXCLUDE_SIZE = /(^|\/)(uv\.lock|pnpm-lock\.yaml|package-lock\.json)$|\.snap$/;
const churn = files
.filter(f => !EXCLUDE_SIZE.test(f.filename))
.reduce((s, f) => s + (f.additions || 0) + (f.deletions || 0), 0);
const sizeLabel =
churn < 20 ? 'size/XS' :
churn < 100 ? 'size/S' :
churn < 300 ? 'size/M' :
churn < 700 ? 'size/L' : 'size/XL';
// ---- risk ----
const docsOnly = paths.length > 0 && paths.every(p =>
/\.(md|mdx|txt)$/i.test(p) || p.startsWith('docs/') ||
/\.(png|jpe?g|gif|svg|webp|ico)$/i.test(p));
const highRisk =
m(/^backend\/app\/gateway\//) ||
m(/^backend\/packages\/harness\/deerflow\/(agents|subagents|sandbox)\//) ||
m(/(^|\/)langgraph\.json$/) ||
m(/(^|\/)(auth|authz|security)/i) ||
m(/(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/) ||
m(/^docker\//) ||
m(/^\.github\/workflows\//);
const riskLabel = docsOnly ? 'risk:low' : (highRisk ? 'risk:high' : 'risk:medium');
// ---- needs-validation: front/back contract surface ----
const contract =
m(/^backend\/app\/gateway\//) ||
m(/^backend\/packages\/harness\/deerflow\/(agents|subagents)\//) ||
m(/(^|\/)langgraph\.json$/) ||
m(/^frontend\/src\/core\/(api|threads|messages)\//);
// ---- live current labels (NOT the stale event payload) ----
const current = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
owner, repo, issue_number: num, per_page: 100,
})).map(l => l.name);
const hasSkip = current.includes('skip-validation');
// Reconcile ONLY namespaces we own; never touch others.
const owned = (n) =>
n.startsWith('area:') || n.startsWith('size/') ||
n.startsWith('risk:') || n === 'needs-validation';
const desired = new Set([...areaLabels, sizeLabel, riskLabel]);
if (contract && !hasSkip) desired.add('needs-validation');
const toRemove = current.filter(n => owned(n) && !desired.has(n));
const toAdd = [...desired].filter(n => !current.includes(n));
// first-time-contributor: add-only, on opened, real users only.
if (context.payload.action === 'opened' &&
pr.user.type === 'User' &&
['FIRST_TIME_CONTRIBUTOR', 'FIRST_TIMER'].includes(pr.author_association) &&
!current.includes('first-time-contributor')) {
toAdd.push('first-time-contributor');
}
for (const name of toRemove) {
try {
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name });
} catch (e) {
if (e.status !== 404) throw e;
}
}
if (toAdd.length) {
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: toAdd });
}
core.info(`area=[${areaLabels.join(',')}] ${sizeLabel} ${riskLabel} churn=${churn} ` +
`validation=${desired.has('needs-validation')} ` +
`(+${toAdd.join(',') || '-'} / -${toRemove.join(',') || '-'})`);
# ── PR: reviewing label on a maintainer's human review ─────────────────────
reviewing:
if: github.event_name == 'pull_request_review'
runs-on: ubuntu-latest
concurrency:
group: triage-review-${{ github.event.pull_request.number }}
cancel-in-progress: false
steps:
- name: Add reviewing label for maintainer reviews
uses: actions/github-script@v8
with:
script: |
const { owner, repo } = context.repo;
const num = context.payload.pull_request.number;
const review = context.payload.review;
const assoc = review.author_association; // payload field; no API call
const type = review.user && review.user.type;
// author_association is NONE for every automated reviewer
// (Copilot, CodeRabbit, Codex, Sourcery, ...), so this allowlist
// drops them all without a denylist — and never calls the
// collaborators API that 404s on "Copilot is not a user".
// user.type === 'User' guards the rare bot-added-as-collaborator case.
if (!['OWNER', 'MEMBER', 'COLLABORATOR'].includes(assoc) || type !== 'User') {
core.info(`reviewer ${review.user && review.user.login} assoc=${assoc} type=${type}; skipping.`);
return;
}
const labels = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
owner, repo, issue_number: num, per_page: 100,
})).map(l => l.name);
if (labels.includes('reviewing')) {
core.info('Already labeled reviewing; skipping.');
return;
}
try {
await github.rest.issues.addLabels({
owner, repo, issue_number: num, labels: ['reviewing'],
});
core.info('Added "reviewing".');
} catch (e) {
if (e.status === 403) core.info('No permission to label (expected on some fork PRs).');
else throw e;
}
# ── Issue: needs-triage on every new issue ────────────────────────────────
issue-triage:
if: github.event_name == 'issues'
runs-on: ubuntu-latest
concurrency:
group: triage-issue-${{ github.event.issue.number }}
cancel-in-progress: false
steps:
- name: Add needs-triage label
uses: actions/github-script@v8
with:
script: |
const { owner, repo } = context.repo;
const issue_number = context.payload.issue.number;
// Read live labels (not the event payload) so labels added at creation
// time via the API or by another automation are seen — consistent with
// the live-state reads in the PR jobs above.
const current = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
owner, repo, issue_number, per_page: 100,
})).map(l => l.name);
if (current.includes('needs-triage')) {
core.info('Issue already has needs-triage; nothing to do.');
return;
}
// Self-heal: create the label if it does not exist yet.
try {
await github.rest.issues.createLabel({
owner, repo, name: 'needs-triage', color: 'fef2c0',
description: 'Awaiting maintainer triage',
});
} catch (e) {
if (e.status !== 422) throw e; // 422 = already exists
}
await github.rest.issues.addLabels({
owner, repo, issue_number, labels: ['needs-triage'],
});
core.info(`Added needs-triage to #${issue_number}.`);
-15
View File
@@ -287,21 +287,6 @@ Nginx (port 2026) ← Unified entry point
git push origin feature/your-feature-name
```
## AI assistance disclosure
DeerFlow is an AI project and we welcome AI-assisted contributions. To help
reviewers calibrate how closely to read a change, **every pull request must
complete the "AI assistance" section of the
[PR template](.github/pull_request_template.md)**:
- which tool(s) you used (or `none`),
- how you used them, and
- a confirmation that a human has read, understands, and takes responsibility
for the change.
Please don't delete the section. PRs that ignore it may be asked to fill it in
before review.
## Testing
```bash
+31 -1
View File
@@ -89,7 +89,36 @@ install:
# Pre-pull sandbox Docker image (optional but recommended)
setup-sandbox:
@$(RUN_WITH_GIT_BASH) ./scripts/setup-sandbox.sh
@echo "=========================================="
@echo " Pre-pulling Sandbox Container Image"
@echo "=========================================="
@echo ""
@IMAGE=$$(grep -A 20 "# sandbox:" config.yaml 2>/dev/null | grep "image:" | awk '{print $$2}' | head -1); \
if [ -z "$$IMAGE" ]; then \
IMAGE="enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"; \
echo "Using default image: $$IMAGE"; \
else \
echo "Using configured image: $$IMAGE"; \
fi; \
echo ""; \
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
echo "Detected Apple Container on macOS, pulling image..."; \
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
fi; \
if command -v docker >/dev/null 2>&1; then \
echo "Pulling image using Docker..."; \
if docker pull "$$IMAGE"; then \
echo ""; \
echo "✓ Sandbox image pulled successfully"; \
else \
echo ""; \
echo "⚠ Failed to pull sandbox image (this is OK for local sandbox mode)"; \
fi; \
else \
echo "✗ Neither Docker nor Apple Container is available"; \
echo " Please install Docker: https://docs.docker.com/get-docker/"; \
exit 1; \
fi
# Start all services in development mode (with hot-reloading)
dev:
@@ -119,6 +148,7 @@ stop:
clean: stop
@echo "Cleaning up..."
@-rm -rf backend/.deer-flow 2>/dev/null || true
@-rm -rf backend/.langgraph_api 2>/dev/null || true
@-rm -rf logs/*.log 2>/dev/null || true
@echo "✓ Cleanup complete"
-2
View File
@@ -585,8 +585,6 @@ A standard Agent Skill is a structured capability module — a Markdown file tha
Skills are loaded progressively — only when the task needs them, not all at once. This keeps the context window lean and makes DeerFlow work well even with token-sensitive models.
Users can explicitly activate an enabled skill for a single turn by starting the request with `/skill-name`, for example `/data-analysis analyze uploads/foo.csv`. DeerFlow loads that skill's `SKILL.md` as hidden current-turn context while leaving the base prompt limited to skill metadata. Slash activation respects disabled skills, custom-agent skill whitelists, and existing channel commands such as `/new` and `/help`.
When you install `.skill` archives through the Gateway, DeerFlow accepts standard optional frontmatter metadata such as `version`, `author`, and `compatibility` instead of rejecting otherwise valid external skills.
Tools follow the same philosophy. DeerFlow comes with a core toolset — web search, web fetch, file operations, bash execution — and supports custom tools via MCP servers and Python functions. Swap anything. Add anything.
-5
View File
@@ -24,10 +24,5 @@ config.yaml
# Langgraph
.langgraph_api
# Sandbox runtime working dir — pre-created and excluded from uvicorn reload
# (scripts/serve.sh, docker/dev-entrypoint.sh). Anchored so it does not match
# the source package backend/packages/harness/deerflow/sandbox/.
/sandbox/
# Claude Code settings
.claude/settings.local.json
+23 -18
View File
@@ -192,7 +192,7 @@ from deerflow.config import get_app_config
### Middleware Chain
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`build_middlewares`):
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
@@ -202,17 +202,16 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SkillActivationMiddleware** - Detects strict `/skill-name task` syntax on the latest real user message, resolves only enabled and runtime-allowed skills, reads `SKILL.md` from trusted skill storage, injects the skill body as hidden current-turn model context, and records a `middleware:skill_activation` audit event with skill name, category, path, and content hash
10. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
11. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
12. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
13. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
14. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
15. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
16. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`)
17. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
18. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
19. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
### Configuration System
@@ -224,9 +223,17 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
| Field | Why a restart is required |
|---|---|
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
Configuration priority:
1. Explicit `config_path` argument
@@ -264,7 +271,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
| **Uploads** (`/api/threads/{id}/uploads`) | `POST /` - upload files (auto-converts PDF/PPT/Excel/Word); `GET /list` - list; `DELETE /{filename}` - delete |
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized and inline reasoning (`<think>...</think>`, including unclosed/truncated blocks from reasoning models like MiniMax-M3) is stripped before JSON parsing |
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
@@ -306,7 +313,6 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` enforced by `SubagentLimitMiddleware` (truncates excess tool calls in `after_model`), 15-minute timeout
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
**Deferred MCP tools** (if `tool_search.enabled`): `SubagentExecutor._build_initial_state` assembles deferral after policy filtering via the shared `assemble_deferred_tools` (fail-closed), appends the `tool_search` tool, injects the `<available-deferred-tools>` section into the subagent's `SystemMessage`, and threads the setup to `_create_agent`, which attaches `DeferredToolFilterMiddleware` through `build_subagent_runtime_middlewares(deferred_setup=...)`. Subagents thus withhold full MCP schemas until promotion, same as the lead agent; each task run gets a fresh `ThreadState` so promotion is isolated per run
### Tool System (`packages/harness/deerflow/tools/`)
@@ -349,7 +355,6 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
- **Format**: Directory with `SKILL.md` (YAML frontmatter: name, description, license, allowed-tools)
- **Loading**: `load_skills()` recursively scans `skills/{public,custom}` for `SKILL.md`, parses metadata, and reads enabled state from extensions_config.json
- **Injection**: Enabled skills listed in agent system prompt with container paths
- **Slash activation**: `/skill-name task` loads that enabled skill's `SKILL.md` for the current model call only. The resolver rejects leading whitespace, missing separators, reserved channel commands (`/new`, `/help`, `/bootstrap`, `/status`, `/models`, `/memory`), disabled skills, and skills outside a custom agent's whitelist.
- **Installation**: `POST /api/skills/install` extracts .skill ZIP archive to custom/ directory
### Model Factory (`packages/harness/deerflow/models/factory.py`)
@@ -495,7 +500,7 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
- `"custom"` — forwarded from `StreamWriter`
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
- Agent created lazily via `create_agent()` + `build_middlewares()`, same as `make_lead_agent`
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
- Supports `checkpointer` parameter for state persistence across turns
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
- See [docs/STREAMING.md](docs/STREAMING.md) for the full design: why Gateway and DeerFlowClient are parallel paths, LangGraph's `stream_mode` semantics, the per-id dedup invariants, and regression testing strategy
+3 -3
View File
@@ -64,7 +64,7 @@ FROM builder AS dev
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
EXPOSE 8001
EXPOSE 8001 2024
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
@@ -94,8 +94,8 @@ WORKDIR /app
# Copy backend with pre-built virtualenv from builder
COPY --from=builder /app/backend ./backend
# Expose Gateway API port.
EXPOSE 8001
# Expose ports (gateway: 8001, langgraph: 2024)
EXPOSE 8001 2024
# Default command (can be overridden in docker-compose)
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run --no-sync uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
-7
View File
@@ -18,10 +18,3 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
"/help",
}
)
def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command."""
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
+4 -2
View File
@@ -14,7 +14,7 @@ from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -59,7 +59,9 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
def _is_dingtalk_command(text: str) -> bool:
return is_known_channel_command(text)
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
def _extract_text_from_rich_text(rich_text_list: list) -> str:
+2 -3
View File
@@ -10,7 +10,6 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -301,7 +300,7 @@ class DiscordChannel(Channel):
# If this is a known active thread, process normally
if thread_id in self._active_thread_ids:
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
@@ -408,7 +407,7 @@ class DiscordChannel(Channel):
chat_id = channel_id
typing_target = message.channel # Type into the channel
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
+13 -190
View File
@@ -7,30 +7,22 @@ import json
import logging
import re
import threading
import time
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
logger = logging.getLogger(__name__)
PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60
def _is_feishu_command(text: str) -> bool:
return is_known_channel_command(text)
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
class FeishuChannel(Channel):
@@ -64,7 +56,6 @@ class FeishuChannel(Channel):
self._background_tasks: set[asyncio.Task] = set()
self._running_card_ids: dict[str, str] = {}
self._running_card_tasks: dict[str, asyncio.Task] = {}
self._pending_clarifications: dict[tuple[str, str], list[dict[str, Any]]] = {}
self._CreateFileRequest = None
self._CreateFileRequestBody = None
self._CreateImageRequest = None
@@ -72,16 +63,6 @@ class FeishuChannel(Channel):
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
@staticmethod
def _non_empty_str(value: Any) -> str | None:
if isinstance(value, str) and value.strip():
return value.strip()
return None
@staticmethod
def _pending_key(chat_id: str, user_id: str) -> tuple[str, str]:
return (chat_id, user_id)
@property
def supports_streaming(self) -> bool:
return True
@@ -550,25 +531,18 @@ class FeishuChannel(Channel):
"[Feishu] failed to patch running card %s, falling back to final reply",
running_card_id,
)
fallback_card_id = await self._reply_card(source_message_id, msg.text)
self._remember_thread_mapping(msg, source_message_id, fallback_card_id)
self._remember_pending_clarification(msg, fallback_card_id)
await self._reply_card(source_message_id, msg.text)
else:
self._remember_thread_mapping(msg, source_message_id, running_card_id)
self._remember_pending_clarification(msg, running_card_id)
logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id)
elif msg.is_final:
final_card_id = await self._reply_card(source_message_id, msg.text)
self._remember_thread_mapping(msg, source_message_id, final_card_id)
self._remember_pending_clarification(msg, final_card_id)
await self._reply_card(source_message_id, msg.text)
elif awaited_running_card_task:
logger.warning(
"[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation",
source_message_id,
)
else:
created_card_id = await self._ensure_running_card(source_message_id, msg.text)
self._remember_thread_mapping(msg, source_message_id, created_card_id)
await self._ensure_running_card(source_message_id, msg.text)
if msg.is_final:
self._running_card_ids.pop(source_message_id, None)
@@ -579,129 +553,6 @@ class FeishuChannel(Channel):
# -- internal ----------------------------------------------------------
def _remember_thread_mapping(self, msg: OutboundMessage, *topic_ids: str | None) -> None:
store = self.config.get("channel_store")
if store is None or not msg.thread_id:
return
metadata_topic_ids = [
msg.metadata.get("message_id"),
msg.metadata.get("root_id"),
msg.metadata.get("parent_id"),
msg.metadata.get("thread_id"),
msg.metadata.get("topic_id"),
]
user_id = ""
raw_user_id = msg.metadata.get("user_id")
if isinstance(raw_user_id, str):
user_id = raw_user_id
seen: set[str] = set()
for topic_id in [*topic_ids, *metadata_topic_ids]:
topic_id = self._non_empty_str(topic_id)
if not topic_id or topic_id in seen:
continue
seen.add(topic_id)
try:
store.set_thread_id(
self.name,
msg.chat_id,
msg.thread_id,
topic_id=topic_id,
user_id=user_id,
)
except Exception:
logger.exception("[Feishu] failed to remember thread mapping for topic_id=%s", topic_id)
def _remember_pending_clarification(self, msg: OutboundMessage, card_message_id: str | None) -> None:
if not msg.is_final or msg.metadata.get(PENDING_CLARIFICATION_METADATA_KEY) is not True:
return
user_id = self._non_empty_str(msg.metadata.get("user_id"))
topic_id = self._non_empty_str(msg.metadata.get("topic_id"))
source_message_id = self._non_empty_str(msg.thread_ts) or self._non_empty_str(msg.metadata.get("message_id"))
if not (user_id and topic_id and msg.thread_id and source_message_id and card_message_id):
return
key = self._pending_key(msg.chat_id, user_id)
pending = {
"thread_id": msg.thread_id,
"topic_id": topic_id,
"source_message_id": source_message_id,
"card_message_id": card_message_id,
"created_at": time.time(),
}
with self._thread_lock:
# Plain-message clarification continuity is a short-lived in-memory
# hint; explicit Feishu replies are still covered by persisted
# message-id mappings.
self._pending_clarifications.setdefault(key, []).append(pending)
logger.info(
"[Feishu] pending clarification remembered: chat_id=%s user_id=%s topic_id=%s thread_id=%s",
msg.chat_id,
user_id,
topic_id,
msg.thread_id,
)
def _consume_pending_clarification(self, chat_id: str, user_id: str) -> dict[str, Any] | None:
key = self._pending_key(chat_id, user_id)
with self._thread_lock:
pending_items = self._pending_clarifications.get(key)
if not pending_items:
return None
now = time.time()
while pending_items:
pending = pending_items.pop(0)
created_at = pending.get("created_at")
if isinstance(created_at, (int, float)) and now - created_at <= PENDING_CLARIFICATION_TTL_SECONDS:
if pending_items:
self._pending_clarifications[key] = pending_items
else:
self._pending_clarifications.pop(key, None)
return pending
logger.info("[Feishu] pending clarification expired: chat_id=%s user_id=%s", chat_id, user_id)
self._pending_clarifications.pop(key, None)
return None
def _ensure_pending_thread_mapping(self, chat_id: str, user_id: str, pending: dict[str, Any]) -> None:
store = self.config.get("channel_store")
topic_id = self._non_empty_str(pending.get("topic_id"))
thread_id = self._non_empty_str(pending.get("thread_id"))
if store is None or not topic_id or not thread_id:
return
try:
store.set_thread_id(self.name, chat_id, thread_id, topic_id=topic_id, user_id=user_id)
except Exception:
logger.exception("[Feishu] failed to restore pending clarification mapping for topic_id=%s", topic_id)
def _resolve_topic_id(
self,
chat_id: str,
msg_id: str,
*,
root_id: str | None,
parent_id: str | None,
thread_id: str | None,
) -> tuple[str, bool]:
store = self.config.get("channel_store")
candidates = [root_id, parent_id, thread_id]
if store is not None:
for candidate in candidates:
candidate = self._non_empty_str(candidate)
if not candidate:
continue
try:
if store.get_thread_id(self.name, chat_id, topic_id=candidate):
return candidate, True
except Exception:
logger.exception("[Feishu] failed to resolve stored topic mapping for topic_id=%s", candidate)
return root_id or msg_id, False
@staticmethod
def _log_future_error(fut, name: str, msg_id: str) -> None:
"""Callback for run_coroutine_threadsafe futures to surface errors."""
@@ -742,9 +593,7 @@ class FeishuChannel(Channel):
# root_id is set when the message is a reply within a Feishu thread.
# Use it as topic_id so all replies share the same DeerFlow thread.
root_id = self._non_empty_str(getattr(message, "root_id", None))
parent_id = self._non_empty_str(getattr(message, "parent_id", None))
feishu_thread_id = self._non_empty_str(getattr(message, "thread_id", None))
root_id = getattr(message, "root_id", None) or None
# Parse message content
content = json.loads(message.content)
@@ -805,12 +654,10 @@ class FeishuChannel(Channel):
text = text.strip()
logger.info(
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r",
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, sender=%s, text=%r",
chat_id,
msg_id,
root_id,
parent_id,
feishu_thread_id,
sender_id,
text[:100] if text else "",
)
@@ -826,24 +673,8 @@ class FeishuChannel(Channel):
else:
msg_type = InboundMessageType.CHAT
# Prefer any platform message id that already maps to a DeerFlow
# thread. This keeps replies to bot clarification cards in the
# original conversation even when Feishu reports the card as root.
topic_id, resolved_from_stored_mapping = self._resolve_topic_id(
chat_id,
msg_id,
root_id=root_id,
parent_id=parent_id,
thread_id=feishu_thread_id,
)
resolved_from_pending = False
if msg_type == InboundMessageType.CHAT and not resolved_from_stored_mapping:
pending = self._consume_pending_clarification(chat_id, sender_id)
pending_topic_id = self._non_empty_str(pending.get("topic_id")) if pending else None
if pending_topic_id:
topic_id = pending_topic_id
self._ensure_pending_thread_mapping(chat_id, sender_id, pending)
resolved_from_pending = True
# topic_id: use root_id for replies (same topic), msg_id for new messages (new topic)
topic_id = root_id or msg_id
inbound = self._make_inbound(
chat_id=chat_id,
@@ -852,15 +683,7 @@ class FeishuChannel(Channel):
msg_type=msg_type,
thread_ts=msg_id,
files=files_list,
metadata={
"message_id": msg_id,
"root_id": root_id,
"parent_id": parent_id,
"thread_id": feishu_thread_id,
"topic_id": topic_id,
"user_id": sender_id,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY: resolved_from_pending,
},
metadata={"message_id": msg_id, "root_id": root_id},
)
inbound.topic_id = topic_id
+20 -200
View File
@@ -8,7 +8,6 @@ import mimetypes
import re
import time
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@@ -16,24 +15,11 @@ import httpx
from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.config.agents_config import load_agent_config
from deerflow.config.paths import make_safe_user_id
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.slash import parse_slash_skill_reference
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY
logger = logging.getLogger(__name__)
@@ -130,16 +116,6 @@ class InvalidChannelSessionConfigError(ValueError):
"""Raised when IM channel session overrides contain invalid agent config."""
class SlashSkillCommandResolutionError(RuntimeError):
"""Raised when IM slash-skill command resolution cannot complete safely."""
@dataclass(frozen=True, slots=True)
class _SlashSkillCommandResolution:
route_to_chat: bool = False
failure_message: str | None = None
def _is_thread_busy_error(exc: BaseException | None) -> bool:
if exc is None:
return False
@@ -226,54 +202,6 @@ def _extract_response_text(result: dict | list) -> str:
return ""
def _messages_from_result(result: dict | list) -> list[Any]:
if isinstance(result, list):
return result
if isinstance(result, dict):
messages = result.get("messages", [])
if isinstance(messages, list):
return messages
return []
def _current_turn_messages(result: dict | list) -> list[dict[str, Any]]:
messages = _messages_from_result(result)
current_turn: list[dict[str, Any]] = []
for msg in reversed(messages):
if not isinstance(msg, dict):
continue
if msg.get("type") == "human":
break
current_turn.append(msg)
current_turn.reverse()
return current_turn
def _has_current_turn_clarification(result: dict | list) -> bool:
"""Return True only when the current turn's final result is clarification."""
for msg in reversed(_current_turn_messages(result)):
msg_type = msg.get("type")
if msg_type == "tool":
return msg.get("name") == "ask_clarification"
if msg_type == "ai":
content = msg.get("content")
if isinstance(content, str):
if content:
return False
elif content:
return False
if msg.get("tool_calls"):
return False
return False
def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification: bool = False) -> dict[str, Any]:
metadata = _slim_metadata(base_metadata)
if pending_clarification:
metadata[PENDING_CLARIFICATION_METADATA_KEY] = True
return metadata
def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field."""
if isinstance(content, str):
@@ -426,46 +354,6 @@ def _format_artifact_text(artifacts: list[str]) -> str:
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
def _unknown_command_reply(command: str | None = None) -> str:
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
if command:
return f"Unknown command: /{command}. Available commands: {available}"
return f"Unknown command. Available commands: {available}"
def _human_input_message(content: str, *, original_content: str | None = None) -> dict[str, Any]:
message: dict[str, Any] = {"role": "human", "content": content}
if original_content is not None and original_content != content:
message["additional_kwargs"] = {ORIGINAL_USER_CONTENT_KEY: original_content}
return message
def _resolve_slash_skill_command(
text: str,
available_skills: set[str] | None = None,
storage: SkillStorage | Callable[[], SkillStorage] | None = None,
) -> _SlashSkillCommandResolution | None:
reference = parse_slash_skill_reference(text)
if reference is None:
return None
try:
resolved_storage = storage() if callable(storage) else storage or get_or_new_skill_storage()
skills = resolved_storage.load_skills(enabled_only=False)
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
if skill is None:
return None
if not skill.enabled:
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
if available_skills is not None and reference.name not in available_skills:
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
return _SlashSkillCommandResolution(route_to_chat=True)
except Exception as exc:
logger.exception("[Manager] failed to resolve slash skill command")
raise SlashSkillCommandResolutionError("Failed to resolve slash skill command. Please check the skill configuration.") from exc
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
"""Resolve virtual artifact paths to host filesystem paths with metadata.
@@ -680,7 +568,6 @@ class ChannelManager:
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._client = None # lazy init — langgraph_sdk async client
self._skill_storage: SkillStorage | None = None
self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None
self._running = False
@@ -728,20 +615,12 @@ class ChannelManager:
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
# ``user_id`` drives user-scoped filesystem buckets that only accept
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
# under ``channel_user_id`` for platform-facing lookups.
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id
run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
channel_layer.get("context"),
user_layer.get("context"),
run_context_identity,
{"thread_id": thread_id},
)
# Custom agents are implemented as lead_agent + agent_name context.
@@ -753,21 +632,6 @@ class ChannelManager:
return assistant_id, run_config, run_context
def _resolve_available_skill_names(self, msg: InboundMessage) -> set[str] | None:
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or ""
_, _, run_context = self._resolve_run_params(msg, thread_id)
if run_context.get("is_bootstrap"):
return {"bootstrap"}
agent_name = run_context.get("agent_name")
if not isinstance(agent_name, str) or not agent_name.strip():
return None
agent_config = load_agent_config(_normalize_custom_agent_name(agent_name))
if agent_config and agent_config.skills is not None:
return set(agent_config.skills)
return None
# -- LangGraph SDK client (lazy) ----------------------------------------
def _get_client(self):
@@ -785,11 +649,6 @@ class ChannelManager:
)
return self._client
def _get_skill_storage(self) -> SkillStorage:
if self._skill_storage is None:
self._skill_storage = get_or_new_skill_storage()
return self._skill_storage
# -- lifecycle ---------------------------------------------------------
async def start(self) -> None:
@@ -859,14 +718,6 @@ class ChannelManager:
exc,
)
await self._send_error(msg, str(exc))
except SlashSkillCommandResolutionError as exc:
logger.warning(
"Slash skill command resolution failed for %s (chat=%s): %s",
msg.channel_name,
msg.chat_id,
exc,
)
await self._send_error(msg, str(exc))
except Exception:
logger.exception(
"Error handling message from %s (chat=%s)",
@@ -921,11 +772,9 @@ class ChannelManager:
if extra_context:
run_context.update(extra_context)
original_text = msg.text
uploaded = await _ingest_inbound_files(thread_id, msg)
if uploaded:
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
human_message = _human_input_message(msg.text, original_content=original_text)
if self._channel_supports_streaming(msg.channel_name):
await self._handle_streaming_chat(
@@ -935,7 +784,6 @@ class ChannelManager:
assistant_id,
run_config,
run_context,
human_message,
)
return
@@ -944,7 +792,7 @@ class ChannelManager:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [human_message]},
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
multitask_strategy="reject",
@@ -958,7 +806,6 @@ class ChannelManager:
raise
response_text = _extract_response_text(result)
pending_clarification = _has_current_turn_clarification(result)
artifacts = _extract_artifacts(result)
logger.info(
@@ -984,7 +831,7 @@ class ChannelManager:
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
metadata=_slim_metadata(msg.metadata),
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
await self.bus.publish_outbound(outbound)
@@ -997,7 +844,6 @@ class ChannelManager:
assistant_id: str,
run_config: dict[str, Any],
run_context: dict[str, Any],
human_message: dict[str, Any],
) -> None:
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
@@ -1013,7 +859,7 @@ class ChannelManager:
async for chunk in client.runs.stream(
thread_id,
assistant_id,
input={"messages": [human_message]},
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
@@ -1047,7 +893,7 @@ class ChannelManager:
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
metadata=_response_metadata(msg.metadata),
metadata=_slim_metadata(msg.metadata),
)
)
last_published_text = latest_text
@@ -1061,7 +907,6 @@ class ChannelManager:
finally:
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
response_text = _extract_response_text(result)
pending_clarification = _has_current_turn_clarification(result)
artifacts = _extract_artifacts(result)
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
@@ -1093,27 +938,18 @@ class ChannelManager:
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
metadata=_slim_metadata(msg.metadata),
)
)
# -- command handling --------------------------------------------------
async def _handle_command(self, msg: InboundMessage) -> None:
raw_text = msg.text
text = raw_text.strip()
text = msg.text.strip()
parts = text.split(maxsplit=1)
reply: str | None = None
if not parts:
command = None
reply = _unknown_command_reply()
else:
command = parts[0].lower().removeprefix("/")
command = parts[0].lower().lstrip("/")
if reply is None and not raw_text.startswith("/"):
reply = _unknown_command_reply(command)
if reply is None and command == "bootstrap":
if command == "bootstrap":
from dataclasses import replace as _dc_replace
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
@@ -1121,7 +957,7 @@ class ChannelManager:
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
return
if reply is None and command == "new":
if command == "new":
# Create a new thread through Gateway
client = self._get_client()
thread = await client.threads.create()
@@ -1134,14 +970,14 @@ class ChannelManager:
user_id=msg.user_id,
)
reply = "New conversation started."
elif reply is None and command == "status":
elif command == "status":
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif reply is None and command == "models":
elif command == "models":
reply = await self._fetch_gateway("/api/models", "models")
elif reply is None and command == "memory":
elif command == "memory":
reply = await self._fetch_gateway("/api/memory", "memory")
elif reply is None and command == "help":
elif command == "help":
reply = (
"Available commands:\n"
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
@@ -1149,32 +985,16 @@ class ChannelManager:
"/status — Show current thread info\n"
"/models — List available models\n"
"/memory — Show memory status\n"
"/<skill-name> <task> — Activate an enabled skill for one turn\n"
"/help — Show this help"
)
elif reply is None:
slash_resolution = await asyncio.to_thread(
lambda: _resolve_slash_skill_command(
raw_text,
self._resolve_available_skill_names(msg),
self._get_skill_storage,
)
)
if slash_resolution and slash_resolution.failure_message:
reply = slash_resolution.failure_message
elif slash_resolution and slash_resolution.route_to_chat:
from dataclasses import replace as _dc_replace
chat_msg = _dc_replace(msg, msg_type=InboundMessageType.CHAT)
await self._handle_chat(chat_msg)
return
else:
reply = _unknown_command_reply(command)
else:
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
reply = f"Unknown command: /{command}. Available commands: {available}"
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=reply,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
@@ -1212,7 +1032,7 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=error_text,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
-3
View File
@@ -13,9 +13,6 @@ from typing import Any
logger = logging.getLogger(__name__)
PENDING_CLARIFICATION_METADATA_KEY = "pending_clarification"
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY = "resolved_from_pending_clarification"
# ---------------------------------------------------------------------------
# Message types
+1 -37
View File
@@ -9,7 +9,6 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -33,20 +32,6 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
return {str(user_id) for user_id in values if str(user_id)}
def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
if not bot_user_id:
return text
if not text.startswith("<@"):
return text
end = text.find(">")
if end <= 2:
return text
mentioned_user_id = text[2:end].split("|", 1)[0].lstrip("!")
if mentioned_user_id != bot_user_id:
return text
return text[end + 1 :].lstrip()
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
@@ -64,8 +49,6 @@ class SlackChannel(Channel):
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
configured_bot_user_id = config.get("bot_user_id")
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
async def start(self) -> None:
if self._running:
@@ -89,17 +72,6 @@ class SlackChannel(Channel):
return
self._web_client = WebClient(token=bot_token)
if self._bot_user_id is None:
try:
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
self._socket_client = SocketModeClient(
app_token=app_token,
web_client=self._web_client,
@@ -238,12 +210,6 @@ class SlackChannel(Channel):
if event_type != "events_api":
return
if self._bot_user_id is None:
authorization = next((item for item in req.payload.get("authorizations", []) if isinstance(item, dict)), None)
user_id = authorization.get("user_id") if authorization else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
event = req.payload.get("event", {})
etype = event.get("type", "")
@@ -267,15 +233,13 @@ class SlackChannel(Channel):
return
text = event.get("text", "").strip()
if event.get("type") == "app_mention":
text = _strip_leading_slack_bot_mention(text, self._bot_user_id)
if not text:
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
if is_known_channel_command(text):
if text.startswith("/"):
msg_type = InboundMessageType.COMMAND
else:
msg_type = InboundMessageType.CHAT
+2 -34
View File
@@ -60,17 +60,12 @@ class TelegramChannel(Channel):
# Command handlers
app.add_handler(CommandHandler("start", self._cmd_start))
app.add_handler(CommandHandler("bootstrap", self._cmd_generic))
app.add_handler(CommandHandler("new", self._cmd_generic))
app.add_handler(CommandHandler("status", self._cmd_generic))
app.add_handler(CommandHandler("models", self._cmd_generic))
app.add_handler(CommandHandler("memory", self._cmd_generic))
app.add_handler(CommandHandler("help", self._cmd_generic))
# Slash skill commands are dynamic and cannot all be pre-registered
# with Telegram, so route unknown slash commands through chat handling.
app.add_handler(MessageHandler(filters.TEXT & filters.COMMAND, self._on_text))
# General message handler
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
@@ -233,33 +228,6 @@ class TelegramChannel(Channel):
return True
return user_id in self._allowed_users
def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None)
username = getattr(bot, "username", None)
if not username and self._application is not None:
username = getattr(getattr(self._application, "bot", None), "username", None)
return str(username) if username else None
@staticmethod
def _strip_bot_username_from_leading_command(text: str, bot_username: str | None) -> str:
username = (bot_username or "").lstrip("@").lower()
if not username or not text.startswith("/"):
return text
parts = text.split(maxsplit=1)
command_token = parts[0]
if "@" not in command_token:
return text
command_name, addressed_username = command_token[1:].rsplit("@", 1)
if not command_name or addressed_username.lower() != username:
return text
normalized = f"/{command_name}"
if len(parts) > 1:
normalized = f"{normalized} {parts[1]}"
return normalized
async def _cmd_start(self, update, context) -> None:
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
@@ -275,7 +243,7 @@ class TelegramChannel(Channel):
if not self._check_user(update.effective_user.id):
return
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
text = update.message.text
chat_id = str(update.effective_chat.id)
user_id = str(update.effective_user.id)
msg_id = str(update.message.message_id)
@@ -311,7 +279,7 @@ class TelegramChannel(Channel):
if not self._check_user(update.effective_user.id):
return
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
text = update.message.text.strip()
if not text:
return
+1 -2
View File
@@ -22,7 +22,6 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -621,7 +620,7 @@ class WechatChannel(Channel):
chat_id=chat_id,
user_id=chat_id,
text=text,
msg_type=InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT,
msg_type=InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT,
thread_ts=thread_ts,
files=files,
metadata={
+1 -2
View File
@@ -8,7 +8,6 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import (
InboundMessageType,
MessageBus,
@@ -271,7 +270,7 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid")
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
user_id=user_id,
-19
View File
@@ -179,25 +179,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
config = get_gateway_config()
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Pre-warm tiktoken encoding cache so the first memory-injection request
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
# that may be unreachable in restricted networks — see issue #3402).
try:
from deerflow.agents.memory.prompt import warm_tiktoken_cache
warmed = await asyncio.wait_for(
asyncio.to_thread(warm_tiktoken_cache),
timeout=5,
)
if warmed:
logger.info("tiktoken encoding cache warmed successfully")
else:
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
except TimeoutError:
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
except Exception:
logger.warning("tiktoken warm-up skipped", exc_info=True)
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app, startup_config):
logger.info("LangGraph runtime initialised")
-56
View File
@@ -17,7 +17,6 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
@@ -34,43 +33,6 @@ from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
# Upper bound (seconds) for draining in-flight runs during shutdown, before the
# AsyncExitStack tears down the checkpointer (and its connection pool). Kept
# local to avoid an app -> deps -> app import cycle. This is a *separate* budget
# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s,
# which bounds channel-service stop): the two govern independent teardown steps
# and may diverge, but both count toward the lifespan shutdown window — revisit
# them together if their sum must stay within the server's graceful-shutdown
# timeout.
_RUN_DRAIN_TIMEOUT_SECONDS = 5.0
async def _drain_inflight_runs(run_manager: RunManager) -> None:
"""Drain in-flight runs before the checkpointer is torn down (issue #3373).
Shields the (internally-bounded) drain so that even if the lifespan
coroutine is itself cancelled mid-shutdown — a second SIGINT or the server's
graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the
checkpointer pool is not closed while run tasks are still writing
checkpoints. On such a cancellation we let the already-running drain finish
(it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate
the cancellation.
"""
drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS))
try:
await asyncio.shield(drain)
except asyncio.CancelledError:
# Re-shield so this second wait does not abandon the in-flight drain;
# it is bounded, so this cannot hang. Then re-raise to honour shutdown.
try:
await asyncio.shield(drain)
except Exception:
logger.exception("In-flight run drain failed after shutdown cancellation")
raise
except Exception:
logger.exception("Failed to drain in-flight runs during shutdown")
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
@@ -119,16 +81,6 @@ def get_config() -> AppConfig:
split-brain where the worker / lead-agent thread saw a stale startup
snapshot.
Hot-reload boundary: fields backed by startup-time singletons
(engines, sandbox provider, IM channels, logging handler) require a
process restart to change at runtime. The authoritative list lives in
:mod:`deerflow.config.reload_boundary` and is mirrored by the
standardised ``"startup-only:"`` prefix on the matching
``Field(description=...)`` in :class:`AppConfig` — IDE hover on those
fields will surface the boundary inline. See
``backend/CLAUDE.md`` "Config Hot-Reload Boundary" for the operator
summary.
Any failure to materialise the config (missing file, permission denied,
YAML parse error, validation error) is reported as 503 — semantically
"the gateway cannot serve requests without a usable configuration" — and
@@ -225,14 +177,6 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
try:
yield
finally:
# Drain in-flight run tasks BEFORE the AsyncExitStack tears down the
# checkpointer (and its connection pool). A run still mid-graph would
# otherwise leak into asyncio.run() shutdown, where langgraph's
# _checkpointer_put_after_previous aput races the closed pool and
# raises PoolClosed (issue #3373).
run_manager = getattr(app.state, "run_manager", None)
if run_manager is not None:
await _drain_inflight_runs(run_manager)
await close_engine()
+1 -2
View File
@@ -10,7 +10,6 @@ from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal"
def _load_internal_auth_token() -> str:
@@ -35,4 +34,4 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
-15
View File
@@ -1,15 +0,0 @@
"""Shared pagination helpers for gateway routers."""
from __future__ import annotations
def trim_run_message_page(rows: list[dict], *, limit: int, after_seq: int | None) -> tuple[list[dict], bool]:
"""Trim a ``limit + 1`` run-message page while preserving page boundaries."""
has_more = len(rows) > limit
if not has_more:
return rows, False
if after_seq is not None:
return rows[:limit], True
return rows[-limit:], True
+48 -73
View File
@@ -1,6 +1,5 @@
"""CRUD API for custom agents."""
import asyncio
import logging
import re
import shutil
@@ -214,60 +213,47 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
user_id = get_effective_user_id()
paths = get_paths()
def _create_agent() -> AgentResponse | None:
# Worker thread: base-dir resolution, existence checks, directory/file
# creation, read-back, and failure cleanup are all blocking filesystem
# IO that must stay off the event loop.
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
if legacy_dir.exists():
return None # signals 409 to the caller
try:
try:
agent_dir.mkdir(parents=True, exist_ok=False)
except FileExistsError:
return None # signals 409 to the caller
# Write config.yaml
config_data: dict = {"name": normalized_name}
if request.description:
config_data["description"] = request.description
if request.model is not None:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
# Write SOUL.md
soul_file = agent_dir / "SOUL.md"
soul_file.write_text(request.soul, encoding="utf-8")
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except Exception:
# Clean up partial state on failure before surfacing the error.
if agent_dir.exists():
shutil.rmtree(agent_dir)
raise
try:
response = await asyncio.to_thread(_create_agent)
except Exception as e:
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
if response is None:
if agent_dir.exists() or legacy_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
return response
try:
agent_dir.mkdir(parents=True, exist_ok=True)
# Write config.yaml
config_data: dict = {"name": normalized_name}
if request.description:
config_data["description"] = request.description
if request.model is not None:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
# Write SOUL.md
soul_file = agent_dir / "SOUL.md"
soul_file.write_text(request.soul, encoding="utf-8")
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except HTTPException:
raise
except Exception as e:
# Clean up on failure
if agent_dir.exists():
shutil.rmtree(agent_dir)
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
@router.put(
@@ -442,30 +428,19 @@ async def delete_agent(name: str) -> None:
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
def _remove_agent_dir() -> tuple[str, str]:
# Runs in a worker thread: resolving the base dir, probing the directory
# (`exists`), and removing it (`rmtree`) are all blocking filesystem IO
# that must stay off the event loop.
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists():
outcome = "legacy" if paths.agent_dir(name).exists() else "missing"
return outcome, str(agent_dir)
shutil.rmtree(agent_dir)
return "deleted", str(agent_dir)
if not agent_dir.exists():
if paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
try:
outcome, agent_dir = await asyncio.to_thread(_remove_agent_dir)
shutil.rmtree(agent_dir)
logger.info(f"Deleted agent '{name}' from {agent_dir}")
except Exception as e:
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
if outcome == "legacy":
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
if outcome == "missing":
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
logger.info(f"Deleted agent '{name}' from {agent_dir}")
+4 -89
View File
@@ -1,10 +1,9 @@
import json
import logging
import os
from pathlib import Path
from typing import Literal
from fastapi import APIRouter, HTTPException, Request, status
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
@@ -13,11 +12,6 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
_MCP_STDIO_COMMAND_ALLOWLIST_ENV = "DEER_FLOW_MCP_STDIO_COMMAND_ALLOWLIST"
_DEFAULT_MCP_STDIO_COMMAND_ALLOWLIST = frozenset({"npx", "uvx"})
_SHELL_METACHARS = frozenset(";|&`$<>\n\r")
class McpOAuthConfigResponse(BaseModel):
"""OAuth configuration for an MCP server."""
@@ -72,78 +66,6 @@ class McpConfigUpdateRequest(BaseModel):
_MASKED_VALUE = "***"
async def _require_admin_user(request: Request) -> None:
"""Require the authenticated caller to be an admin user.
``AuthMiddleware`` normally stamps ``request.state.user`` before the
request reaches this router. Falling back to the strict dependency keeps
this route safe even in tests or alternative ASGI compositions that mount
the router without the global middleware.
"""
user = getattr(request.state, "user", None)
if user is None:
from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
if getattr(user, "system_role", None) != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required to manage MCP configuration.",
)
def _allowed_stdio_commands() -> set[str]:
"""Return executable names allowed for API-managed stdio MCP servers."""
raw = os.environ.get(_MCP_STDIO_COMMAND_ALLOWLIST_ENV)
base = set(_DEFAULT_MCP_STDIO_COMMAND_ALLOWLIST)
if raw is None:
return base
extra = {item.strip() for item in raw.split(",") if item.strip()}
return base | extra
def _stdio_command_name(command: str | None, *, server_name: str) -> str:
"""Normalize and validate a stdio command field from the API boundary."""
if command is None or not command.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"MCP server '{server_name}' with stdio transport requires a command.",
)
stripped = command.strip()
has_path_separator = "/" in stripped or "\\" in stripped
if stripped != command or has_path_separator or any(ch.isspace() for ch in stripped) or any(ch in stripped for ch in _SHELL_METACHARS):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(f"MCP server '{server_name}' command must be a single executable name; put parameters in args instead."),
)
return stripped
def _validate_mcp_update_request(request: McpConfigUpdateRequest) -> None:
"""Validate API-submitted MCP config before it is persisted.
Local config files can still express arbitrary advanced setups, but the
HTTP API is an untrusted boundary. Restricting stdio commands here reduces
the blast radius of a compromised authenticated browser session.
"""
allowed_commands = _allowed_stdio_commands()
for name, server in request.mcp_servers.items():
transport_type = (server.type or "stdio").lower()
if transport_type != "stdio":
continue
command_name = _stdio_command_name(server.command, server_name=name)
if command_name not in allowed_commands:
allowed = ", ".join(sorted(allowed_commands)) or "<none>"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(f"MCP server '{name}' uses disallowed stdio command '{command_name}'. Allowed commands: {allowed}. Configure {_MCP_STDIO_COMMAND_ALLOWLIST_ENV} to extend this list."),
)
def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse:
"""Return a copy of server config with sensitive fields masked.
@@ -240,7 +162,7 @@ def _merge_preserving_secrets(
summary="Get MCP Configuration",
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
)
async def get_mcp_configuration(request: Request) -> McpConfigResponse:
async def get_mcp_configuration() -> McpConfigResponse:
"""Get the current MCP configuration.
Returns:
@@ -261,8 +183,6 @@ async def get_mcp_configuration(request: Request) -> McpConfigResponse:
}
```
"""
await _require_admin_user(request)
config = get_extensions_config()
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()}
@@ -275,7 +195,7 @@ async def get_mcp_configuration(request: Request) -> McpConfigResponse:
summary="Update MCP Configuration",
description="Update Model Context Protocol (MCP) server configurations and save to file.",
)
async def update_mcp_configuration(request: Request, body: McpConfigUpdateRequest) -> McpConfigResponse:
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
"""Update the MCP configuration.
This will:
@@ -308,9 +228,6 @@ async def update_mcp_configuration(request: Request, body: McpConfigUpdateReques
```
"""
try:
await _require_admin_user(request)
_validate_mcp_update_request(body)
# Get the current config path (or determine where to save it)
config_path = ExtensionsConfig.resolve_config_path()
@@ -338,7 +255,7 @@ async def update_mcp_configuration(request: Request, body: McpConfigUpdateReques
# Merge incoming server configs with raw on-disk secrets
merged_servers: dict[str, McpServerConfigResponse] = {}
for name, incoming in body.mcp_servers.items():
for name, incoming in request.mcp_servers.items():
raw_server = raw_servers.get(name)
if raw_server is not None:
merged_servers[name] = _merge_preserving_secrets(
@@ -366,8 +283,6 @@ async def update_mcp_configuration(request: Request, body: McpConfigUpdateReques
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
return McpConfigResponse(mcp_servers=servers)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")
+2 -2
View File
@@ -15,7 +15,6 @@ from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.pagination import trim_run_message_page
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import serialize_channel_values
@@ -130,7 +129,8 @@ async def run_messages(
before_seq=before_seq,
after_seq=after_seq,
)
data, has_more = trim_run_message_page(rows, limit=limit, after_seq=after_seq)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
+1 -28
View File
@@ -1,6 +1,5 @@
import json
import logging
import re
from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage
@@ -31,31 +30,6 @@ class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
# Matches a complete <think>...</think> block (case-insensitive, spans newlines).
_THINK_BLOCK_RE = re.compile(r"<think\b[^>]*>.*?</think\s*>", re.IGNORECASE | re.DOTALL)
# Matches a dangling, unclosed <think> (model truncated at max_tokens mid-thought).
_OPEN_THINK_RE = re.compile(r"<think\b[^>]*>", re.IGNORECASE)
def _strip_think_blocks(text: str) -> str:
"""Remove reasoning-model ``<think>...</think>`` blocks from the response.
Reasoning models such as MiniMax-M3 inline their chain-of-thought into the
message ``content`` wrapped in ``<think>...</think>`` (``reasoning_split``
defaults to false), rather than exposing a separate ``reasoning_content``
field. The thinking text frequently contains ``[`` / ``]`` characters, which
corrupted the downstream ``find('[')`` / ``rfind(']')`` JSON extraction and
produced empty suggestions. We strip the reasoning before parsing so only
the actual answer remains.
"""
text = _THINK_BLOCK_RE.sub("", text)
# Drop any unclosed <think> (and everything after it) left by truncation.
open_match = _OPEN_THINK_RE.search(text)
if open_match:
text = text[: open_match.start()]
return text.strip()
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
@@ -67,8 +41,7 @@ def _strip_markdown_code_fence(text: str) -> str:
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_think_blocks(text)
candidate = _strip_markdown_code_fence(candidate)
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
+2 -2
View File
@@ -21,7 +21,6 @@ from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.pagination import trim_run_message_page
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
@@ -403,7 +402,8 @@ async def list_run_messages(
before_seq=before_seq,
after_seq=after_seq,
)
data, has_more = trim_run_message_page(rows, limit=limit, after_seq=after_seq)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
+4 -16
View File
@@ -17,7 +17,7 @@ import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from langgraph.checkpoint.base import empty_checkpoint, uuid6
from langgraph.checkpoint.base import empty_checkpoint
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
@@ -536,21 +536,9 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
# Assign a new checkpoint ID so aput performs an INSERT rather than an
# in-place REPLACE of the existing row. Use uuid6 (time-ordered) rather
# than uuid4 (random) so the new ID is always lexicographically greater
# than the previous one — LangGraph's checkpointers determine the "latest"
# checkpoint by max(checkpoint_ids) string order, matching the uuid6 epoch.
checkpoint["id"] = str(uuid6())
# aput requires checkpoint_ns in the config — use the same config used for the
# read (which always includes checkpoint_ns=""). The fresh checkpoint ID is
# assigned above via checkpoint["id"]; keep checkpoint_id out of the config so
# the write is keyed by the new checkpoint payload rather than the prior read.
# All supported savers (InMemorySaver, AsyncSqliteSaver, AsyncPostgresSaver)
# persist and echo back checkpoint["id"] verbatim — none mint their own — so
# the new_config below carries the uuid6 we assigned here. (Regression-locked
# by test_update_thread_state_inserts_new_checkpoint_each_call.)
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
# so that aput generates a fresh checkpoint ID for the new snapshot.
write_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
@@ -569,7 +557,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
# reflects them immediately in both sqlite and memory backends.
if thread_store and body.values and "title" in body.values:
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title: # Skip empty strings and None
try:
+5 -29
View File
@@ -39,39 +39,15 @@ DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
class UploadedFileInfo(BaseModel):
"""Uploaded file metadata exposed by upload and list APIs."""
filename: str
size: int
path: str
virtual_path: str
artifact_url: str
extension: str | None = None
modified: float | None = None
original_filename: str | None = None
markdown_file: str | None = None
markdown_path: str | None = None
markdown_virtual_path: str | None = None
markdown_artifact_url: str | None = None
class UploadResponse(BaseModel):
"""Response model for file upload."""
success: bool
files: list[UploadedFileInfo]
files: list[dict[str, str]]
message: str
skipped_files: list[str] = Field(default_factory=list)
class UploadListResponse(BaseModel):
"""Response model for uploaded file listing."""
files: list[UploadedFileInfo]
count: int
class UploadLimits(BaseModel):
"""Application-level upload limits exposed to clients."""
@@ -280,7 +256,7 @@ async def upload_files(
file_info = {
"filename": safe_filename,
"size": file_size,
"size": str(file_size),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
@@ -357,9 +333,9 @@ async def get_upload_limits(
return _get_upload_limits(config)
@router.get("/list", response_model=UploadListResponse)
@router.get("/list", response_model=dict)
@require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> UploadListResponse:
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
@@ -373,7 +349,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> UploadListRes
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
return UploadListResponse(**result)
return result
@router.delete("/{filename}")
+1 -14
View File
@@ -19,7 +19,6 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config
from deerflow.runtime import (
@@ -141,14 +140,7 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
and ``config['context']`` so they are visible to legacy configurable readers and
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool
see issue #2677).
``user_id`` is intentionally propagated into ``config['context']`` in addition to
the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in
``body.context`` keep it on ``ToolRuntime.context``. It is merged with
``setdefault`` so a server-authenticated id stamped by
:func:`inject_authenticated_user_context` always wins over the client-supplied one.
"""
see issue #2677)."""
if not context:
return
configurable = config.setdefault("configurable", {})
@@ -159,8 +151,6 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
configurable.setdefault(key, context[key])
if isinstance(runtime_context, dict):
runtime_context.setdefault(key, context[key])
if "user_id" in context and isinstance(runtime_context, dict):
runtime_context.setdefault("user_id", context["user_id"])
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
@@ -176,9 +166,6 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request)
if user_id is None:
return
if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
return
runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id)
+4 -22
View File
@@ -228,13 +228,10 @@ Get current MCP server configurations.
GET /api/mcp/config
```
Requires an authenticated admin session. Sensitive env/header/OAuth secret
values are masked in the response.
**Response:**
```json
{
"mcp_servers": {
"mcpServers": {
"github": {
"enabled": true,
"type": "stdio",
@@ -258,15 +255,10 @@ PUT /api/mcp/config
Content-Type: application/json
```
Requires an authenticated admin session. API-managed `stdio` MCP servers may
only use allowed executable names for `command` (default: `npx`, `uvx`). Set
`DEER_FLOW_MCP_STDIO_COMMAND_ALLOWLIST` to a comma-separated list when a
deployment needs additional trusted launchers.
**Request Body:**
```json
{
"mcp_servers": {
"mcpServers": {
"github": {
"enabled": true,
"type": "stdio",
@@ -284,18 +276,8 @@ deployment needs additional trusted launchers.
**Response:**
```json
{
"mcp_servers": {
"github": {
"enabled": true,
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_TOKEN": "***"
},
"description": "GitHub operations"
}
}
"success": true,
"message": "MCP configuration updated"
}
```
+2 -2
View File
@@ -29,7 +29,7 @@ All other test plan sections were executed against either:
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` |
| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-06 | Docker deploy uses Gateway embedded runtime | `./scripts/deploy.sh` produces a Gateway + frontend + nginx topology (no `langgraph` container); same auth flow as local `make dev` | needs `docker compose up` |
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
## Coverage already provided by non-Docker tests
@@ -43,7 +43,7 @@ the test cases that ran on sg_dev or local:
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies |
| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-06 (Gateway embedded runtime container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (Gateway auth flow on sg_dev) — same Gateway code, container is just a packaging change |
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
## Reproduction steps when Docker becomes available
+2 -2
View File
@@ -124,8 +124,8 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
## 兼容性
- **本地开发**`make dev`):Gateway embedded runtime 完全兼容;无 admin 时访问 `/setup` 初始化
- **Gateway embedded runtime**:标准脚本、Docker dev 和生产部署均通过 Gateway 提供认证与 LangGraph-compatible API
- **标准模式**`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化
- **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
+7 -17
View File
@@ -95,35 +95,25 @@ models:
thinking:
type: enabled
- name: minimax-m3
display_name: MiniMax M3
- name: minimax-m2.5
display_name: MiniMax M2.5
use: langchain_openai:ChatOpenAI
model: MiniMax-M3
model: MiniMax-M2.5
api_key: $MINIMAX_API_KEY
base_url: https://api.minimax.io/v1
max_tokens: 4096
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
supports_vision: true
- name: minimax-m2.7
display_name: MiniMax M2.7
- name: minimax-m2.5-highspeed
display_name: MiniMax M2.5 Highspeed
use: langchain_openai:ChatOpenAI
model: MiniMax-M2.7
model: MiniMax-M2.5-highspeed
api_key: $MINIMAX_API_KEY
base_url: https://api.minimax.io/v1
max_tokens: 4096
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
supports_vision: false # M2.7 is text-only; M3 supports vision
- name: minimax-m2.7-highspeed
display_name: MiniMax M2.7 Highspeed
use: langchain_openai:ChatOpenAI
model: MiniMax-M2.7-highspeed
api_key: $MINIMAX_API_KEY
base_url: https://api.minimax.io/v1
max_tokens: 4096
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
supports_vision: false # M2.7 is text-only; M3 supports vision
supports_vision: true
- name: openrouter-gemini-2.5-flash
display_name: Gemini 2.5 Flash (OpenRouter)
use: langchain_openai:ChatOpenAI
-1
View File
@@ -19,7 +19,6 @@ This directory contains detailed documentation for the DeerFlow backend.
| [STREAMING.md](STREAMING.md) | Token-level streaming design: Gateway vs DeerFlowClient paths, `stream_mode` semantics, per-id dedup |
| [FILE_UPLOAD.md](FILE_UPLOAD.md) | File upload functionality |
| [PATH_EXAMPLES.md](PATH_EXAMPLES.md) | Path types and usage examples |
| [SANDBOX_MEMORY_PROFILING.md](SANDBOX_MEMORY_PROFILING.md) | Sandbox memory baseline and runtime comparison guide |
| [summarization.md](summarization.md) | Context summarization feature |
| [plan_mode_usage.md](plan_mode_usage.md) | Plan mode with TodoList |
| [AUTO_TITLE_GENERATION.md](AUTO_TITLE_GENERATION.md) | Automatic title generation |
-120
View File
@@ -1,120 +0,0 @@
# Record/Replay E2E — front-back contract verification
Deterministic, **key-free** end-to-end checks that a backend change can't
silently break the frontend (and vice-versa). Two complementary layers, fed by a
single recording.
## Why
The mock-based frontend e2e hand-writes the backend's JSON/SSE, so a backend
schema or SSE change passes green ("fake green"). These layers replay a recorded
**real** run against the **real** backend (and, for Layer 2, the real frontend),
so contract drift turns the build red instead.
## The two layers
- **Layer 1 — backend golden** (`tests/test_replay_golden.py`): replays a fixture
through the real FastAPI gateway with `ReplayChatModel` and asserts the streamed
SSE event sequence equals a committed golden. Fast, no browser. Guards protocol
*shape*.
- **Layer 2 — full-stack render** (`frontend/tests/e2e-real-backend/`): real
Next.js + real gateway (replay model) + Chromium; asserts the replayed
auto-title and a follow-up suggestion render in the browser. Guards semantic
*render*. (Complementary to Layer 1 — neither subsumes the other.)
Layer 2 also hosts **cross-stack contract scenarios** — the dangerous class
where a backend change silently breaks a frontend assumption and *both sides'
unit tests stay green*. See below.
## Cross-stack scenario: multi-run render order (`multi-run-order.spec.ts`)
Regression guard for issue **#3352** (after context compression, refreshing a
thread rendered history out of order). Root cause was a front-back desync:
backend `RunManager.list_by_thread` returns runs **newest-first** (PR #2932),
while the frontend (`core/threads/hooks.ts`) iterated runs and **prepended** each
loaded page — inverting chronological order once the checkpoint no longer held
the older messages. The backend ordering test was green throughout, and the
frontend regression unit test hardcodes "backend returns newest-first" in a mock,
so only a *real frontend against a real backend* catches the desync.
This scenario does **not** record a conversation. It uses a **test-only seeder**
(`tests/seed_runs_router.py`, mounted on the replay gateway only when
`DEERFLOW_ENABLE_TEST_SEED=1`) to stand up a thread with ≥2 runs and per-run
message events — and deliberately **no checkpoint**, which is the #3352
precondition: it forces the frontend's per-run reload path to be the sole source
of truth so the ordering bug becomes observable. The seeder writes through the
gateway's own run/event stores using the request's auth context, so the real
`list_by_thread``/runs/{id}/messages` → prepend path runs live. Reverting the
#3354 frontend fix turns this spec red.
## How replay works
`tests/replay_provider.py::ReplayChatModel` returns recorded assistant turns keyed
by a **normalized hash of the model caller + conversation**. The conversation is
human / ai / tool messages — role, text, tool-call name+args; with
`<system-reminder>`, dates, UUIDs, tmp paths stripped. The caller is the stable
source of the model call (`lead_agent`, `middleware:title`, `suggest_agent`,
`subagent:*`, etc.). A miss raises loudly rather than passing silently.
**The system prompt is excluded from the match key.** The lead-agent system
prompt is a living, frequently-edited implementation detail — its wording changes
across PRs (e.g. #3195 added a "File Editing Workflow" section). Hashing it would
make every fixture go stale and red-fail unrelated PRs the moment anyone edits the
prompt. The conversation flow (user input → tool calls → results → answer) is the
stable contract that identifies a recorded turn. The caller still stays in the
key so two different model users with identical conversation text do not compete
for the same replay bucket. (This mirrors how open-design's mock picker keys on
the user prompt, not the system internals.) Combined with pinning skills +
extensions empty and disabling memory/summarization
(`tests/_replay_fixture.py::build_config_yaml`), a fixture replays the same across
machines, days, prompt edits, and CI. Replaying needs **no API key**.
A swallowed hash-miss keeps the SSE *event shapes* identical (the gateway wraps it
into a normal assistant error message), so the Layer-1 golden can't catch a miss
by shape alone — it inspects `replay_provider.replay_misses()` and fails loud
instead. Layer-2 already fails on a miss (the recorded turns never render).
## Record a new scenario (needs a real key — dev machine only)
Recording drives the **real frontend** so captured inputs match exactly what the
browser sends; fixtures contain no API key.
```bash
# 1. drive the real frontend against a real-model gateway, capturing model calls
OPENAI_API_KEY=... OPENAI_API_BASE=<openai-compatible-endpoint>/v1 \
DEERFLOW_RECORD_OUT=/tmp/rec/turns.jsonl RECORD_MODEL=<model> \
bash -c 'cd frontend && pnpm exec playwright test -c playwright.record.config.ts'
# 2. stitch the capture into a fixture
cd backend && uv run python scripts/build_fixture_from_jsonl.py \
--jsonl /tmp/rec/turns.jsonl --meta /tmp/rec/turns.jsonl.meta.json \
--out tests/fixtures/replay/<scenario>.<mode>.json --model <model>
# 3. regenerate the committed golden
DEERFLOW_WRITE_GOLDEN=1 PYTHONPATH=. uv run pytest tests/test_replay_golden.py
```
## Run (no key)
```bash
cd backend && PYTHONPATH=. uv run pytest tests/test_replay_golden.py # Layer 1
cd frontend && pnpm exec playwright test -c playwright.real-backend.config.ts # Layer 2
```
## CI
`.github/workflows/replay-e2e.yml` runs both layers on changes to **either** side
of the contract (`frontend/**`, `backend/app/gateway/**`,
`backend/packages/harness/**`, fixtures). DOM assertions are the gate; the rendered
screenshot + Playwright HTML report are uploaded as a CI artifact.
## Known limitations
- Visual regression baselines are OS-specific, so they are a **local dev gate
only** (gitignored); CI uploads the render as an artifact for human review
instead of hard-asserting a cross-OS baseline.
- Fixtures are coupled to the recording-time prompt; if new
environment-dependent content enters the system prompt, extend the
normalization in `replay_provider.py` (or pin it in `build_config_yaml`).
- Re-record a scenario if the agent graph changes how many model calls it makes
— the replay raises loudly on a hash miss pointing at the divergence.
-81
View File
@@ -1,81 +0,0 @@
# Sandbox Memory Profiling
This guide records a repeatable baseline before changing the sandbox runtime.
Issue #3213 reports per-sandbox memory near 1 GiB in Kubernetes. Before adding
or recommending a new provider, capture the current AIO sandbox baseline and
compare candidates with the same DeerFlow workload.
## What to Measure
Measure at least these samples:
1. Empty sandbox after it becomes ready.
2. After a simple bash command.
3. After a Python task that imports common packages.
4. After a Node task when Node-based workloads are expected.
5. After generating files under `/mnt/user-data/outputs`.
6. After release and warm reuse.
7. At the target concurrency level, for example 10, 50, or 100 sandboxes.
`kubectl top` reports Kubernetes/container working set memory. Treat it as a
capacity signal, not exclusive RSS/PSS. Pod-level memory includes every
container in the Pod and may include cache charged to the cgroup. If a result
looks surprising, inspect the sandbox processes and cgroup metrics on the node
before drawing conclusions.
## Capture a Snapshot
Run this from the repository root:
```bash
python scripts/sandbox_memory_profile.py \
--namespace deer-flow \
--selector app=deer-flow-sandbox \
--sample empty \
--include-processes \
--format markdown
```
Use a descriptive `--sample` value for each phase:
```bash
python scripts/sandbox_memory_profile.py --sample after-bash --format json
python scripts/sandbox_memory_profile.py --sample after-python --format json
python scripts/sandbox_memory_profile.py --sample after-artifact --format json
```
`--include-processes` runs `kubectl exec ... ps` in each sandbox Pod and adds
the highest-RSS processes to the report. This helps distinguish Pod-level cgroup
memory from process RSS. The two numbers will not match exactly because cgroup
memory can include cache and other kernel-accounted memory.
Save the raw JSON when comparing backends so totals, pod names, images,
requests, limits, and timestamps can be audited later.
## Candidate Runtime Matrix
For AIO, CubeSandbox, OpenSandbox, gVisor, Kata, or another candidate, compare
the same workload and record:
| Area | Required Evidence |
| --- | --- |
| Capacity | Pod or instance count, total memory, average memory, max memory |
| Startup | Ready latency at 1, 10, 50, and 100 concurrent sandboxes |
| Commands | Bash output, timeout behavior, failure shape |
| Files | `read_file`, `write_file`, binary `update_file`, `list_dir`, `glob`, `grep` |
| Uploads | Files uploaded by the gateway are visible inside the sandbox |
| Artifacts | Files written to `/mnt/user-data/outputs` are readable by the backend artifact API |
| Paths | `/mnt/user-data/workspace`, `/mnt/user-data/uploads`, `/mnt/user-data/outputs`, `/mnt/acp-workspace`, and skills paths keep their expected semantics |
| Isolation | Different users and threads cannot read each other's data |
| Cleanup | Release, idle timeout, process restart, and orphan cleanup free resources |
| Operations | Deployment prerequisites, privileged components, networking, storage, and upgrade path |
## PR Guidance
Do not claim that a new provider fixes high-concurrency memory usage until the
same DeerFlow workload has been measured on both the current AIO sandbox and the
candidate backend.
For an experimental provider PR, prefer `Related to #3213` unless the PR also
includes reproducible DeerFlow workload data that demonstrates the target memory
reduction and preserves uploads, outputs, artifacts, and isolation behavior.
+4 -4
View File
@@ -127,8 +127,8 @@ complex_agent = create_agent_for_task("high")
## How It Works
1. When `make_lead_agent(config)` is called, it extracts `is_plan_mode` from `config.configurable`
2. The config is passed to `build_middlewares(config)`
3. `build_middlewares()` reads `is_plan_mode` and calls `_create_todo_list_middleware(is_plan_mode)`
2. The config is passed to `_build_middlewares(config)`
3. `_build_middlewares()` reads `is_plan_mode` and calls `_create_todo_list_middleware(is_plan_mode)`
4. If `is_plan_mode=True`, a `TodoListMiddleware` instance is created and added to the middleware chain
5. The middleware automatically adds a `write_todos` tool to the agent's toolset
6. The agent can use this tool to manage tasks during execution
@@ -141,7 +141,7 @@ make_lead_agent(config)
├─> Extracts: is_plan_mode = config.configurable.get("is_plan_mode", False)
└─> build_middlewares(config)
└─> _build_middlewares(config)
├─> ThreadDataMiddleware
├─> SandboxMiddleware
@@ -156,7 +156,7 @@ make_lead_agent(config)
### Agent Module
- **Location**: `packages/harness/deerflow/agents/lead_agent/agent.py`
- **Function**: `_create_todo_list_middleware(is_plan_mode: bool)` - Creates TodoListMiddleware if plan mode is enabled
- **Function**: `build_middlewares(config: RunnableConfig)` - Builds middleware chain based on runtime config
- **Function**: `_build_middlewares(config: RunnableConfig)` - Builds middleware chain based on runtime config
- **Function**: `make_lead_agent(config: RunnableConfig)` - Creates agent with appropriate middlewares
### Runtime Configuration
@@ -18,8 +18,6 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
``create_chat_model`` call must add to this list and pass the flag.
"""
from __future__ import annotations
import logging
from langchain.agents import create_agent
@@ -49,8 +47,6 @@ from deerflow.tracing import build_tracing_callbacks
logger = logging.getLogger(__name__)
_BOOTSTRAP_SKILL_NAMES = {"bootstrap"}
def _get_runtime_config(config: RunnableConfig) -> dict:
"""Merge legacy configurable options with LangGraph runtime context."""
@@ -267,31 +263,20 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def build_middlewares(
def _build_middlewares(
config: RunnableConfig,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
*,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
deferred_setup=None,
):
"""Build the lead-agent middleware chain based on runtime configuration.
Public entry point for the lead agent's full middleware composition. Used by
``make_lead_agent`` and by the embedded ``DeerFlowClient`` (a lead-agent variant
that needs the identical chain). Keep this name stable: it is imported across a
module boundary, so renames/signature changes ripple into ``client.py``.
"""Build middleware chain based on runtime configuration.
Args:
config: Runtime configuration containing configurable options like is_plan_mode.
model_name: Resolved runtime model name; gates vision-only middleware.
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
custom_middlewares: Optional list of custom middlewares to inject into the chain.
app_config: Explicit AppConfig; falls back to ``get_app_config()`` when omitted.
deferred_setup: Optional deferred-MCP-tool setup that attaches
``DeferredToolFilterMiddleware`` when ``tool_search`` is enabled.
Returns:
List of middleware instances.
@@ -305,13 +290,6 @@ def build_middlewares(
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
# Deterministically load a full SKILL.md when the user starts the turn with
# /skill-name. This keeps the base system prompt metadata-only while giving
# explicit user activation priority over model-side relevance guessing.
from deerflow.agents.middlewares.skill_activation_middleware import SkillActivationMiddleware
middlewares.append(SkillActivationMiddleware(available_skills=available_skills, app_config=resolved_app_config))
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
@@ -340,13 +318,11 @@ def build_middlewares(
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
# Hide deferred tool schemas from model binding until tool_search promotes them.
# The deferred set + catalog hash come from the build-time setup (assembled
# after tool-policy filtering); promotion is read from graph state.
if deferred_setup is not None and deferred_setup.deferred_names:
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if resolved_app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware(deferred_setup.deferred_names, deferred_setup.catalog_hash))
middlewares.append(DeferredToolFilterMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = cfg.get("subagent_enabled", False)
@@ -379,7 +355,7 @@ def build_middlewares(
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
if is_bootstrap:
return set(_BOOTSTRAP_SKILL_NAMES)
return {"bootstrap"}
if agent_config and agent_config.skills is not None:
return set(agent_config.skills)
return None
@@ -410,7 +386,6 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent, update_agent
from deerflow.tools.builtins.tool_search import assemble_deferred_tools
cfg = _get_runtime_config(config)
resolved_app_config = app_config
@@ -485,27 +460,16 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
# Keep the bootstrap skill set intentionally narrow so agent creation
# remains deterministic before the custom agent's own config exists.
raw_tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
filtered = filter_tools_by_skill_allowed_tools(raw_tools, skills_for_tool_policy)
final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled)
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False),
tools=final_tools,
middleware=build_middlewares(
config,
model_name=model_name,
available_skills=set(_BOOTSTRAP_SKILL_NAMES),
app_config=resolved_app_config,
deferred_setup=setup,
),
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
available_skills=set(_BOOTSTRAP_SKILL_NAMES),
available_skills=set(["bootstrap"]),
app_config=resolved_app_config,
deferred_names=setup.deferred_names,
),
state_schema=ThreadState,
)
@@ -514,27 +478,17 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# The default agent (no agent_name) does not see this tool.
extra_tools = [update_agent] if agent_name else []
# Default lead agent (unchanged behavior)
raw_tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
filtered = filter_tools_by_skill_allowed_tools(raw_tools + extra_tools, skills_for_tool_policy)
final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled)
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False),
tools=final_tools,
middleware=build_middlewares(
config,
model_name=model_name,
agent_name=agent_name,
available_skills=available_skills,
app_config=resolved_app_config,
deferred_setup=setup,
),
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=agent_name,
available_skills=available_skills,
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
app_config=resolved_app_config,
deferred_names=setup.deferred_names,
),
state_schema=ThreadState,
)
@@ -10,7 +10,6 @@ from deerflow.config.agents_config import load_agent_soul
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill, SkillCategory
from deerflow.subagents import get_available_subagent_names
from deerflow.tools.builtins.tool_search import get_deferred_tools_prompt_section
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
@@ -543,14 +542,6 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
- Progressive Loading: Load resources incrementally as referenced in skills
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
- File Editing Workflow: When revising an existing file, prefer
`str_replace` over `write_file` it sends only the diff and avoids
re-emitting the whole file (mirrors Claude Code's Edit and Codex's
apply_patch). When writing long new content from scratch, split it
into sections: the first `write_file` call creates the file, then use
`write_file` with append=True to extend it section by section. This
keeps each tool call small and avoids mid-stream chunk-gap timeouts
on oversized single-shot writes. (See issue #3189.)
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `![Image Description](image_path)\n\n` or "```mermaid" to display images in response or Markdown files
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
@@ -625,11 +616,6 @@ You have access to skills that provide optimized workflows for specific tasks. E
4. Load referenced resources only when needed during execution
5. Follow the skill's instructions precisely
**Explicit Slash Skill Activation:**
- If the user starts a request with `/<skill-name>`, that skill was explicitly requested for the current turn.
- Follow the activated skill before choosing a general workflow.
- The runtime injects the activated skill content for explicit slash activations; do not call `read_file` for that SKILL.md again unless the injected skill references supporting resources you need.
**Skills are located at:** {container_base_path}
{skill_evolution_section}
{skills_list}
@@ -692,13 +678,42 @@ SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace a
Rules:
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
- Only pass the fields that should change. Omit the others to preserve them.
- Never pass literal strings like `"null"`, `"none"`, or `"undefined"` for unchanged fields.
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
</self_update>
"""
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
and can use tool_search to load them.
Returns empty string when tool_search is disabled or no tools are deferred.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
if app_config is None:
try:
from deerflow.config import get_app_config
config = get_app_config()
except Exception:
return ""
else:
config = app_config
if not config.tool_search.enabled:
return ""
registry = get_deferred_registry()
if not registry:
return ""
names = "\n".join(e.name for e in registry.entries)
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
if app_config is None:
@@ -757,7 +772,6 @@ def apply_prompt_template(
agent_name: str | None = None,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
deferred_names: frozenset[str] = frozenset(),
) -> str:
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
@@ -785,7 +799,7 @@ def apply_prompt_template(
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section(deferred_names=deferred_names)
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section(app_config=app_config)
@@ -1,14 +1,9 @@
"""Prompt templates for memory update and injection."""
from __future__ import annotations
import logging
import math
import re
from typing import Any
logger = logging.getLogger(__name__)
try:
import tiktoken
@@ -165,39 +160,6 @@ Rules:
Return ONLY valid JSON."""
# Module-level tiktoken encoding cache. Populated lazily on first use;
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
# (potentially slow) first ``get_encoding`` call.
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
"""Return a cached tiktoken encoding, or ``None`` on failure / unavailability.
On the very first call for a given *encoding_name*, tiktoken may need to
download the BPE data from ``openaipublic.blob.core.windows.net``. In
network-restricted environments (e.g. deployments behind the GFW) this
download can block for tens of minutes before the OS TCP timeout kicks in.
The caller must therefore be prepared for this to block and should run it
off the event loop (e.g. via ``asyncio.to_thread``).
"""
if not TIKTOKEN_AVAILABLE:
return None
cached = _tiktoken_encoding_cache.get(encoding_name)
if cached is not None:
return cached
try:
encoding = tiktoken.get_encoding(encoding_name)
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
except Exception:
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
return None
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
"""Count tokens in text using tiktoken.
@@ -208,30 +170,18 @@ def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
Returns:
The number of tokens in the text.
"""
encoding = _get_tiktoken_encoding(encoding_name)
if encoding is None:
if not TIKTOKEN_AVAILABLE:
# Fallback to character-based estimation if tiktoken is not available
# or the encoding failed to load.
return len(text) // 4
try:
encoding = tiktoken.get_encoding(encoding_name)
return len(encoding.encode(text))
except Exception:
# Fallback to character-based estimation on error
return len(text) // 4
def warm_tiktoken_cache() -> bool:
"""Pre-warm the tiktoken encoding cache.
Call at startup (off the event loop) so the first request never blocks
on the BPE download. Returns ``True`` if the encoding was loaded
successfully (or was already cached), ``False`` if tiktoken is
unavailable or the download failed.
"""
return _get_tiktoken_encoding("cl100k_base") is not None
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
"""Coerce a confidence-like value to a bounded float in [0, 1].
@@ -1,15 +1,12 @@
"""Middleware to filter deferred tool schemas from model binding.
When tool_search is enabled, MCP tools are still passed to ToolNode for
execution, but their schemas must NOT be sent to the LLM via bind_tools until
the model has discovered them via tool_search. This middleware removes the
still-deferred tools from request.tools before model binding, and blocks tool
calls to tools that have not been promoted yet.
When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
and passed to ToolNode for execution, but their schemas should NOT be sent to the
LLM via bind_tools (that's the whole point of deferral — saving context tokens).
The deferred name set and the catalog hash are injected at construction time
(no ContextVar). Promotion state is read from graph state (``state["promoted"]``),
scoped by catalog hash so a stale persisted promotion cannot expose a renamed
or drifted tool.
This middleware intercepts wrap_model_call and removes deferred tools from
request.tools so that model.bind_tools only receives active tool schemas.
The agent discovers deferred tools at runtime via the tool_search tool.
"""
import logging
@@ -27,49 +24,47 @@ logger = logging.getLogger(__name__)
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
"""Hide deferred tool schemas from the bound model until promoted.
"""Remove deferred tools from request.tools before model binding.
ToolNode still holds all tools (including deferred) for execution routing,
but the LLM only sees active tool schemas plus tools that have already been
promoted (recorded in ``state["promoted"]`` under the current catalog hash).
but the LLM only sees active tool schemas deferred tools are discoverable
via tool_search at runtime.
"""
def __init__(self, deferred_names: frozenset[str], catalog_hash: str | None):
super().__init__()
self._deferred = deferred_names
self._catalog_hash = catalog_hash
def _promoted(self, state) -> set[str]:
promoted = (state or {}).get("promoted")
if promoted and promoted.get("catalog_hash") == self._catalog_hash:
return set(promoted.get("names") or [])
return set()
def _hidden(self, state) -> set[str]:
return set(self._deferred) - self._promoted(state)
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
if not self._deferred:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return request
hide = self._hidden(request.state)
if not hide:
return request
active = [t for t in request.tools if getattr(t, "name", None) not in hide]
if len(active) < len(request.tools):
logger.debug("Filtered %d deferred tool schema(s) from model binding", len(request.tools) - len(active))
return request.override(tools=active)
deferred_names = registry.deferred_names
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools):
logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
if not self._deferred:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None
name = str(request.tool_call.get("name") or "")
if not name or name not in self._hidden(request.state):
tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage(
content=(f"Error: Tool '{name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id,
name=name,
name=tool_name,
status="error",
)
@@ -28,7 +28,6 @@ Date-update format:
from __future__ import annotations
import asyncio
import logging
import re
import uuid
@@ -44,12 +43,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Upper bound (seconds) for a single _inject() offload. If the warm-up at
# gateway startup failed silently, the first request may still hit a cold
# tiktoken BPE download that blocks until the OS TCP timeout (~26 min).
# This cap ensures the request degrades gracefully instead of hanging.
_INJECT_TIMEOUT_SECONDS = 5.0
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
_SUMMARY_MESSAGE_NAME = "summary"
@@ -208,25 +201,4 @@ class DynamicContextMiddleware(AgentMiddleware):
@override
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
# _inject() performs synchronous file I/O (memory JSON loading) and
# potentially blocking network calls (tiktoken encoding download on
# first use). Offload to a thread so the event loop is never blocked
# — a blocking call here starves all concurrent HTTP handlers (auth,
# SSE heartbeats, etc.). See issue #3402.
#
# Bounded timeout: if startup warm-up failed silently (e.g. network
# blip during deploy), the first request's cold tiktoken download can
# block for tens of minutes (OS TCP timeout). Time-box injection so
# the request degrades gracefully (no memory context) rather than
# hanging.
try:
return await asyncio.wait_for(
asyncio.to_thread(self._inject, state),
timeout=_INJECT_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"DynamicContextMiddleware: injection timed out (%.1fs); skipping memory/date injection for this turn",
_INJECT_TIMEOUT_SECONDS,
)
return None
return self._inject(state)
@@ -62,41 +62,6 @@ _AUTH_PATTERNS = (
"未授权",
)
# Per-exception retry budget overrides.
#
# Some transient errors are retriable in principle but expensive to retry at
# the default budget. StreamChunkTimeoutError in particular fires after the
# upstream provider has already stalled for `stream_chunk_timeout` seconds
# (typically 120-240s); a full 3-attempt loop can therefore stack 6-12 minutes
# of dead air before surfacing the failure to the user. We keep exactly one
# retry (cheap reconnect that catches genuine transient TCP blips) and then
# fail fast — the same buffered payload is overwhelmingly likely to fail
# again at the upstream provider for the same reason.
#
# Keys are exception class *names* (not classes) so we don't introduce
# import-time coupling on optional dependencies like langchain-openai. The
# value is the absolute max attempt count, NOT additional retries — so a
# value of 2 means "1 first attempt + 1 retry" (the CR-requested
# "keep one retry" behavior).
_RETRY_BUDGET_OVERRIDES: dict[str, int] = {
"StreamChunkTimeoutError": 2,
}
# Exception class names that indicate the upstream stream-chunk watchdog
# fired because the model stalled mid-flight. These deserve a more specific
# user-facing message than the generic "temporarily unavailable" copy,
# because the typical root cause is a long tool-call serialization stalling
# the upstream stream — and the most actionable advice we can give the user
# is "ask for a shorter / split output" rather than "wait and retry".
# Generic connection drops (httpx RemoteProtocolError / ReadError) are
# intentionally excluded: they routinely fire on transient network blips
# with normal payloads, where the "split the work" guidance is misleading.
_STREAM_DROP_EXCEPTIONS: frozenset[str] = frozenset(
{
"StreamChunkTimeoutError",
}
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages."""
@@ -118,18 +83,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _max_attempts_for(self, exc: BaseException) -> int:
"""Return the effective max attempt count for this exception.
Falls back to `self.retry_max_attempts` unless the exception class name
appears in the per-exception override table.
"""
override = _RETRY_BUDGET_OVERRIDES.get(type(exc).__name__)
if override is None:
return self.retry_max_attempts
return min(override, self.retry_max_attempts)
def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock:
@@ -200,7 +153,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"InternalServerError",
"ReadError", # httpx.ReadError: connection dropped mid-stream
"RemoteProtocolError", # httpx: server closed connection unexpectedly
"StreamChunkTimeoutError", # langchain-openai: chunk gap exceeded stream_chunk_timeout
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
@@ -225,24 +177,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_error_fallback_message(
self,
content: str,
*,
error_type: str,
reason: str,
detail: str,
) -> AIMessage:
return AIMessage(
content=content,
additional_kwargs={
"deerflow_error_fallback": True,
"error_type": error_type,
"error_reason": reason,
"error_detail": detail,
},
)
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
@@ -250,31 +184,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}:
# Stream-drop failures (chunk-gap timeout, peer-closed connection,
# raw read error) almost always point at a single oversized
# tool-call payload — the model spent so long serializing JSON
# arguments that the upstream provider buffered and the stream
# gap exceeded `stream_chunk_timeout`. Surfacing this distinct
# cause lets the user split or shorten their next request
# instead of helplessly retrying the same prompt.
if type(exc).__name__ in _STREAM_DROP_EXCEPTIONS:
return (
"The model's streaming response was interrupted before it could "
"finish. This usually happens when a single response or tool call "
"is very large — please ask the assistant to split the work into "
"smaller steps, or shorten the requested output, and try again."
)
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _build_user_fallback_message(self, exc: BaseException, reason: str) -> AIMessage:
return self._build_error_fallback_message(
self._build_user_message(exc, reason),
error_type=type(exc).__name__,
reason=reason,
detail=_extract_error_detail(exc),
)
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
@@ -300,12 +212,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
if self._check_circuit():
return self._build_error_fallback_message(
self._build_circuit_breaker_message(),
error_type="CircuitBreakerOpen",
reason="circuit_open",
detail="LLM circuit breaker is open",
)
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
@@ -321,8 +228,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
max_attempts = self._max_attempts_for(exc)
if retriable and attempt < max_attempts:
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
@@ -343,7 +249,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
)
if retriable:
self._record_failure()
return self._build_user_fallback_message(exc, reason)
return AIMessage(content=self._build_user_message(exc, reason))
@override
async def awrap_model_call(
@@ -352,12 +258,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
if self._check_circuit():
return self._build_error_fallback_message(
self._build_circuit_breaker_message(),
error_type="CircuitBreakerOpen",
reason="circuit_open",
detail="LLM circuit breaker is open",
)
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
@@ -373,8 +274,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
max_attempts = self._max_attempts_for(exc)
if retriable and attempt < max_attempts:
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
@@ -395,7 +295,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
)
if retriable:
self._record_failure()
return self._build_user_fallback_message(exc, reason)
return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
@@ -1,289 +0,0 @@
"""Middleware for explicit slash skill activation."""
from __future__ import annotations
import asyncio
import hashlib
import html
import logging
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, override
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.skills.slash import parse_slash_skill_reference, resolve_slash_skill
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
from deerflow.utils.messages import get_original_user_content_text
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_SLASH_SKILL_ACTIVATION_KEY = "slash_skill_activation"
_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY = "slash_skill_activation_target_id"
_SUMMARY_MESSAGE_NAME = "summary"
@dataclass(frozen=True, slots=True)
class _Activation:
skill_name: str
category: str
container_file_path: str
skill_content: str
content_hash: str
remaining_text: str
@dataclass(frozen=True, slots=True)
class _ActivationResolution:
activation: _Activation | None = None
failure_message: str | None = None
def is_slash_skill_activation_reminder(message: object) -> bool:
"""Return whether a message is hidden slash-skill activation context."""
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_KEY))
def _is_user_activation_target(message: object) -> bool:
if not isinstance(message, HumanMessage):
return False
if message.name == _SUMMARY_MESSAGE_NAME:
return False
if message.additional_kwargs.get("hide_from_ui"):
return False
return True
class SkillActivationMiddleware(AgentMiddleware):
"""Inject full SKILL.md content when the user explicitly types /skill-name."""
def __init__(
self,
*,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> None:
super().__init__()
self._available_skills = set(available_skills) if available_skills is not None else None
self._app_config = app_config
def _storage(self) -> SkillStorage:
if self._app_config is not None:
return get_or_new_skill_storage(app_config=self._app_config)
return get_or_new_skill_storage()
@staticmethod
def _read_skill_content(skill_file: Path, skills_root: Path) -> str:
if skill_file.name != SKILL_MD_FILE:
raise ValueError(f"Expected {SKILL_MD_FILE}, got {skill_file.name}")
resolved_root = skills_root.resolve()
resolved_file = skill_file.resolve()
try:
resolved_file.relative_to(resolved_root)
except ValueError as exc:
raise ValueError("Resolved skill file must stay within the configured skills root.") from exc
if not resolved_file.is_file():
raise FileNotFoundError(resolved_file)
return resolved_file.read_text(encoding="utf-8")
def _resolve_activation(self, text: str) -> _ActivationResolution | None:
reference = parse_slash_skill_reference(text)
if reference is None:
return None
storage = self._storage()
skills = storage.load_skills(enabled_only=False)
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
if skill is None:
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not installed.")
if not skill.enabled:
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
if self._available_skills is not None and reference.name not in self._available_skills:
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
resolved = resolve_slash_skill(
text,
skills,
available_skills=self._available_skills,
container_base_path=storage.get_container_root(),
)
if resolved is None:
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be resolved.")
try:
skill_content = self._read_skill_content(resolved.skill.skill_file, storage.get_skills_root_path())
except (OSError, ValueError):
logger.exception("Failed to read slash-activated skill %s", resolved.skill.name)
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be loaded safely. Please check the skill installation.")
content_hash = hashlib.sha256(skill_content.encode("utf-8")).hexdigest()
return _ActivationResolution(
activation=_Activation(
skill_name=resolved.skill.name,
category=str(resolved.skill.category),
container_file_path=resolved.container_file_path,
skill_content=skill_content,
content_hash=content_hash,
remaining_text=resolved.remaining_text,
)
)
@staticmethod
def _build_activation_reminder(activation: _Activation) -> str:
user_request = activation.remaining_text or ("No additional task text was provided after the slash skill command. Ask the user what they want to do with this skill if the next step is unclear.")
escaped_user_request = html.escape(user_request, quote=False)
escaped_skill_content = html.escape(activation.skill_content, quote=False)
escaped_skill_name = html.escape(activation.skill_name, quote=True)
escaped_category = html.escape(activation.category, quote=True)
escaped_path = html.escape(activation.container_file_path, quote=True)
escaped_content_hash = html.escape(activation.content_hash, quote=True)
return f"""<slash_skill_activation>
The user explicitly activated the `{activation.skill_name}` skill for this turn.
Treat the task text as:
<user_request>
{escaped_user_request}
</user_request>
Follow this skill before choosing a general workflow. Load supporting resources from the same skill directory only when needed.
<skill name="{escaped_skill_name}" category="{escaped_category}" path="{escaped_path}" sha256="{escaped_content_hash}">
<skill_content encoding="xml-escaped">
{escaped_skill_content}
</skill_content>
</skill>
</slash_skill_activation>"""
@staticmethod
def _has_existing_activation_for_target(messages: list, target_index: int, target: HumanMessage) -> bool:
if target_index <= 0:
return False
if target.id:
for previous in messages[:target_index]:
if not is_slash_skill_activation_reminder(previous):
continue
target_id = previous.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY)
if target_id == target.id or previous.id == f"{target.id}__slash_activation":
return True
previous = messages[target_index - 1]
return is_slash_skill_activation_reminder(previous)
def _find_activation_target(self, messages: list) -> tuple[int, HumanMessage, _ActivationResolution] | None:
if not messages:
return None
target_index = next((idx for idx in range(len(messages) - 1, -1, -1) if _is_user_activation_target(messages[idx])), None)
if target_index is None:
return None
target = messages[target_index]
if target is None:
return None
if self._has_existing_activation_for_target(messages, target_index, target):
return None
content = get_original_user_content_text(target.content, target.additional_kwargs)
resolution = self._resolve_activation(content)
if resolution is None:
return None
return target_index, target, resolution
@staticmethod
def _record_activation(request: ModelRequest, activation: _Activation, *, hook: str) -> None:
runtime = getattr(request, "runtime", None)
context = getattr(runtime, "context", None)
journal = context.get("__run_journal") if isinstance(context, dict) else None
if journal is None:
return
try:
journal.record_middleware(
"skill_activation",
name="SkillActivationMiddleware",
hook=hook,
action="activate",
changes={
"skill_name": activation.skill_name,
"category": activation.category,
"path": activation.container_file_path,
"content_hash": activation.content_hash,
},
)
except Exception:
logger.debug("Failed to record slash skill activation audit event", exc_info=True)
def _prepare_model_request(self, request: ModelRequest, *, hook: str) -> ModelRequest | AIMessage | None:
target_and_resolution = self._find_activation_target(list(request.messages))
if target_and_resolution is None:
return None
target_index, target, resolution = target_and_resolution
if resolution.failure_message:
return AIMessage(content=resolution.failure_message)
activation = resolution.activation
if activation is None:
return None
logger.info(
"SkillActivationMiddleware: activating slash skill %s category=%s path=%s hash=%s",
activation.skill_name,
activation.category,
activation.container_file_path,
activation.content_hash,
)
self._record_activation(request, activation, hook=hook)
activation_msg = self._make_activation_message(target, self._build_activation_reminder(activation))
messages = list(request.messages)
messages.insert(target_index, activation_msg)
return request.override(messages=messages)
@staticmethod
def _make_activation_message(target: HumanMessage, activation_content: str) -> HumanMessage:
stable_id = target.id or str(uuid.uuid4())
additional_kwargs = {
"hide_from_ui": True,
_SLASH_SKILL_ACTIVATION_KEY: True,
}
if target.id:
additional_kwargs[_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY] = target.id
return HumanMessage(
content=activation_content,
id=f"{stable_id}__slash_activation",
additional_kwargs=additional_kwargs,
)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse | AIMessage:
prepared = self._prepare_model_request(request, hook="wrap_model_call")
if prepared is None:
return handler(request)
if isinstance(prepared, AIMessage):
return prepared
return handler(prepared)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse | AIMessage:
prepared = await asyncio.to_thread(self._prepare_model_request, request, hook="awrap_model_call")
if prepared is None:
return await handler(request)
if isinstance(prepared, AIMessage):
return prepared
return await handler(prepared)
@@ -9,9 +9,8 @@ from typing import Any, Protocol, override, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage, get_buffer_string
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langgraph.config import get_config
from langgraph.constants import TAG_NOSTREAM
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
@@ -117,74 +116,6 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
# The summary LLM call runs inside a LangGraph middleware hook, so its token
# stream would otherwise be captured by the messages-tuple stream callback and
# broadcast to the frontend as a phantom AI message. Tag a dedicated model copy
# with TAG_NOSTREAM so the streaming handler skips it.
# Keep self.model untagged so the parent's profile / ls_params inspection still works.
#
# Preserve any tags already bound on the model (e.g. "middleware:summarize" set in
# lead_agent/agent.py for RunJournal attribution): RunnableBinding.with_config does a
# shallow merge that would otherwise overwrite the existing tags list entirely.
existing_tags = list((getattr(self.model, "config", None) or {}).get("tags") or [])
merged_tags = [*existing_tags, TAG_NOSTREAM] if TAG_NOSTREAM not in existing_tags else existing_tags
self._summary_model = self.model.with_config(tags=merged_tags)
@override
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
return self._summarize_with(messages_to_summarize)
@override
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
return await self._asummarize_with(messages_to_summarize)
def _summarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Mirror the parent ``_create_summary`` but invoke the nostream-tagged model.
We do not swap ``self.model`` at the instance level: the agent/middleware is
cached and reused across concurrent runs, so a temporary swap would leak the
``RunnableBinding`` to other coroutines during ``await`` and break parent logic
that inspects the raw model (``profile`` / ``_get_ls_params``).
"""
if not messages_to_summarize:
return "No previous conversation history."
prompt = self._build_summary_prompt(messages_to_summarize)
if prompt is None:
return "Previous conversation was too long to summarize."
try:
response = self._summary_model.invoke(
prompt,
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
async def _asummarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Async counterpart of :meth:`_summarize_with` using the nostream model."""
if not messages_to_summarize:
return "No previous conversation history."
prompt = self._build_summary_prompt(messages_to_summarize)
if prompt is None:
return "Previous conversation was too long to summarize."
try:
response = await self._summary_model.ainvoke(
prompt,
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
def _build_summary_prompt(self, messages_to_summarize: list[AnyMessage]) -> str | None:
"""Build the summary prompt, returning ``None`` when trimming leaves nothing."""
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return None
# Format messages to avoid token inflation from metadata when str() is called on
# message objects.
formatted_messages = get_buffer_string(trimmed_messages)
return self.summary_prompt.format(messages=formatted_messages).rstrip()
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)
@@ -2,7 +2,7 @@
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, override
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
@@ -12,48 +12,10 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.config.app_config import AppConfig
from deerflow.subagents.status_contract import (
extract_subagent_status,
make_subagent_additional_kwargs,
)
if TYPE_CHECKING:
from deerflow.tools.builtins.tool_search import DeferredToolSetup
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
_TASK_TOOL_NAME = "task"
def _stamp_task_subagent_status(message: ToolMessage, *, tool_name: str, error: str | None = None) -> ToolMessage:
"""Centralised stamping of ``additional_kwargs.subagent_status``.
Bytedance/deer-flow issue #3146: the frontend now reads the subagent
status from a structured field instead of parsing the leading text of
the task tool's return string. That contract is enforced here, in the
one place every task tool result flows through, rather than at the 5
normal-return + 3 ``Error:`` pre-execution branches inside
``task_tool.py``. Centralisation prevents the "added a new return
path, forgot the stamp" drift mode.
For non-``task`` tools this is a no-op so other tools' additional_kwargs
conventions are untouched.
"""
if tool_name != _TASK_TOOL_NAME:
return message
content = message.content if isinstance(message.content, str) else ""
status = extract_subagent_status(content)
if status is None:
# Non-terminal streaming chunks or unrecognised shapes leave the
# field unset so the frontend can keep the card on its in-progress
# placeholder until a real terminal frame arrives.
return message
stamp = make_subagent_additional_kwargs(status, error=error)
existing = dict(message.additional_kwargs or {})
existing.update(stamp)
message.additional_kwargs = existing
return message
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
@@ -67,31 +29,12 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
detail = detail[:497] + "..."
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
message = ToolMessage(
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
# Stamp the structured subagent status on the wrapper too: the
# frontend would otherwise have to fall back to prefix-matching
# ``Error: Tool 'task' failed ...`` on the wire. The ``subagent_error``
# carries the same ``ExcClass: detail`` shape the wrapper string
# uses so debugging artifacts stay aligned.
structured_error = f"{exc.__class__.__name__}: {detail}"
return _stamp_task_subagent_status(message, tool_name=tool_name, error=structured_error)
@staticmethod
def _maybe_stamp(result: ToolMessage | Command, request: ToolCallRequest) -> ToolMessage | Command:
"""Apply the subagent stamp to successful task tool returns.
``Command`` results bypass the stamp they encode LangGraph
control flow rather than user-facing tool output.
"""
if not isinstance(result, ToolMessage):
return result
tool_name = str(request.tool_call.get("name") or "")
return _stamp_task_subagent_status(result, tool_name=tool_name)
@override
def wrap_tool_call(
@@ -100,14 +43,13 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
try:
result = handler(request)
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
return self._maybe_stamp(result, request)
@override
async def awrap_tool_call(
@@ -116,14 +58,13 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
result = await handler(request)
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
return self._maybe_stamp(result, request)
def _build_runtime_middlewares(
@@ -202,7 +143,6 @@ def build_subagent_runtime_middlewares(
app_config: AppConfig | None = None,
model_name: str | None = None,
lazy_init: bool = True,
deferred_setup: "DeferredToolSetup | None" = None,
) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
if app_config is None:
@@ -226,16 +166,6 @@ def build_subagent_runtime_middlewares(
middlewares.append(ViewImageMiddleware())
# Hide deferred (MCP) tool schemas from the subagent's model binding until
# tool_search promotes them. This is the same wiring the lead agent gets. The deferred
# set + catalog hash come from the build-time setup (assembled after
# tool-policy filtering); promotion is read from graph state. Empty/None
# setup (deferral disabled or no MCP tool survived) is a pure no-op.
if deferred_setup is not None and deferred_setup.deferred_names:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware(deferred_setup.deferred_names, deferred_setup.catalog_hash))
# Same provider safety-termination guard the lead agent uses — subagents
# are equally exposed to truncated tool_calls returned with
# finish_reason=content_filter (and friends), and the bad call would then
@@ -11,11 +11,10 @@ from __future__ import annotations
import asyncio
import logging
import os
import shlex
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import replace as dc_replace
from typing import TYPE_CHECKING, Any, override
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
@@ -25,19 +24,9 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.config.tool_output_config import ToolOutputConfig
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
if TYPE_CHECKING:
from deerflow.sandbox.sandbox import Sandbox
logger = logging.getLogger(__name__)
# Virtual outputs root inside the sandbox. Host-mounted sandboxes map this to
# the thread outputs dir on the host; for non-mounted (remote) sandboxes the
# same path is written directly into the sandbox filesystem so the model's
# ``read_file`` tool can read it back (issue #3416).
_VIRTUAL_OUTPUTS_BASE = "/mnt/user-data/outputs"
def _default_config() -> ToolOutputConfig:
return ToolOutputConfig()
@@ -105,18 +94,6 @@ def _sanitize_tool_name(name: str) -> str:
return safe or "unknown"
def _build_externalized_filename(*, tool_name: str, tool_call_id: str) -> str:
"""Build the on-disk filename for an externalized tool output.
Shared by the host-disk and sandbox externalization paths so both
produce the identical naming scheme.
"""
safe_name = _sanitize_tool_name(tool_name)
ext = _EXT_MAP.get(tool_name, "txt")
short_id = uuid.uuid4().hex[:12]
return f"{safe_name}-{short_id}.{ext}"
def _externalize(
content: str,
*,
@@ -134,7 +111,10 @@ def _externalize(
except OSError:
return None
filename = _build_externalized_filename(tool_name=tool_name, tool_call_id=tool_call_id)
safe_name = _sanitize_tool_name(tool_name)
ext = _EXT_MAP.get(tool_name, "txt")
short_id = uuid.uuid4().hex[:12]
filename = f"{safe_name}-{short_id}.{ext}"
filepath = os.path.join(storage_dir, filename)
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
@@ -146,56 +126,8 @@ def _externalize(
except OSError:
return None
return f"{_VIRTUAL_OUTPUTS_BASE}/{storage_subdir}/{filename}"
def _externalize_to_sandbox(
content: str,
*,
tool_name: str,
tool_call_id: str,
storage_subdir: str,
sandbox: Sandbox,
) -> str | None:
"""Write *content* into the sandbox filesystem and return the virtual path.
Used when the sandbox does not use thread-data mounts (e.g. a remote AIO
sandbox): the host-side :func:`_externalize` virtual path would not exist
inside the sandbox, so the model's ``read_file`` tool could not read it
back (issue #3416). Returns the same virtual-path contract on success, or
``None`` to signal the caller to fall back to inline truncation.
"""
if os.path.isabs(storage_subdir) or ".." in storage_subdir:
return None
filename = _build_externalized_filename(tool_name=tool_name, tool_call_id=tool_call_id)
virtual_dir = f"{_VIRTUAL_OUTPUTS_BASE}/{storage_subdir}"
virtual_path = f"{virtual_dir}/{filename}"
try:
# AIO sandbox write_file does NOT create parent directories, so create
# them explicitly before writing. execute_command returns its stdout
# verbatim (including an "Error: ..." string on failure) rather than
# raising, so we cannot rely on exception propagation here.
sandbox.execute_command(f"mkdir -p {shlex.quote(virtual_dir)}")
sandbox.write_file(virtual_path, content)
# Validate the file landed: execute_command may have silently failed
# to create the directory, and write_file backends differ. Refuse to
# hand the model an unreadable read_file path.
check = sandbox.execute_command(f"test -s {shlex.quote(virtual_path)} && echo OK || echo MISSING")
if not isinstance(check, str) or check.strip() != "OK":
logger.warning(
"Sandbox externalize validation failed: path=%s, check=%r",
virtual_path,
check,
)
return None
except Exception:
logger.exception(
"Failed to externalize %s output to sandbox (call_id=%s)",
tool_name,
tool_call_id,
)
return None
return virtual_path
virtual_base = "/mnt/user-data/outputs"
return f"{virtual_base}/{storage_subdir}/{filename}"
# ---------------------------------------------------------------------------
@@ -295,33 +227,6 @@ def _resolve_outputs_path(request: ToolCallRequest) -> str | None:
return outputs_path if isinstance(outputs_path, str) else None
def _resolve_sandbox(request: ToolCallRequest) -> Sandbox | None:
"""Resolve the active sandbox for the current tool call, or ``None``.
Reads the sandbox_id that ``SandboxMiddleware`` (and the sandbox tools
themselves) write into ``runtime.state["sandbox"]``. We intentionally do
NOT call ``provider.acquire`` here: acquiring a sandbox can trigger
blocking remote I/O, and this resolver runs on every tool call. Tools
that do not use a sandbox (``web_search``, MCP, ...) will return ``None``
here, which is fine -- the caller falls back to inline truncation.
"""
runtime = getattr(request, "runtime", None)
state = getattr(runtime, "state", None)
if not isinstance(state, dict):
return None
sandbox_state = state.get("sandbox")
if not isinstance(sandbox_state, dict):
return None
sandbox_id = sandbox_state.get("sandbox_id")
if not sandbox_id:
return None
try:
return get_sandbox_provider().get(sandbox_id)
except Exception:
logger.exception("Failed to look up sandbox %s for tool-output externalization", sandbox_id)
return None
def _budget_content(
content: str,
*,
@@ -329,7 +234,6 @@ def _budget_content(
tool_call_id: str,
outputs_path: str | None,
config: ToolOutputConfig,
sandbox: Sandbox | None = None,
) -> str | None:
"""Apply budget to *content*. Returns ``None`` if no change needed."""
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
@@ -338,50 +242,14 @@ def _budget_content(
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
return None
if threshold > 0 and len(content) > threshold:
virtual_path: str | None = None
# Decide persistence target based on what's available, without touching
# the sandbox provider unless a sandbox was actually resolved for this
# call. This keeps the legacy host-disk path provider-free, so callers
# without a configured sandbox (and CI environments without a
# config.yaml) continue to externalize to the host as before.
if sandbox is not None:
provider = None
try:
provider = get_sandbox_provider()
except Exception:
logger.exception("Failed to get sandbox provider for tool-output externalization; falling back to inline truncation")
if provider is not None and getattr(provider, "uses_thread_data_mounts", False):
# Host-mounted sandbox: host outputs path is bind-mounted into
# the sandbox at the same virtual path, so writing host-side is
# equivalent. Preserve the original behavior to avoid extra
# sandbox round-trips.
if outputs_path:
virtual_path = _externalize(
content,
tool_name=tool_name,
tool_call_id=tool_call_id,
outputs_path=outputs_path,
storage_subdir=config.storage_subdir,
)
else:
virtual_path = _externalize_to_sandbox(
content,
tool_name=tool_name,
tool_call_id=tool_call_id,
storage_subdir=config.storage_subdir,
sandbox=sandbox,
)
elif outputs_path:
# No sandbox in this call (legacy / non-sandbox tools): write to
# host outputs path directly, no provider needed.
virtual_path = _externalize(
content,
tool_name=tool_name,
tool_call_id=tool_call_id,
outputs_path=outputs_path,
storage_subdir=config.storage_subdir,
)
if threshold > 0 and len(content) > threshold and outputs_path:
virtual_path = _externalize(
content,
tool_name=tool_name,
tool_call_id=tool_call_id,
outputs_path=outputs_path,
storage_subdir=config.storage_subdir,
)
if virtual_path is not None:
logger.info(
"Externalized %s output (%d chars) to %s",
@@ -420,12 +288,7 @@ def _budget_content(
# ---------------------------------------------------------------------------
def _patch_tool_message(
msg: ToolMessage,
config: ToolOutputConfig,
outputs_path: str | None,
sandbox: Sandbox | None = None,
) -> ToolMessage:
def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage:
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
tool_name = msg.name or "unknown"
if tool_name in config.exempt_tools:
@@ -441,7 +304,6 @@ def _patch_tool_message(
tool_call_id=msg.tool_call_id or "",
outputs_path=outputs_path,
config=config,
sandbox=sandbox,
)
if replacement is None:
return msg
@@ -493,15 +355,10 @@ def _needs_budget(result: ToolMessage | Command, config: ToolOutputConfig) -> bo
return False
def _patch_result(
result: ToolMessage | Command,
config: ToolOutputConfig,
outputs_path: str | None,
sandbox: Sandbox | None = None,
) -> ToolMessage | Command:
def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage | Command:
"""Apply budget to a tool call result (ToolMessage or Command)."""
if isinstance(result, ToolMessage):
return _patch_tool_message(result, config, outputs_path, sandbox)
return _patch_tool_message(result, config, outputs_path)
update = getattr(result, "update", None)
if not isinstance(update, dict):
@@ -515,7 +372,7 @@ def _patch_result(
changed = False
for msg in messages:
if isinstance(msg, ToolMessage):
patched = _patch_tool_message(msg, config, outputs_path, sandbox)
patched = _patch_tool_message(msg, config, outputs_path)
if patched is not msg:
changed = True
new_messages.append(patched)
@@ -535,11 +392,6 @@ def _patch_model_messages(messages: list[Any], config: ToolOutputConfig) -> list
ToolMessage exceeds the budget the common case once every result has
already been budgeted at tool-call time, so a long history is not rebuilt
on every model call.
Historical messages do not get a ``sandbox`` argument: any oversized tool
message in history was already budgeted (and possibly externalized) at
tool-call time, so the only thing left for the history path to do is
inline fallback truncation, which needs no sandbox.
"""
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
return None
@@ -590,8 +442,7 @@ class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
if not _needs_budget(result, self._config):
return result
outputs_path = _resolve_outputs_path(request)
sandbox = _resolve_sandbox(request)
return _patch_result(result, self._config, outputs_path, sandbox)
return _patch_result(result, self._config, outputs_path)
@override
async def awrap_tool_call(
@@ -605,12 +456,7 @@ class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
if not _needs_budget(result, self._config):
return result
outputs_path = _resolve_outputs_path(request)
# _resolve_sandbox only touches runtime.state and the provider's
# in-memory sandbox registry, so it is safe to call on the event
# loop. The actual sandbox I/O (mkdir/write/test) happens inside
# _patch_result, which is offloaded to a worker thread below.
sandbox = _resolve_sandbox(request)
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path, sandbox)
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path)
# -- model call hooks (historical message truncation) ------------------
@@ -13,7 +13,6 @@ from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.file_conversion import extract_outline
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY, message_content_to_text
logger = logging.getLogger(__name__)
@@ -266,8 +265,6 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
# Extract original content - handle both string and list formats
original_content = last_message.content
additional_kwargs = dict(last_message.additional_kwargs or {})
additional_kwargs.setdefault(ORIGINAL_USER_CONTENT_KEY, message_content_to_text(original_content))
if isinstance(original_content, str):
# Simple case: string content, just prepend files message
updated_content = f"{files_message}\n\n{original_content}"
@@ -288,7 +285,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
content=updated_content,
id=last_message.id,
name=last_message.name,
additional_kwargs=additional_kwargs,
additional_kwargs=last_message.additional_kwargs,
)
messages[last_message_index] = updated_message
@@ -179,10 +179,8 @@ class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
# Create the image details message with text and image content
image_content = self._create_image_details_message(state)
# Create a new human message with mixed content (text + images). This is
# internal context for the model only, so hide it from the chat UI and IM
# channels (matches the other middleware-injected context messages).
human_msg = HumanMessage(content=image_content, additional_kwargs={"hide_from_ui": True})
# Create a new human message with mixed content (text + images)
human_msg = HumanMessage(content=image_content)
logger.debug("Injecting image details message with images before LLM call")
@@ -58,32 +58,6 @@ def merge_todos(existing: list | None, new: list | None) -> list | None:
return new
class PromotedTools(TypedDict):
catalog_hash: str
names: list[str]
def merge_promoted(existing: PromotedTools | None, new: PromotedTools | None) -> PromotedTools | None:
"""Reducer for deferred-tool promotions, scoped by catalog hash.
- new None/empty -> preserve existing (node didn't touch promotions).
- catalog_hash changed -> replace wholesale, dropping stale names (prevents a
persisted bare name from exposing a different tool after catalog drift).
- same catalog_hash -> union names, dedupe, preserve order.
"""
if not new:
return existing
if existing is None or existing.get("catalog_hash") != new["catalog_hash"]:
return {
"catalog_hash": new["catalog_hash"],
"names": list(dict.fromkeys(new["names"])),
}
return {
"catalog_hash": existing["catalog_hash"],
"names": list(dict.fromkeys(existing["names"] + new["names"])),
}
class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
@@ -92,4 +66,3 @@ class ThreadState(AgentState):
todos: Annotated[list | None, merge_todos]
uploaded_files: NotRequired[list[dict] | None]
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
promoted: Annotated[PromotedTools | None, merge_promoted]
+4 -16
View File
@@ -33,7 +33,7 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.agent import build_middlewares
from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN
@@ -43,7 +43,6 @@ from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.tools.builtins.tool_search import assemble_deferred_tools
from deerflow.tracing import build_tracing_callbacks, inject_langfuse_metadata
from deerflow.uploads.manager import (
claim_unique_filename,
@@ -238,30 +237,19 @@ class DeerFlowClient:
subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
tools = self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled)
final_tools, deferred_setup = assemble_deferred_tools(tools, enabled=self._app_config.tool_search.enabled)
kwargs: dict[str, Any] = {
# attach_tracing=False because ``stream()`` injects tracing
# callbacks at the graph invocation root so a single embedded run
# produces one trace with correct session_id / user_id propagation.
# Attaching them again on the model would emit duplicate spans.
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False),
"tools": final_tools,
"middleware": build_middlewares(
config,
model_name=model_name,
agent_name=self._agent_name,
available_skills=self._available_skills,
custom_middlewares=self._middlewares,
app_config=self._app_config,
deferred_setup=deferred_setup,
),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
available_skills=self._available_skills,
deferred_names=deferred_setup.deferred_names,
),
"state_schema": ThreadState,
}
@@ -1218,7 +1206,7 @@ class DeerFlowClient:
info: dict[str, Any] = {
"filename": dest_name,
"size": dest.stat().st_size,
"size": str(dest.stat().st_size),
"path": str(dest),
"virtual_path": upload_virtual_path(dest_name),
"artifact_url": upload_artifact_url(thread_id, dest_name),
@@ -39,63 +39,11 @@ class AioSandbox(Sandbox):
self._client = AioSandboxClient(base_url=base_url, timeout=600)
self._home_dir = home_dir
self._lock = threading.Lock()
self._closed = False
@property
def base_url(self) -> str:
return self._base_url
def close(self) -> None:
"""Best-effort close of the host-side HTTP client owned by this sandbox.
The agent_sandbox SDK is Fern-generated and exposes no ``close()`` /
``__exit__``, so we reach the socket-owning ``httpx.Client`` explicitly
through its attribute chain::
Sandbox._client_wrapper -> SyncClientWrapper
.httpx_client -> Fern HttpClient (a wrapper, NOT httpx.Client)
.httpx_client -> httpx.Client <- the real socket owner
Closing it releases pooled sockets so long-running provider lifecycles
do not accumulate unreclaimed host-side resources (#2872).
Resolution is most-specific-first with graceful degradation: if a future
SDK adds a top-level ``Sandbox.close()`` it is picked up automatically
without changing this code. Idempotent, thread-safe, and non-fatal:
failures during teardown are logged and swallowed so provider/backend
cleanup is never blocked.
"""
with self._lock:
if self._closed:
return
self._closed = True
client = self._client
# Drop the reference under the lock for use-after-close safety: any
# later command on this instance fails loudly instead of reusing a
# half-closed client.
self._client = None
if client is None:
return
# Walk from the real httpx.Client up to the top-level client, picking the
# first object that actually exposes close().
wrapper = getattr(client, "_client_wrapper", None)
fern_http = getattr(wrapper, "httpx_client", None)
real_httpx = getattr(fern_http, "httpx_client", None)
target = next(
(c for c in (real_httpx, fern_http, client) if c is not None and hasattr(c, "close")),
None,
)
if target is None:
logger.debug("AioSandbox %s: no closable client found, nothing to release", self.id)
return
try:
target.close()
except Exception as e:
logger.warning(f"Error closing AioSandbox client for {self.id}: {e}")
@property
def home_dir(self) -> str:
"""Get the home directory inside the sandbox."""
@@ -790,20 +790,14 @@ class AioSandboxProvider(SandboxProvider):
thread on its next turn without a cold-start. The container will only be
stopped when the replicas limit forces eviction or during shutdown.
The host-side HTTP client owned by the cached ``AioSandbox`` instance is
closed before the instance is dropped (#2872). The warm-pool entry only
stores ``SandboxInfo``, so a fresh ``AioSandbox`` (and a fresh client)
is constructed if the container is later reclaimed.
Args:
sandbox_id: The ID of the sandbox to release.
"""
info = None
sandbox = None
thread_ids_to_remove: list[str] = []
with self._lock:
sandbox = self._sandboxes.pop(sandbox_id, None)
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
@@ -813,15 +807,6 @@ class AioSandboxProvider(SandboxProvider):
if info and sandbox_id not in self._warm_pool:
self._warm_pool[sandbox_id] = (info, time.time())
if sandbox is not None:
# Defense-in-depth: close() already swallows its own errors; this
# guard only protects against a future close() that misbehaves, so
# host-side client cleanup can never block parking in the warm pool.
try:
sandbox.close()
except Exception as e:
logger.warning(f"Error closing sandbox {sandbox_id} during release: {e}")
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
def destroy(self, sandbox_id: str) -> None:
@@ -830,19 +815,14 @@ class AioSandboxProvider(SandboxProvider):
Unlike release(), this actually stops the container. Use this for
explicit cleanup, capacity-driven eviction, or shutdown.
The host-side HTTP client owned by the cached ``AioSandbox`` instance is
closed alongside backend/container destruction so no client/socket
resources leak (#2872).
Args:
sandbox_id: The ID of the sandbox to destroy.
"""
info = None
sandbox = None
thread_ids_to_remove: list[str] = []
with self._lock:
sandbox = self._sandboxes.pop(sandbox_id, None)
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
@@ -854,15 +834,6 @@ class AioSandboxProvider(SandboxProvider):
else:
self._warm_pool.pop(sandbox_id, None)
if sandbox is not None:
# Defense-in-depth: close() already swallows its own errors; this
# guard only protects against a future close() that misbehaves, so
# host-side client cleanup can never block container destruction.
try:
sandbox.close()
except Exception as e:
logger.warning(f"Error closing sandbox {sandbox_id} during destroy: {e}")
if info:
self._backend.destroy(info)
logger.info(f"Destroyed sandbox {sandbox_id}")
@@ -11,85 +11,12 @@ from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
DEFAULT_BACKEND = "auto"
DEFAULT_REGION = "wt-wt"
DEFAULT_SAFESEARCH = "moderate"
DEFAULT_WIKIPEDIA_REGION = "us-en"
WIKIPEDIA_BACKENDS = {"auto", "all", "wikipedia"}
WIKIPEDIA_LANGUAGE_ALIASES = {
"jp": "ja",
"kr": "ko",
"tzh": "zh",
"wt": "en",
}
def _normalize_backend(backend: str | list[str] | tuple[str, ...] | None) -> str:
if backend is None:
return DEFAULT_BACKEND
if isinstance(backend, (list, tuple)):
return ",".join(str(part).strip() for part in backend if str(part).strip()) or DEFAULT_BACKEND
return str(backend).strip() or DEFAULT_BACKEND
def _normalize_setting(value: str | None, default: str) -> str:
return str(value).strip() if value else default
def _backend_includes_wikipedia(backend: str | list[str] | tuple[str, ...] | None) -> bool:
backend = _normalize_backend(backend)
return any(part.strip().lower() in WIKIPEDIA_BACKENDS for part in backend.split(","))
def _contains_codepoint(query: str, ranges: tuple[tuple[int, int], ...]) -> bool:
return any(start <= ord(char) <= end for char in query for start, end in ranges)
def _infer_wikipedia_region(query: str) -> str:
"""Pick a valid Wikipedia language region when DDGS' worldwide region is used."""
if _contains_codepoint(query, ((0x3040, 0x30FF), (0x31F0, 0x31FF))):
return "jp-ja"
if _contains_codepoint(query, ((0xAC00, 0xD7AF), (0x1100, 0x11FF), (0x3130, 0x318F))):
return "kr-ko"
if _contains_codepoint(query, ((0x3400, 0x9FFF),)):
return "cn-zh"
if _contains_codepoint(query, ((0x0400, 0x04FF),)):
return "ru-ru"
if _contains_codepoint(query, ((0x0370, 0x03FF),)):
return "gr-el"
if _contains_codepoint(query, ((0x0590, 0x05FF),)):
return "il-he"
if _contains_codepoint(query, ((0x0600, 0x06FF),)):
return "xa-ar"
return DEFAULT_WIKIPEDIA_REGION
def _resolve_ddgs_region(query: str, region: str | None, backend: str | list[str] | tuple[str, ...] | None) -> str:
"""
DDGS' wikipedia engine treats the second part of region as a Wikipedia
subdomain. Its default worldwide region, wt-wt, becomes wt.wikipedia.org.
"""
normalized_region = _normalize_setting(region, DEFAULT_REGION).lower()
if not _backend_includes_wikipedia(backend):
return normalized_region
if normalized_region == DEFAULT_REGION:
return _infer_wikipedia_region(query)
if "-" not in normalized_region:
return DEFAULT_WIKIPEDIA_REGION
country, language = normalized_region.split("-", 1)
return f"{country}-{WIKIPEDIA_LANGUAGE_ALIASES.get(language, language)}"
def _search_text(
query: str,
max_results: int = 5,
region: str | None = DEFAULT_REGION,
safesearch: str | None = DEFAULT_SAFESEARCH,
backend: str | list[str] | tuple[str, ...] | None = DEFAULT_BACKEND,
region: str = "wt-wt",
safesearch: str = "moderate",
) -> list[dict]:
"""
Execute text search using DuckDuckGo.
@@ -99,7 +26,6 @@ def _search_text(
max_results: Maximum number of results
region: Search region
safesearch: Safe search level
backend: DDGS backend(s), e.g. "auto", "duckduckgo", or "duckduckgo,brave"
Returns:
List of search results
@@ -113,15 +39,11 @@ def _search_text(
ddgs = DDGS(timeout=30)
try:
backend = _normalize_backend(backend)
safesearch = _normalize_setting(safesearch, DEFAULT_SAFESEARCH)
effective_region = _resolve_ddgs_region(query, region, backend)
results = ddgs.text(
query,
region=effective_region,
region=region,
safesearch=safesearch,
max_results=max_results,
backend=backend,
)
return list(results) if results else []
@@ -142,23 +64,14 @@ def web_search_tool(
max_results: Maximum number of results to return. Default is 5.
"""
config = get_app_config().get_tool_config("web_search")
region = DEFAULT_REGION
safesearch = DEFAULT_SAFESEARCH
backend = DEFAULT_BACKEND
if config is not None:
# Override tool call defaults from config if set.
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
region = config.model_extra.get("region", region)
safesearch = config.model_extra.get("safesearch", safesearch)
backend = config.model_extra.get("backend", backend)
results = _search_text(
query=query,
max_results=max_results,
region=region,
safesearch=safesearch,
backend=backend,
)
if not results:
@@ -9,7 +9,7 @@ _api_key_warned = False
class JinaClient:
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10, proxy: str | None = None, trust_env: bool = True) -> str:
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
global _api_key_warned
headers = {
"Content-Type": "application/json",
@@ -23,10 +23,7 @@ class JinaClient:
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
data = {"url": url}
try:
client_kwargs: dict[str, object] = {"trust_env": trust_env}
if proxy:
client_kwargs["proxy"] = proxy
async with httpx.AsyncClient(**client_kwargs) as client:
async with httpx.AsyncClient() as client:
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
if response.status_code != 200:
@@ -9,38 +9,6 @@ from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor()
def _coerce_bool(value: object, default: bool) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def _coerce_timeout(value: object, default: int) -> int:
if isinstance(value, bool):
return default
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return default
return default
def _coerce_proxy(value: object) -> str | None:
if not isinstance(value, str):
return None
proxy = value.strip()
return proxy or None
@tool("web_fetch", parse_docstring=True)
async def web_fetch_tool(url: str) -> str:
"""Fetch the contents of a web page at a given URL.
@@ -54,14 +22,10 @@ async def web_fetch_tool(url: str) -> str:
"""
jina_client = JinaClient()
timeout = 10
proxy = None
trust_env = True
config = get_app_config().get_tool_config("web_fetch")
if config is not None:
timeout = _coerce_timeout(config.model_extra.get("timeout"), timeout)
proxy = _coerce_proxy(config.model_extra.get("proxy"))
trust_env = _coerce_bool(config.model_extra.get("trust_env"), trust_env)
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout, proxy=proxy, trust_env=trust_env)
if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
@@ -7,7 +7,7 @@ from typing import Any, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
@@ -18,7 +18,6 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.reload_boundary import format_field_description
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.runtime_paths import existing_project_file
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
@@ -86,21 +85,10 @@ def apply_logging_level(name: str | None) -> None:
class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
log_level: str = Field(
default="info",
description=format_field_description(
"log_level",
field_doc="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected.",
),
)
log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected")
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
sandbox: SandboxConfig = Field(
description=format_field_description(
"sandbox",
field_doc="Sandbox provider configuration (local filesystem or Docker-based aio sandbox).",
),
)
sandbox: SandboxConfig = Field(description="Sandbox configuration")
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
@@ -119,49 +107,10 @@ class AppConfig(BaseModel):
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(
default_factory=DatabaseConfig,
description=format_field_description(
"database",
field_doc="Unified database backend for run/feedback metadata (memory, sqlite, or postgres).",
),
)
run_events: RunEventsConfig = Field(
default_factory=RunEventsConfig,
description=format_field_description(
"run_events",
field_doc="Run-event store backend (memory for dev, db for production queries, jsonl for lightweight single-node persistence).",
),
)
checkpointer: CheckpointerConfig | None = Field(
default=None,
description=format_field_description(
"checkpointer",
field_doc="LangGraph state-persistence checkpointer configuration.",
),
)
stream_bridge: StreamBridgeConfig | None = Field(
default=None,
description=format_field_description(
"stream_bridge",
field_doc="Stream bridge connecting agent workers to SSE endpoints.",
),
)
@field_validator("models", "tools", "tool_groups", mode="before")
@classmethod
def _coerce_null_list_sections(cls, value: Any) -> Any:
"""Treat a present-but-empty config section as an empty list.
Commenting out every entry under a top-level YAML key e.g. ``models:``
with only comments beneath it, exactly as shipped in
``config.example.yaml`` makes PyYAML parse the value as ``None``.
Without this, the documented ``cp config.example.yaml config.yaml``
first-run flow crashes with an opaque ``Input should be a valid list``
pydantic error. Coercing ``None`` to ``[]`` keeps that flow working and
matches the field's own ``default_factory=list``.
"""
return [] if value is None else value
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -224,11 +173,6 @@ class AppConfig(BaseModel):
config_data["extensions"] = extensions_config.model_dump()
result = cls.model_validate(config_data)
if not result.models:
logger.warning(
"No models are configured in %s. Add at least one entry under `models:` (see the commented examples in config.example.yaml) or run `make setup`.",
resolved_path,
)
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
cls._apply_singleton_configs(result, acp_agents)
return result
@@ -41,20 +41,6 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
_checkpointer_config = config
def ensure_config_loaded() -> None:
"""Lazily load app config when checkpointer config has not been initialized."""
from deerflow.config.app_config import _app_config, get_app_config
config = get_checkpointer_config()
if config is not None or _app_config is not None:
return
try:
get_app_config()
except FileNotFoundError:
pass
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
@@ -5,7 +5,7 @@ import os
from pathlib import Path
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.runtime_paths import existing_project_file
@@ -47,24 +47,6 @@ class McpServerConfig(BaseModel):
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
@model_validator(mode="before")
@classmethod
def _accept_transport_alias(cls, data: Any) -> Any:
"""Accept the MCP-spec ``transport`` field as an alias for ``type``.
The official MCP configuration schema uses ``transport`` to indicate
the transport mechanism (``stdio``/``sse``/``http``). Earlier versions
of this project only honored ``type``, which caused remote SSE/HTTP
servers configured with just ``transport`` to be incorrectly treated as
``stdio`` (the default). This validator normalizes the two so either
spelling works, with ``type`` taking precedence when both are provided.
"""
if isinstance(data, dict):
transport = data.get("transport")
if transport and not data.get("type"):
data = {**data, "type": transport}
return data
class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state."""
@@ -32,16 +32,6 @@ class ModelConfig(BaseModel):
description="Extra settings to be passed to the model when thinking is disabled",
)
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
stream_chunk_timeout: float | None = Field(
default=None,
description=(
"Maximum seconds to wait between successive streaming chunks before "
"langchain-openai raises StreamChunkTimeoutError. None means use the "
"factory default (240s for OpenAI-compatible clients). Tune higher for "
"reasoning models with long thinking pauses; lower for latency-sensitive "
"interactive endpoints. Has no effect on non-OpenAI-compatible providers."
),
)
thinking: dict | None = Field(
default_factory=lambda: None,
description=(
@@ -1,4 +1,3 @@
import hashlib
import os
import re
import shutil
@@ -11,8 +10,6 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
def _default_local_base_dir() -> Path:
@@ -34,23 +31,6 @@ def _validate_user_id(user_id: str) -> str:
return user_id
def make_safe_user_id(raw: str) -> str:
"""Normalize an external identity into the user-id charset (``[A-Za-z0-9_-]``).
IM channel ids (Feishu/Slack/Telegram) may contain characters that
:func:`_validate_user_id` rejects. Already-safe ids pass through unchanged;
lossy ones get a short digest suffix so two distinct inputs never share a
storage bucket.
"""
if not raw:
raise ValueError("user_id must be a non-empty string.")
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
if sanitized == raw:
return raw
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
return f"{sanitized}-{digest}"
def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style.
@@ -1,104 +0,0 @@
"""Single source of truth for the config hot-reload boundary.
Bytedance/deer-flow issue #3144: gateway request dependencies resolve
``AppConfig`` through ``get_app_config()`` on every request, so per-run
fields take effect on the next message without restarting the gateway.
The fields listed in this module are the **infrastructure** subset that
the gateway captures once at startup engines, singletons, IM clients,
the logging handler and that therefore require a process restart to
change at runtime.
The registry covers two kinds of entries:
- Top-level ``AppConfig`` fields (``database``, ``checkpointer``,
``run_events``, ``stream_bridge``, ``sandbox``, ``log_level``). For
these, :func:`format_field_description` produces the standardised
``"startup-only: ..."`` prefix that the matching Pydantic
``Field(description=...)`` carries, so the boundary surfaces in IDE
hover next to the field itself.
- Top-level ``config.yaml`` sections that are not part of the
``AppConfig`` schema (``channels``). These cannot be standardised at
the schema level, so the registry is their only canonical location.
Any future "needs restart" scanner operator tooling, lint hooks, doc
generators should drive off this registry rather than re-parsing
prose.
"""
from __future__ import annotations
from collections.abc import Iterator
#: The standardised prefix every restart-required field description starts
#: with. ``test_reload_boundary`` enforces both directions: registered
#: fields must use this prefix in the schema, and any schema field using
#: this prefix must be in the registry.
STARTUP_ONLY_PREFIX = "startup-only:"
#: Restart-required field paths mapped to the human-readable reason.
#:
#: The reason text is what surfaces in ``Field(description=...)``, so it
#: must explain *what* code captures the snapshot — not just that the
#: field is restart-required — so an operator changing the value knows
#: which subsystem to restart.
STARTUP_ONLY_FIELDS: dict[str, str] = {
"database": ("init_engine_from_config() runs once during langgraph_runtime() startup; the SQLAlchemy engine holds the connection pool and is not rebuilt on config.yaml edits."),
"checkpointer": ("make_checkpointer() binds the persistent checkpointer once at startup, including SQLite WAL / busy_timeout settings."),
"run_events": ("make_run_event_store() picks the memory- vs SQL-backed implementation at startup and is frozen onto app.state.run_events_config to stay paired with the underlying event store."),
"stream_bridge": ("make_stream_bridge() constructs the stream-bridge singleton once during startup."),
"sandbox": ("get_sandbox_provider() caches the provider singleton (``_default_sandbox_provider``); a different ``sandbox.use`` class path only takes effect on next process start."),
"log_level": (
"apply_logging_level() runs only during app.py startup; it sets the deerflow/app logger levels and may lower root handler thresholds so configured messages can propagate. A freshly reloaded AppConfig does not retrigger it."
),
# Not part of the AppConfig Pydantic schema — channel credentials are
# consumed directly by ``start_channel_service()`` once at lifespan
# startup and the live channel clients are not rebuilt on
# config.yaml edits.
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
}
def iter_startup_only_field_paths() -> Iterator[str]:
"""Yield every registered restart-required field path."""
return iter(STARTUP_ONLY_FIELDS)
def is_startup_only_field(field_path: str) -> bool:
"""Return ``True`` when *field_path* is registered as restart-required.
Accepts only top-level paths (``"database"``, ``"sandbox"`` etc.);
nested keys like ``"database.url"`` are not modelled here because the
boundary is per-section, not per-leaf.
"""
return field_path in STARTUP_ONLY_FIELDS
def format_field_description(field_path: str, *, field_doc: str | None = None) -> str:
"""Build the standardised description for a registered field.
Used inside ``AppConfig`` ``Field(description=...)`` so the hover
text in IDEs matches the registry and the drift tests can pin one
side against the other.
Args:
field_path: A registered top-level field path (e.g. ``"log_level"``).
field_doc: Optional human-facing description for the field itself
(allowed values, semantics, etc.). When supplied, it is
appended after the ``startup-only:`` marker block separated by
a blank line so IDE hover shows both the restart-required
reason *and* the field's normal documentation. Composition
keeps the marker as the leading token machine-readable tooling
pivots on while restoring the prose that ``Field(description=)``
used to carry before the registry took over.
Raises:
KeyError: when *field_path* is not registered. This is deliberate
silently returning a placeholder would let a typo bypass
the drift coverage.
"""
reason = STARTUP_ONLY_FIELDS[field_path]
header = f"{STARTUP_ONLY_PREFIX} {reason}"
if field_doc is None:
return header
return f"{header}\n\n{field_doc.strip()}"
@@ -1,10 +1,6 @@
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
from .cache import (
get_cached_mcp_tools,
initialize_mcp_tools,
reset_mcp_tools_cache,
)
from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
from .client import build_server_params, build_servers_config
from .tools import get_mcp_tools
+2 -11
View File
@@ -143,20 +143,11 @@ def reset_mcp_tools_cache() -> None:
# Close persistent sessions they will be recreated by the next
# get_mcp_tools() call with the (possibly updated) connection config.
#
# close_all_sync() already picks the correct strategy per owning loop:
# * sessions owned by the *current* running loop are only *signalled*
# (their owner task runs __aexit__ once the loop regains control
# this is correct and leak-free, since the loop keeps the task alive),
# * sessions on other threads' loops are torn down deterministically,
# * idle/closed loops are handled or skipped.
# We deliberately do NOT try to synchronously wait for the current running
# loop to finish teardown here: that is a self-deadlock (the loop can only
# run the teardown after this synchronous call returns control to it).
try:
from deerflow.mcp.session_pool import get_session_pool
get_session_pool().close_all_sync()
pool = get_session_pool()
pool.close_all_sync()
except Exception:
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
@@ -8,27 +8,6 @@ This module provides a session pool that maintains persistent MCP sessions,
scoped by ``(server_name, scope_key)`` typically scope_key is the thread_id
so that consecutive tool calls share the same session and server-side state.
Sessions are evicted in LRU order when the pool reaches capacity.
Lifecycle model (owner task)
----------------------------
An MCP ``ClientSession`` is implemented on top of an ``anyio`` task group, and
anyio enforces that a cancel scope must be exited from the *same task* that
entered it. Calling ``cm.__aexit__`` from any task other than the one that ran
``cm.__aenter__`` raises::
RuntimeError: Attempted to exit cancel scope in a different task than it
was entered in
The sync-tool path (``make_sync_tool_wrapper``) drives each call through a fresh
``asyncio.run`` event loop, so a session entered while answering one call would
otherwise be exited while answering another from a different task and crash
(GitHub issue #3379).
To make this impossible, every pooled session is owned by a dedicated
``_run_session`` task. That task enters the context manager, hands the live
session back to the caller, and then *waits* on a close event. All shutdown
paths only ever **signal** that event; the owner task performs ``__aexit__``
itself, guaranteeing enter and exit always happen in the same task.
"""
from __future__ import annotations
@@ -48,81 +27,18 @@ class MCPSessionPool:
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
MAX_SESSIONS = 256
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session on a foreign loop
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
def __init__(self) -> None:
# Each entry: (session, owning_loop, owner_task, close_event).
self._entries: OrderedDict[
tuple[str, str],
tuple[
ClientSession,
asyncio.AbstractEventLoop,
asyncio.Task[Any],
asyncio.Event,
],
tuple[ClientSession, asyncio.AbstractEventLoop],
] = OrderedDict()
# In-flight creations, keyed by (server, scope). Lets concurrent callers
# on the same loop share a single creation instead of each spawning a
# duplicate session. Value: (loop, ready_future, owner_task, close_event).
self._inflight: dict[
tuple[str, str],
tuple[
asyncio.AbstractEventLoop,
asyncio.Future[ClientSession],
asyncio.Task[Any],
asyncio.Event,
],
] = {}
self._context_managers: dict[tuple[str, str], Any] = {}
# threading.Lock is not bound to any event loop, so it is safe to
# acquire from both async paths and sync/worker-thread paths.
self._lock = threading.Lock()
# ------------------------------------------------------------------
# Session owner task
# ------------------------------------------------------------------
async def _run_session(
self,
connection: dict[str, Any],
ready: asyncio.Future[ClientSession],
close_evt: asyncio.Event,
) -> None:
"""Own a single MCP session for its entire lifetime.
Enters the session context manager, initializes it, publishes the live
session via ``ready``, then blocks until ``close_evt`` is set. The
context manager is *always* exited from this task, satisfying anyio's
cancel-scope same-task requirement.
"""
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
try:
session = await cm.__aenter__()
except BaseException as e:
# Never entered the cancel scope, so there is nothing to exit.
if not ready.done():
ready.set_exception(e)
return
# The context manager is now entered. From here on __aexit__ MUST run in
# this task — on init failure, on cancellation, or on the close signal —
# to satisfy anyio's same-task cancel-scope requirement and to avoid
# leaking the session/subprocess.
try:
await session.initialize()
if not ready.done():
ready.set_result(session)
await close_evt.wait()
except BaseException as e:
if not ready.done():
ready.set_exception(e)
finally:
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session", exc_info=True)
async def get_session(
self,
server_name: str,
@@ -131,9 +47,9 @@ class MCPSessionPool:
) -> ClientSession:
"""Get or create a persistent MCP session.
If an existing session was created in a different (or closed) event
loop, it is evicted and replaced with a fresh one owned by a task on
the current loop.
If an existing session was created in a different event loop (e.g.
the sync-wrapper path), it is closed and replaced with a fresh one
in the current loop.
Args:
server_name: MCP server name.
@@ -147,118 +63,44 @@ class MCPSessionPool:
current_loop = asyncio.get_running_loop()
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
# Decide one of three outcomes atomically: return an existing session,
# join an in-flight creation, or become the creator for this key.
# Each item: (loop, owner_task, close_event, cancel). ``cancel`` is True
# for in-flight creations, whose owner may be blocked inside
# ``initialize()`` where close_evt cannot wake it — it must be cancelled.
evicted: list[tuple[asyncio.AbstractEventLoop, asyncio.Task[Any], asyncio.Event, bool]] = []
join: asyncio.Future[ClientSession] | None = None
ready: asyncio.Future[ClientSession] | None = None
close_evt: asyncio.Event | None = None
task: asyncio.Task[Any] | None = None
cms_to_close: list[tuple[tuple[str, str], Any]] = []
with self._lock:
if key in self._entries:
session, loop, ent_task, ent_close = self._entries[key]
if loop is current_loop and not loop.is_closed():
session, loop = self._entries[key]
if loop is current_loop:
self._entries.move_to_end(key)
return session
# Session belongs to a different/closed event loop evict it.
# Session belongs to a different event loop evict it.
cm = self._context_managers.pop(key, None)
self._entries.pop(key)
evicted.append((loop, ent_task, ent_close, False))
inflight = self._inflight.get(key)
if inflight is not None and inflight[0] is current_loop and not inflight[0].is_closed():
# Another caller on this loop is already creating the session;
# wait for the same result instead of building a duplicate.
join = inflight[1]
else:
if inflight is not None:
# Stale in-flight creation owned by a different/closed loop.
# Drop the record and tear its owner down; because that owner
# may be blocked inside initialize() (where close_evt cannot
# wake it), it must be cancelled. We then create a fresh
# session here.
self._inflight.pop(key)
evicted.append((inflight[0], inflight[2], inflight[3], True))
# Become the creator: publish an in-flight record before any
# await so concurrent callers join us instead of racing.
ready = current_loop.create_future()
close_evt = asyncio.Event()
task = current_loop.create_task(self._run_session(connection, ready, close_evt))
self._inflight[key] = (current_loop, ready, task, close_evt)
if cm is not None:
cms_to_close.append((key, cm))
# Evict LRU entries when at capacity.
while len(self._entries) >= self.MAX_SESSIONS:
oldest_key, (_, loop, ent_task, ent_close) = next(iter(self._entries.items()))
oldest_key = next(iter(self._entries))
cm = self._context_managers.pop(oldest_key, None)
self._entries.pop(oldest_key)
evicted.append((loop, ent_task, ent_close, False))
if cm is not None:
cms_to_close.append((oldest_key, cm))
# Phase 2: shut down evicted sessions/creations. Same-loop owners are
# awaited so they finish deterministically; foreign-loop owners are
# routed to their own loop. In every case the owner task — never this
# one — runs __aexit__. In-flight owners are cancelled (cancel=True) so a
# blocking initialize() cannot leave them hung.
for loop, ent_task, ent_close, cancel in evicted:
if loop is current_loop and not loop.is_closed():
await self._shutdown(ent_close, ent_task, cancel)
elif cancel:
await self._shutdown_entry(loop, ent_task, ent_close, cancel=True)
else:
self._signal_close(loop, ent_close)
# Phase 2b: a concurrent creation for this key is already in progress on
# this loop — share its result rather than create a duplicate session.
if join is not None:
return await asyncio.shield(join)
assert ready is not None and close_evt is not None and task is not None
# Phase 3: wait for our owner task to publish the initialized session.
try:
session = await asyncio.shield(ready)
except BaseException:
# Two distinct cases reach here:
#
# 1. The owner task failed (e.g. connect/initialize error) and
# reported it via ready.set_exception(). It is *already* in its
# finally block running cm.__aexit__ in its own task, so we must
# NOT cancel it — doing so would interrupt that cleanup. We only
# wait for it to finish unwinding.
# 2. This call itself was cancelled (CancelledError). Because of the
# shield, `ready` is still pending and the owner task is alive and
# blocked. We signal close and cancel it so it exits the cancel
# scope in its own task, then wait for it to finish.
#
# The session is never registered yet, so nobody else can close it;
# waiting here guarantees we never leak a session or owner task.
owner_already_failed = ready.done() and not ready.cancelled() and ready.exception() is not None
if not owner_already_failed:
close_evt.set()
task.cancel()
# Phase 2: async cleanup outside the lock so we never await while holding it.
for close_key, cm in cms_to_close:
try:
await asyncio.shield(task)
except BaseException:
logger.debug("Owner task ended during get_session unwind", exc_info=True)
with self._lock:
if self._inflight.get(key) == (current_loop, ready, task, close_evt):
self._inflight.pop(key)
raise
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
# Phase 4: promote the in-flight creation to a registered entry — but
# only if our in-flight record is still the live one. A concurrent
# close_* / close_all may have removed it while we were initializing; in
# that case we must NOT resurrect the session into _entries. Instead we
# own the teardown: signal our owner task and wait for it to run
# __aexit__ in its own task, then surface the cancellation.
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
session = await cm.__aenter__()
await session.initialize()
# Phase 3: register the new session under the lock.
with self._lock:
still_ours = self._inflight.get(key) == (current_loop, ready, task, close_evt)
if still_ours:
self._inflight.pop(key)
self._entries[key] = (session, current_loop, task, close_evt)
if not still_ours:
await self._shutdown(close_evt, task)
raise asyncio.CancelledError("MCP session pool was closed while the session was being created")
self._entries[key] = (session, current_loop)
self._context_managers[key] = cm
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
return session
@@ -266,169 +108,70 @@ class MCPSessionPool:
# Cleanup helpers
# ------------------------------------------------------------------
@staticmethod
def _signal_close(loop: asyncio.AbstractEventLoop, close_evt: asyncio.Event) -> None:
"""Ask an owner task to shut down without waiting.
``asyncio.Event.set`` is not thread-safe, so it is scheduled on the
owning loop. A closed loop means the owner task is already gone.
"""
if loop.is_closed():
return
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
"""Close a single context manager (must be called WITHOUT the lock)."""
try:
loop.call_soon_threadsafe(close_evt.set)
except RuntimeError:
# Loop was closed between the is_closed() check and now.
pass
async def _shutdown(
self,
close_evt: asyncio.Event,
task: asyncio.Task[Any],
cancel: bool = False,
) -> None:
"""Signal an owner task and wait for it to finish (runs on its loop).
``cancel=True`` is used for in-flight creations: the owner task may be
blocked inside ``initialize()`` where ``close_evt`` cannot wake it, so it
must be cancelled. Its ``finally`` block still runs ``__aexit__`` in its
own task, satisfying anyio's same-task cancel-scope requirement.
"""
close_evt.set()
if cancel:
task.cancel()
try:
await task
except (Exception, asyncio.CancelledError):
logger.debug("Owner task ended during shutdown", exc_info=True)
async def _shutdown_entry(
self,
loop: asyncio.AbstractEventLoop,
task: asyncio.Task[Any],
close_evt: asyncio.Event,
cancel: bool = False,
) -> None:
"""Shut down one entry, routing the close to its owning loop."""
if loop.is_closed():
return
current_loop = asyncio.get_running_loop()
if loop is current_loop:
await self._shutdown(close_evt, task, cancel)
elif loop.is_running():
future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop)
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning("Error closing MCP session on owning loop", exc_info=True)
else:
# Owning loop exists but is neither the current loop nor running.
# We are inside an async context here, so run_until_complete() would
# raise "Cannot run the event loop while another loop is running";
# and the loop may belong to another thread, where driving it from
# here is unsafe. This branch is not expected in practice — a
# session's owning loop is either the long-lived gateway loop (which
# is running) or a short-lived asyncio.run loop (which is closed and
# caught above). Fall back to a best-effort thread-safe signal so the
# owner task tears down if/when its loop runs again.
logger.warning("Owning loop for MCP session is idle; signalling close best-effort. Session may leak until the loop runs again.")
self._signal_close(loop, close_evt)
if cancel:
try:
loop.call_soon_threadsafe(task.cancel)
except RuntimeError:
pass
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", key, exc_info=True)
async def close_scope(self, scope_key: str) -> None:
"""Close all sessions for a given scope (e.g. thread_id)."""
with self._lock:
keys = [k for k in self._entries if k[1] == scope_key]
entries = [(self._entries.pop(k)) for k in keys]
inflight_keys = [k for k in self._inflight if k[1] == scope_key]
inflight = [self._inflight.pop(k) for k in inflight_keys]
for _session, loop, task, close_evt in entries:
await self._shutdown_entry(loop, task, close_evt)
for loop, _ready, task, close_evt in inflight:
await self._shutdown_entry(loop, task, close_evt, cancel=True)
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_server(self, server_name: str) -> None:
"""Close all sessions for a given server."""
with self._lock:
keys = [k for k in self._entries if k[0] == server_name]
entries = [(self._entries.pop(k)) for k in keys]
inflight_keys = [k for k in self._inflight if k[0] == server_name]
inflight = [self._inflight.pop(k) for k in inflight_keys]
for _session, loop, task, close_evt in entries:
await self._shutdown_entry(loop, task, close_evt)
for loop, _ready, task, close_evt in inflight:
await self._shutdown_entry(loop, task, close_evt, cancel=True)
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_all(self) -> None:
"""Close every managed session."""
with self._lock:
entries = list(self._entries.values())
cms = list(self._context_managers.items())
self._context_managers.clear()
self._entries.clear()
inflight = list(self._inflight.values())
self._inflight.clear()
for _session, loop, task, close_evt in entries:
await self._shutdown_entry(loop, task, close_evt)
for loop, _ready, task, close_evt in inflight:
await self._shutdown_entry(loop, task, close_evt, cancel=True)
for key, cm in cms:
await self._close_cm(key, cm)
def close_all_sync(self) -> None:
"""Close all sessions on their owning event loops (synchronous).
"""Close all sessions using their owning event loops (synchronous).
Each session is closed by its owner task on the loop it was created in,
avoiding cross-loop and cross-task errors. Safe to call from any thread
without an active event loop.
Closing semantics differ by where the owning loop runs:
* Owning loop is idle, or running on another thread this call blocks
until teardown completes (or ``SESSION_CLOSE_TIMEOUT`` elapses).
* Owning loop is the one currently running on *this* thread we cannot
block on it without deadlocking, so teardown is only *signalled* here
and completes asynchronously once control returns to that loop. The
caller must therefore keep that loop running afterwards; if it stops
the loop immediately, the owner task's ``__aexit__`` may not run. When
a deterministic close is required from inside a running loop, ``await
close_all()`` instead.
Each session is closed on the loop it was created in, avoiding
cross-loop resource leaks. Safe to call from any thread without an
active event loop.
"""
with self._lock:
entries = list(self._entries.values())
entries = list(self._entries.items())
cms = dict(self._context_managers)
self._entries.clear()
inflight = list(self._inflight.values())
self._inflight.clear()
self._context_managers.clear()
# Entries are initialized (gentle close_evt path). In-flight creations
# may be blocked mid-init, so they are cancelled to unblock teardown.
owners = [(loop, task, close_evt, False) for _s, loop, task, close_evt in entries]
owners += [(loop, task, close_evt, True) for loop, _r, task, close_evt in inflight]
try:
current_running_loop = asyncio.get_running_loop()
except RuntimeError:
current_running_loop = None
for loop, task, close_evt, cancel in owners:
if loop.is_closed():
for key, (_, loop) in entries:
cm = cms.get(key)
if cm is None or loop.is_closed():
continue
try:
if loop is current_running_loop:
# We are executing inside this loop's thread, so synchronously
# waiting on run_coroutine_threadsafe(...).result() would
# deadlock until timeout. Signal the owner task directly and
# let it finish once this synchronous call returns control to
# the running loop.
close_evt.set()
if cancel:
task.cancel()
elif loop.is_running():
# Schedule the shutdown on the owning loop from this thread.
future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop)
if loop.is_running():
# Schedule on the owning loop from this (different) thread.
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
else:
loop.run_until_complete(self._shutdown(close_evt, task, cancel))
loop.run_until_complete(cm.__aexit__(None, None, None))
except Exception:
logger.debug("Error closing MCP session during sync close", exc_info=True)
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
# ------------------------------------------------------------------
+1 -10
View File
@@ -3,7 +3,6 @@
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any
from langchain_core.tools import BaseTool, StructuredTool
@@ -138,15 +137,7 @@ def _make_session_pool_tool(
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
async def base_handler(request: MCPToolCallRequest) -> Any:
# Preserve interceptor-injected headers for stdio MCP calls by
# forwarding them through MCP call meta.
call_kwargs: dict[str, Any] = {}
if request.headers:
if isinstance(request.headers, Mapping):
call_kwargs["meta"] = {"headers": dict(request.headers)}
else:
logger.warning("Ignoring MCP interceptor headers with unsupported type: %s", type(request.headers).__name__)
return await session.call_tool(request.name, request.args, **call_kwargs)
return await session.call_tool(request.name, request.args)
handler = base_handler
for interceptor in reversed(tool_interceptors):
@@ -47,38 +47,6 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True
# Default chunk-gap budget for OpenAI-compatible streaming responses.
#
# langchain-openai raises ``StreamChunkTimeoutError`` after this many seconds
# without receiving a chunk. Its own default is 60s, which is too aggressive for
# reasoning models (DeepSeek-R1, Doubao-thinking, GPT-5) whose first chunk can
# legitimately take 90~150s. We default to 240s so the streaming layer rarely
# trips on long thinking pauses; the LLMErrorHandlingMiddleware still retries
# (budget=2) if a real stall happens. Users can override per-model in config.yaml.
_DEFAULT_STREAM_CHUNK_TIMEOUT_SECONDS: float = 240.0
def _apply_stream_chunk_timeout_default(model_use_path: str, model_settings_from_config: dict) -> None:
"""Inject a generous ``stream_chunk_timeout`` for OpenAI-compatible clients.
The ``stream_chunk_timeout`` kwarg is specific to ``langchain_openai:ChatOpenAI``
and is rejected by other providers' constructors as an unexpected keyword
argument. Behaviour:
* OpenAI-compatible path: an explicit value in ``config.yaml`` is preserved.
An explicit ``null`` is dropped upstream by ``model_dump(exclude_none=True)``
and therefore treated as "unset", so the default is injected.
* Non-OpenAI path: drop the key so it is never forwarded to an incompatible
constructor (which would raise ``TypeError: unexpected keyword argument``).
"""
if model_use_path != "langchain_openai:ChatOpenAI":
model_settings_from_config.pop("stream_chunk_timeout", None)
return
if "stream_chunk_timeout" in model_settings_from_config:
return
model_settings_from_config["stream_chunk_timeout"] = _DEFAULT_STREAM_CHUNK_TIMEOUT_SECONDS
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, attach_tracing: bool = True, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
@@ -160,7 +128,6 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
model_settings_from_config.pop("reasoning_effort", None)
_enable_stream_usage_by_default(model_config.use, model_settings_from_config)
_apply_stream_chunk_timeout_default(model_config.use, model_settings_from_config)
# For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel
@@ -114,27 +114,8 @@ class PatchedChatMiniMax(ChatOpenAI):
}
else:
payload["extra_body"] = {"reasoning_split": True}
self._strip_user_message_names(payload)
return payload
@staticmethod
def _strip_user_message_names(payload: dict) -> None:
"""Drop the per-message ``name`` field from user-role messages.
DeerFlow middlewares tag user messages with internal provenance names
(``user-input``, ``summary``, ``loop_warning``, ...). ``langchain_openai``
serializes those into the OpenAI-compatible request, but MiniMax requires
every user-role ``name`` to be identical and otherwise rejects the request
with ``invalid params, user name must be consistent (2013)``. MiniMax does
not use the per-message author name, so strip it.
"""
messages = payload.get("messages")
if not isinstance(messages, list):
return
for message in messages:
if isinstance(message, dict) and message.get("role") == "user":
message.pop("name", None)
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
@@ -1,175 +0,0 @@
"""Patched ChatOpenAI adapter for StepFun reasoning models.
StepFun returns ``reasoning`` (or ``reasoning_content`` with deepseek-style) in
both streaming deltas and non-streaming responses. Standard ``ChatOpenAI``
ignores these non-standard fields, so reasoning content is silently dropped.
This adapter captures reasoning from all response paths and replays it on
historical assistant messages for multi-turn tool-call conversations.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import (
restore_assistant_payloads,
restore_reasoning_content,
)
_MISSING = object()
def _extract_reasoning(value: Any) -> str | object:
"""Return reasoning content from a dict/Pydantic object.
StepFun may return reasoning via ``reasoning`` (default) or
``reasoning_content`` (deepseek-style). Check both fields.
"""
if isinstance(value, Mapping):
# Check reasoning_content first (deepseek-style), then reasoning (default)
for field in ("reasoning_content", "reasoning"):
if field in value and value[field] is not None:
return value[field]
return _MISSING
# Pydantic / SDK object attributes
for field in ("reasoning_content", "reasoning"):
attr = getattr(value, field, _MISSING)
if attr is not _MISSING and attr is not None:
return attr
# Some SDK versions store extra fields in model_extra
model_extra = getattr(value, "model_extra", None)
if isinstance(model_extra, Mapping):
for field in ("reasoning_content", "reasoning"):
if field in model_extra and model_extra[field] is not None:
return model_extra[field]
return _MISSING
def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk:
"""Return a copy of *message* with reasoning_content stored in additional_kwargs."""
additional_kwargs = dict(message.additional_kwargs)
if additional_kwargs.get("reasoning_content") != reasoning:
additional_kwargs["reasoning_content"] = reasoning
return message.model_copy(update={"additional_kwargs": additional_kwargs})
def _get_typed_choice_message(response: Any, index: int) -> Any:
"""Extract the SDK-typed choice message at *index*, if available."""
choices = getattr(response, "choices", None)
if choices is None:
return None
try:
return choices[index].message
except (AttributeError, IndexError, TypeError):
return None
class PatchedChatStepFun(ChatOpenAI):
"""ChatOpenAI with full reasoning support for StepFun models.
Captures ``reasoning`` / ``reasoning_content`` from both streaming and
non-streaming responses and replays it on historical assistant messages in
multi-turn tool-call conversations.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_secrets(self) -> dict[str, str]:
return {"api_key": "STEPFUN_API_KEY", "openai_api_key": "STEPFUN_API_KEY"}
# --- Request payload replay ---
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Restore ``reasoning_content`` on historical assistant messages."""
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload
# --- Streaming reasoning capture ---
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
"""Capture ``reasoning`` / ``reasoning_content`` from streaming deltas."""
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info,
)
if generation_chunk is None:
return None
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("delta") or {}
reasoning = _extract_reasoning(delta)
if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk):
generation_chunk = ChatGenerationChunk(
message=_with_reasoning_content(generation_chunk.message, reasoning),
generation_info=generation_chunk.generation_info,
)
return generation_chunk
# --- Non-streaming reasoning capture ---
def _create_chat_result(
self,
response: dict | Any,
generation_info: dict | None = None,
) -> ChatResult:
"""Extract ``reasoning`` / ``reasoning_content`` from non-streaming responses."""
result = super()._create_chat_result(response, generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
choices = response_dict.get("choices", [])
patched_generations: list[ChatGeneration] | None = None
for index, generation in enumerate(result.generations):
choice = choices[index] if index < len(choices) else {}
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
reasoning = _extract_reasoning(choice_message)
if reasoning is _MISSING and not isinstance(response, dict):
reasoning = _extract_reasoning(_get_typed_choice_message(response, index))
message = generation.message
if reasoning is not _MISSING and isinstance(message, AIMessage):
if patched_generations is None:
patched_generations = list(result.generations)
patched_generations[index] = ChatGeneration(
message=_with_reasoning_content(message, reasoning),
generation_info=generation.generation_info,
)
return ChatResult(
generations=patched_generations or result.generations,
llm_output=result.llm_output,
)
@@ -47,41 +47,6 @@ def _prepare_database_sqlite_checkpointer_path(db_config) -> str:
return conn_str
def _build_postgres_pool(conn_string: str):
"""Build an AsyncConnectionPool with TCP keepalive and connection checking."""
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
return AsyncConnectionPool(
conn_string,
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 6,
},
check=AsyncConnectionPool.check_connection,
)
def _ensure_postgres_imports():
"""Import and return (AsyncPostgresSaver, AsyncConnectionPool), raising ImportError on failure."""
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
try:
from psycopg_pool import AsyncConnectionPool
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
return AsyncPostgresSaver, AsyncConnectionPool
# ---------------------------------------------------------------------------
# Async factory
# ---------------------------------------------------------------------------
@@ -109,13 +74,15 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
AsyncPostgresSaver, _ = _ensure_postgres_imports()
pool = _build_postgres_pool(config.connection_string)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
await saver.setup()
yield saver
return
@@ -150,13 +117,15 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
return
if db_config.backend == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
AsyncPostgresSaver, _ = _ensure_postgres_imports()
pool = _build_postgres_pool(db_config.postgres_url)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
await saver.setup()
yield saver
return
@@ -21,13 +21,12 @@ from __future__ import annotations
import contextlib
import logging
import threading
from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig, ensure_config_loaded
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -101,7 +100,6 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
_checkpointer_lock = threading.Lock()
def get_checkpointer() -> Checkpointer:
@@ -118,29 +116,34 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
# Config loading can reset both persistence singletons. Keep it outside
# this provider lock to avoid cross-provider lock-order inversion.
ensure_config_loaded()
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
with _checkpointer_lock:
if _checkpointer is not None:
return _checkpointer
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
checkpointer_ctx = _sync_checkpointer_cm(config)
checkpointer = checkpointer_ctx.__enter__()
_checkpointer_ctx = checkpointer_ctx
_checkpointer = checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
@@ -152,14 +155,13 @@ def reset_checkpointer() -> None:
Useful in tests or after a configuration change.
"""
global _checkpointer, _checkpointer_ctx
with _checkpointer_lock:
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
# ---------------------------------------------------------------------------
@@ -86,8 +86,6 @@ class RunJournal(BaseCallbackHandler):
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
self._msg_count = 0
self._had_llm_error_fallback = False
self._llm_error_fallback_message: str | None = None
# Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
@@ -258,18 +256,6 @@ class RunJournal(BaseCallbackHandler):
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
additional_kwargs = getattr(message, "additional_kwargs", None) or {}
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
self._had_llm_error_fallback = True
detail = additional_kwargs.get("error_detail")
reason = additional_kwargs.get("error_reason")
fallback_text = self._message_text(message).strip()
if isinstance(detail, str) and detail.strip():
self._llm_error_fallback_message = detail.strip()
elif isinstance(reason, str) and reason.strip():
self._llm_error_fallback_message = reason.strip()
elif fallback_text:
self._llm_error_fallback_message = fallback_text[:2000]
# Resolve call index
call_index = self._llm_call_index
@@ -583,11 +569,3 @@ class RunJournal(BaseCallbackHandler):
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
}
@property
def had_llm_error_fallback(self) -> bool:
return self._had_llm_error_fallback
@property
def llm_error_fallback_message(self) -> str | None:
return self._llm_error_fallback_message
@@ -1,16 +1,39 @@
"""Run lifecycle management for LangGraph Platform API compatibility."""
from .domain import (
AssistantId,
CancelAction,
DisconnectMode,
EventSeq,
InvalidRunTransition,
MultitaskStrategy,
Run,
RunId,
RunScope,
RunStatus,
ThreadId,
UserId,
)
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import RunContext, run_agent
__all__ = [
"AssistantId",
"CancelAction",
"ConflictError",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskStrategy",
"Run",
"RunContext",
"RunId",
"RunManager",
"RunRecord",
"RunScope",
"RunStatus",
"ThreadId",
"UnsupportedStrategyError",
"UserId",
"run_agent",
]
@@ -0,0 +1,20 @@
"""Application-layer DTOs and services for run runtime use cases."""
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle, StoredRunEvent
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
from .services import RunsApplicationService
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"GetRunQuery",
"JoinRunStreamCommand",
"ListRunMessagesQuery",
"ListRunsQuery",
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"RunsApplicationService",
"StoredRunEvent",
]
@@ -0,0 +1,46 @@
"""Application command DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
from ..domain import AssistantId, CancelAction, DisconnectMode, MultitaskStrategy, RunId, RunScope, ThreadId
@dataclass(frozen=True)
class CreateRunCommand:
thread_id: ThreadId
assistant_id: AssistantId | None = None
input: dict[str, Any] | None = None
command: dict[str, Any] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
config: dict[str, Any] = field(default_factory=dict)
context: dict[str, Any] = field(default_factory=dict)
scope: RunScope = RunScope.stateful
on_disconnect: DisconnectMode = DisconnectMode.cancel
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
stream_mode: list[str] | str | None = None
stream_subgraphs: bool = False
interrupt_before: list[str] | Literal["*"] | None = None
interrupt_after: list[str] | Literal["*"] | None = None
@dataclass(frozen=True)
class CancelRunCommand:
run_id: RunId
action: CancelAction = CancelAction.interrupt
wait: bool = False
@dataclass(frozen=True)
class JoinRunStreamCommand:
run_id: RunId
last_event_id: str | None = None
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"JoinRunStreamCommand",
]
@@ -0,0 +1,76 @@
"""Application output DTOs for run use cases."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
from ..domain import AssistantId, EventSeq, Run, RunId, RunStatus, ThreadId
@dataclass(frozen=True)
class RunSnapshot:
run_id: RunId
thread_id: ThreadId
assistant_id: AssistantId | None = None
status: RunStatus = RunStatus.pending
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
error: str | None = None
model_name: str | None = None
@classmethod
def from_run(cls, run: Run) -> RunSnapshot:
return cls(
run_id=run.run_id,
thread_id=run.thread_id,
assistant_id=run.assistant_id,
status=run.status,
metadata=dict(run.metadata),
kwargs=dict(run.kwargs),
created_at=run.created_at,
updated_at=run.updated_at,
error=run.error,
model_name=run.model_name,
)
@dataclass(frozen=True)
class RunMessageView:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class StoredRunEvent:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
category: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class RunStreamHandle:
run_id: RunId
thread_id: ThreadId
events: AsyncIterator[Any]
__all__ = [
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"StoredRunEvent",
]
@@ -0,0 +1,37 @@
"""Application query DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..domain import RunId, ThreadId, UserId
@dataclass(frozen=True)
class GetRunQuery:
run_id: RunId
thread_id: ThreadId | None = None
user_id: UserId | None = None
@dataclass(frozen=True)
class ListRunsQuery:
thread_id: ThreadId
user_id: UserId | None = None
limit: int = 100
@dataclass(frozen=True)
class ListRunMessagesQuery:
thread_id: ThreadId
run_id: RunId
limit: int = 50
before_seq: int | None = None
after_seq: int | None = None
__all__ = [
"GetRunQuery",
"ListRunMessagesQuery",
"ListRunsQuery",
]
@@ -0,0 +1,74 @@
"""Application service skeleton for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..execution import RunExecutionScheduler, RunSupervisor
from ..repositories import RunEventLog, RunRepository
from ..streams import RunStreamBroker
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
@dataclass
class RunsApplicationService:
"""Use-case orchestration boundary for run runtime operations.
PR1 only introduces the boundary and dependency shape. Existing Gateway
handlers continue to call the legacy service functions until later PRs move
behavior into this class.
"""
run_repository: RunRepository
run_event_log: RunEventLog
stream_broker: RunStreamBroker
scheduler: RunExecutionScheduler
supervisor: RunSupervisor
async def create_background(self, command: CreateRunCommand) -> RunSnapshot:
# PR1 defines the application boundary; later PRs move Gateway runtime
# behavior behind this method.
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_stream(self, command: CreateRunCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_wait(self, command: CreateRunCommand) -> RunSnapshot:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def join_stream(self, command: JoinRunStreamCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def cancel(self, command: CancelRunCommand) -> bool:
return await self.supervisor.cancel(command.run_id, action=command.action)
async def get_run(self, query: GetRunQuery) -> RunSnapshot | None:
run = await self.run_repository.get(query.run_id, user_id=query.user_id)
if run is None:
return None
if query.thread_id is not None and run.thread_id != query.thread_id:
return None
return RunSnapshot.from_run(run)
async def list_runs(self, query: ListRunsQuery) -> list[RunSnapshot]:
return await self.run_repository.list_by_thread(
query.thread_id,
user_id=query.user_id,
limit=query.limit,
)
async def list_run_messages(self, query: ListRunMessagesQuery) -> list[RunMessageView]:
return await self.run_event_log.list_messages_by_run(
query.thread_id,
query.run_id,
limit=query.limit,
before_seq=query.before_seq,
after_seq=query.after_seq,
)
__all__ = [
"RunsApplicationService",
]
@@ -0,0 +1,33 @@
"""Run runtime domain model."""
from .errors import InvalidRunTransition, RunDomainError
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, UserId
from .model import Run
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
__all__ = [
"AssistantId",
"CancelAction",
"CancelPolicy",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskDecision",
"MultitaskPolicy",
"MultitaskStrategy",
"Run",
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunDomainError",
"RunEvent",
"RunFailed",
"RunId",
"RunScope",
"RunStarted",
"RunStatus",
"ThreadId",
"UserId",
]
@@ -0,0 +1,24 @@
"""Domain-level errors for run lifecycle operations."""
from __future__ import annotations
from .value_objects import RunStatus
class RunDomainError(Exception):
"""Base class for run runtime domain errors."""
class InvalidRunTransition(RunDomainError):
"""Raised when a run status transition violates lifecycle rules."""
def __init__(self, current: RunStatus, target: RunStatus) -> None:
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
self.current = current
self.target = target
__all__ = [
"InvalidRunTransition",
"RunDomainError",
]
@@ -0,0 +1,64 @@
"""Domain events emitted by the run aggregate."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .identifiers import AssistantId, RunId, ThreadId
from .value_objects import CancelAction, RunStatus
@dataclass(frozen=True)
class RunCreated:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
assistant_id: AssistantId | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class RunStarted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunCompleted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunFailed:
run_id: RunId
thread_id: ThreadId
status: RunStatus
occurred_at: str = field(default_factory=now_iso)
error: str | None = None
@dataclass(frozen=True)
class RunCancelled:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
action: CancelAction = CancelAction.interrupt
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
__all__ = [
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunEvent",
"RunFailed",
"RunStarted",
]
@@ -0,0 +1,27 @@
"""Lightweight identifiers for the run runtime domain."""
from __future__ import annotations
from typing import NewType
RunId = NewType("RunId", str)
ThreadId = NewType("ThreadId", str)
AssistantId = NewType("AssistantId", str)
UserId = NewType("UserId", str)
def require_non_empty(value: str, *, field_name: str) -> str:
"""Return a stripped identifier value, rejecting empty identifiers."""
normalized = value.strip()
if not normalized:
raise ValueError(f"{field_name} must not be empty")
return normalized
__all__ = [
"AssistantId",
"RunId",
"ThreadId",
"UserId",
"require_non_empty",
]
@@ -0,0 +1,193 @@
"""Run aggregate root and lifecycle invariants."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .errors import InvalidRunTransition
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
# Keep lifecycle transitions explicit so later application code cannot invent
# ad hoc status moves outside the aggregate.
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
RunStatus.pending: frozenset(
{
RunStatus.running,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.running: frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.success: frozenset(),
RunStatus.error: frozenset(),
RunStatus.timeout: frozenset(),
RunStatus.interrupted: frozenset(),
}
@dataclass
class Run:
"""Run aggregate root.
The aggregate owns lifecycle invariants only. Infrastructure concerns such
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
this model.
"""
run_id: RunId
thread_id: ThreadId
status: RunStatus
assistant_id: AssistantId | None = None
scope: RunScope = RunScope.stateful
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = field(default_factory=now_iso)
updated_at: str = field(default_factory=now_iso)
error: str | None = None
model_name: str | None = None
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
def __post_init__(self) -> None:
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
if self.assistant_id is not None:
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
@classmethod
def create(
cls,
*,
run_id: RunId,
thread_id: ThreadId,
assistant_id: AssistantId | None = None,
scope: RunScope = RunScope.stateful,
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
model_name: str | None = None,
created_at: str | None = None,
) -> Run:
timestamp = created_at or now_iso()
run = cls(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
scope=scope,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=timestamp,
updated_at=timestamp,
model_name=model_name,
)
run._record_event(
RunCreated(
run_id=run.run_id,
thread_id=run.thread_id,
occurred_at=timestamp,
assistant_id=run.assistant_id,
metadata=dict(run.metadata),
)
)
return run
@property
def is_terminal(self) -> bool:
return not _ALLOWED_TRANSITIONS[self.status]
def pull_events(self) -> tuple[RunEvent, ...]:
# Domain events are drained by the application layer after the aggregate
# has accepted a state change.
events = tuple(self._pending_events)
self._pending_events.clear()
return events
def mark_started(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.running, at=at)
def mark_completed(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.success, at=at)
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.error, error=error, at=at)
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.timeout, error=error, at=at)
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
self._transition_to(RunStatus.interrupted, action=action, at=at)
def _transition_to(
self,
target: RunStatus,
*,
error: str | None = None,
action: CancelAction = CancelAction.interrupt,
at: str | None = None,
) -> None:
if target == self.status:
return
if target not in _ALLOWED_TRANSITIONS[self.status]:
raise InvalidRunTransition(self.status, target)
timestamp = at or now_iso()
self.status = target
self.updated_at = timestamp
if error is not None:
self.error = error
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
def _event_for_transition(
self,
target: RunStatus,
occurred_at: str,
*,
error: str | None,
action: CancelAction,
) -> RunEvent:
# Keep event construction next to the transition rules so a new status
# cannot be added without an explicit durable event shape.
if target == RunStatus.running:
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target == RunStatus.success:
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target in (RunStatus.error, RunStatus.timeout):
return RunFailed(
run_id=self.run_id,
thread_id=self.thread_id,
status=target,
occurred_at=occurred_at,
error=error,
)
if target == RunStatus.interrupted:
return RunCancelled(
run_id=self.run_id,
thread_id=self.thread_id,
occurred_at=occurred_at,
action=action,
)
raise InvalidRunTransition(self.status, target)
def _record_event(self, event: RunEvent) -> None:
self._pending_events.append(event)
__all__ = [
"Run",
"RunStatus",
]
@@ -0,0 +1,50 @@
"""Domain policies for run concurrency and cancellation."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from enum import StrEnum
from .model import Run
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
class MultitaskDecision(StrEnum):
"""Application-level decision produced by a multitask policy."""
allow = "allow"
reject = "reject"
cancel_existing = "cancel_existing"
enqueue = "enqueue"
@dataclass(frozen=True)
class MultitaskPolicy:
strategy: MultitaskStrategy = MultitaskStrategy.reject
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
if not inflight:
return MultitaskDecision.allow
if self.strategy == MultitaskStrategy.reject:
return MultitaskDecision.reject
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
return MultitaskDecision.cancel_existing
return MultitaskDecision.enqueue
@dataclass(frozen=True)
class CancelPolicy:
action: CancelAction = CancelAction.interrupt
@property
def rolls_back_checkpoint(self) -> bool:
return self.action == CancelAction.rollback
__all__ = [
"CancelPolicy",
"MultitaskDecision",
"MultitaskPolicy",
]
@@ -0,0 +1,88 @@
"""Domain value objects for run lifecycle semantics."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
class RunScope(StrEnum):
"""Conversation scope for a run."""
stateful = "stateful"
stateless = "stateless"
temporary_thread = "temporary_thread"
class MultitaskStrategy(StrEnum):
"""Concurrency strategy for a new run on a thread."""
reject = "reject"
interrupt = "interrupt"
rollback = "rollback"
enqueue = "enqueue"
class CancelAction(StrEnum):
"""Cancellation action requested by an API or supervisor."""
interrupt = "interrupt"
rollback = "rollback"
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
)
def is_terminal_status(status: RunStatus) -> bool:
return status in TERMINAL_RUN_STATUSES
@dataclass(frozen=True, order=True)
class EventSeq:
"""Thread-local event sequence number."""
value: int
def __post_init__(self) -> None:
if self.value < 0:
raise ValueError("EventSeq must be non-negative")
def next(self) -> EventSeq:
return EventSeq(self.value + 1)
__all__ = [
"CancelAction",
"DisconnectMode",
"EventSeq",
"MultitaskStrategy",
"RunScope",
"RunStatus",
"TERMINAL_RUN_STATUSES",
"is_terminal_status",
]
@@ -0,0 +1,12 @@
"""Execution contracts for run lifecycle orchestration."""
from .executor import RunExecutor
from .scheduler import RunExecutionHandle, RunExecutionScheduler
from .supervisor import RunSupervisor
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
"RunExecutor",
"RunSupervisor",
]
@@ -0,0 +1,19 @@
"""Run executor contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import Run
class RunExecutor(Protocol):
"""Executes one run against the underlying agent or graph runtime."""
async def execute(self, run: Run) -> None:
pass
__all__ = [
"RunExecutor",
]
@@ -0,0 +1,26 @@
"""Run execution scheduler contract."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunExecutionHandle:
run_id: RunId
class RunExecutionScheduler(Protocol):
"""Starts background execution for an accepted run."""
async def start(self, run_id: RunId) -> RunExecutionHandle:
pass
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
]
@@ -0,0 +1,19 @@
"""Run execution supervision contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import CancelAction, RunId
class RunSupervisor(Protocol):
"""Controls lifecycle operations for already scheduled runs."""
async def cancel(self, run_id: RunId, *, action: CancelAction = CancelAction.interrupt) -> bool:
pass
__all__ = [
"RunSupervisor",
]
@@ -645,98 +645,6 @@ class RunManager:
self._runs.pop(run_id, None)
logger.debug("Run record %s cleaned up", run_id)
async def shutdown(self, *, timeout: float = 5.0) -> None:
"""Cancel and bounded-await all in-flight runs on process shutdown.
Chat runs execute in fire-and-forget background ``asyncio`` tasks that
write checkpoints through a shared checkpointer. On shutdown the
checkpointer's resources (e.g. the postgres connection pool owned by the
gateway's ``AsyncExitStack``) are torn down; if a run task is still
mid-graph at that point, langgraph's
``AsyncPregelLoop._checkpointer_put_after_previous`` runs its
``finally: await checkpointer.aput(...)`` against the closed pool. Because
that put runs in a langgraph-internal task (not on ``run_agent``'s call
stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the
worker and surfaces as an unhandled exception during ``asyncio.run()``
shutdown (bytedance/deer-flow issue #3373).
Draining in-flight runs *before* the checkpointer is closed lets each
run that settles within ``timeout`` flush its final checkpoint while
resources are still open. Only runs that do **not** settle on their own
are marked ``interrupted`` a run that completes (e.g. ``success``)
during the drain keeps its real terminal status instead of being
blanket-overwritten. The whole drain, including the trailing status
persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a
slow store under DB pressure) cannot hang worker shutdown the
precondition for the signal-reentrancy deadlock guarded by
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active
after ``timeout`` are logged and may still race teardown.
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
async with self._lock:
inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()]
for record in inflight:
record.abort_action = "interrupt"
record.abort_event.set()
record.task.cancel() # type: ignore[union-attr] # filtered above
# Status is decided AFTER the drain (below), not here: a run that
# completes on its own during the drain must keep its real status.
if not inflight:
return
tasks = [record.task for record in inflight]
_, pending = await asyncio.wait(tasks, timeout=timeout)
# Only mark/persist ``interrupted`` for runs that did not settle on their
# own (still pending after the timeout, or ended cancelled). A run that
# finished normally during the drain keeps the status it set for itself.
to_persist: list[RunRecord] = []
async with self._lock:
for record in inflight:
task = record.task
if task not in pending and not task.cancelled():
# Completed on its own — retrieve any surfaced exception so it
# is not reported as "never retrieved", and keep its status.
task.exception() # type: ignore[union-attr] # done & not cancelled
continue
if record.status in (RunStatus.pending, RunStatus.running):
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
to_persist.append(record)
# Bound the trailing status persistence within the remaining budget so a
# slow store (``_call_store_with_retry`` can back off under DB pressure)
# cannot push shutdown past ``timeout``.
if to_persist:
remaining = deadline - loop.time()
if remaining <= 0:
logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist))
else:
try:
results = await asyncio.wait_for(
asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True),
timeout=remaining,
)
except TimeoutError:
logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist))
else:
# ``_persist_status`` is best-effort: it catches and logs its
# own failures, returning ``False``. Inspect the aggregate so a
# partial failure is surfaced at shutdown level (with the
# run_id) instead of being silently swallowed by the gather.
for record, result in zip(to_persist, results):
if isinstance(result, Exception):
logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result)
elif result is False:
logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id)
if pending:
logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending))
logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout)
class ConflictError(Exception):
"""Raised when multitask_strategy=reject and thread has inflight runs."""
@@ -0,0 +1,9 @@
"""Repository contracts for the run runtime application layer."""
from .run_event_log import RunEventLog
from .run_repository import RunRepository
__all__ = [
"RunEventLog",
"RunRepository",
]

Some files were not shown because too many files have changed in this diff Show More