Compare commits

..

2 Commits

Author SHA1 Message Date
rayhpeng 9ed83c84dc fix(runtime): use pass for protocol stubs 2026-06-01 15:31:46 +08:00
rayhpeng 30bb2d5149 refactor(runtime): add run DDD boundary skeleton 2026-06-01 09:22:32 +08:00
136 changed files with 2846 additions and 8035 deletions
-159
View File
@@ -1,159 +0,0 @@
name: 🐛 Bug report
description: Report something that isn't working so maintainers can reproduce and fix it.
title: "[bug] "
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to file a bug. A clear, reproducible report is the
single biggest factor in how fast it gets fixed.
Please fill in every required field — especially **reproduction steps** and **logs**.
- type: checkboxes
id: preflight
attributes:
label: Before you start
options:
- label: I searched [existing issues](https://github.com/bytedance/deer-flow/issues?q=is%3Aissue) and this is not a duplicate.
required: true
- label: I can reproduce this on the latest `main`.
required: false
- type: input
id: summary
attributes:
label: Problem summary
description: One sentence describing the bug.
placeholder: e.g. make dev fails to start the gateway service
validations:
required: true
- type: dropdown
id: areas
attributes:
label: Affected area(s)
description: Which part of DeerFlow does this touch? Select all that apply.
multiple: true
options:
- Frontend (UI / Next.js)
- Backend API (gateway / endpoints / SSE)
- Agents / LangGraph (graph, prompts, langgraph.json)
- Sandbox / Docker
- Skills
- MCP
- Config / setup (make, config.yaml, env)
- Docs
- Not sure
validations:
required: true
- type: textarea
id: actual
attributes:
label: What happened?
description: The actual behavior. Include the key error lines verbatim.
placeholder: When I do X, I expected Y but I got Z.
validations:
required: true
- type: textarea
id: expected
attributes:
label: Expected behavior
placeholder: What did you expect to happen instead?
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: Exact commands and sequence. Minimal steps that reliably reproduce the problem.
placeholder: |
1. make check
2. make install
3. make dev
4. ...
validations:
required: true
- type: textarea
id: logs
attributes:
label: Relevant logs
description: Paste key lines from logs (for example `logs/gateway.log`, `logs/frontend.log`). Redact secrets.
render: shell
validations:
required: true
- type: dropdown
id: run_mode
attributes:
label: How are you running DeerFlow?
options:
- Local (make dev)
- Docker (make docker-start)
- CI
- Other
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
options:
- macOS
- Linux
- Windows
- Other
validations:
required: true
- type: input
id: platform_details
attributes:
label: Platform details
description: Architecture and shell, if relevant.
placeholder: e.g. arm64, zsh
- type: input
id: python_version
attributes:
label: Python version
placeholder: e.g. Python 3.12.9
- type: input
id: node_version
attributes:
label: Node.js version
placeholder: e.g. v22.11.0
- type: input
id: pnpm_version
attributes:
label: pnpm version
placeholder: e.g. 10.26.2
- type: input
id: uv_version
attributes:
label: uv version
placeholder: e.g. 0.7.20
- type: textarea
id: git_info
attributes:
label: Git state
description: Output of `git branch --show-current` and the latest commit SHA.
placeholder: |
branch: feature/my-branch
commit: abcdef1
- type: textarea
id: additional
attributes:
label: Additional context
description: Screenshots, related issues, config snippets (redacted), or anything else that helps triage.
-11
View File
@@ -1,11 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: 💬 Questions & usage help
url: https://github.com/bytedance/deer-flow/discussions/categories/q-a
about: "How do I use X? Why does Y behave like that? Ask in Discussions — it gets answered faster and stays searchable."
- name: 💡 Ideas & proposals
url: https://github.com/bytedance/deer-flow/discussions/categories/ideas
about: Have a half-formed idea? Float it in Discussions before opening a formal feature request.
- name: 🔒 Report a security vulnerability
url: https://github.com/bytedance/deer-flow/security/policy
about: Do not open a public issue for security problems. Follow the security policy instead.
@@ -1,67 +0,0 @@
name: 💡 Feature request
description: Propose a new capability or an improvement to an existing one.
title: "[feat] "
labels: ["enhancement"]
body:
- type: markdown
attributes:
value: |
Thanks for the suggestion. For non-trivial features, please open a
[Discussion](https://github.com/bytedance/deer-flow/discussions/categories/ideas)
first to align on scope before writing code.
- type: checkboxes
id: preflight
attributes:
label: Before you start
options:
- label: I searched [existing issues](https://github.com/bytedance/deer-flow/issues?q=is%3Aissue) and this is not a duplicate.
required: true
- type: textarea
id: problem
attributes:
label: Problem / motivation
description: What problem does this solve? What is painful today, or what does it unblock?
placeholder: "I'm always frustrated when ..."
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: Describe the change from a user's / caller's perspective.
validations:
required: true
- type: dropdown
id: areas
attributes:
label: Affected area(s)
description: Which part of DeerFlow would this touch? Select all that apply.
multiple: true
options:
- Frontend (UI / Next.js)
- Backend API (gateway / endpoints / SSE)
- Agents / LangGraph (graph, prompts, langgraph.json)
- Sandbox / Docker
- Skills
- MCP
- Config / setup
- Docs
- Not sure
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives considered
description: Other approaches you weighed and why you discarded them.
- type: textarea
id: additional
attributes:
label: Additional context
description: Mockups, links, related issues, or anything else that helps.
@@ -0,0 +1,128 @@
name: Runtime Information
description: Report runtime/environment details to help reproduce an issue.
title: "[runtime] "
labels:
- needs-triage
body:
- type: markdown
attributes:
value: |
Thanks for sharing runtime details.
Complete this form so maintainers can quickly reproduce and diagnose the problem.
- type: input
id: summary
attributes:
label: Problem summary
description: Short summary of the issue.
placeholder: e.g. make dev fails to start gateway service
validations:
required: true
- type: textarea
id: expected
attributes:
label: Expected behavior
placeholder: What did you expect to happen?
validations:
required: true
- type: textarea
id: actual
attributes:
label: Actual behavior
placeholder: What happened instead? Include key error lines.
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
options:
- macOS
- Linux
- Windows
- Other
validations:
required: true
- type: input
id: platform_details
attributes:
label: Platform details
description: Add architecture and shell if relevant.
placeholder: e.g. arm64, zsh
- type: input
id: python_version
attributes:
label: Python version
placeholder: e.g. Python 3.12.9
- type: input
id: node_version
attributes:
label: Node.js version
placeholder: e.g. v23.11.0
- type: input
id: pnpm_version
attributes:
label: pnpm version
placeholder: e.g. 10.26.2
- type: input
id: uv_version
attributes:
label: uv version
placeholder: e.g. 0.7.20
- type: dropdown
id: run_mode
attributes:
label: How are you running DeerFlow?
options:
- Local (make dev)
- Docker (make docker-dev)
- CI
- Other
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Reproduction steps
description: Provide exact commands and sequence.
placeholder: |
1. make check
2. make install
3. make dev
4. ...
validations:
required: true
- type: textarea
id: logs
attributes:
label: Relevant logs
description: Paste key lines from logs (for example logs/gateway.log, logs/frontend.log).
render: shell
validations:
required: true
- type: textarea
id: git_info
attributes:
label: Git state
description: Share output of git branch and latest commit SHA.
placeholder: |
branch: feature/my-branch
commit: abcdef1
- type: textarea
id: additional
attributes:
label: Additional context
description: Add anything else that might help triage.
-72
View File
@@ -1,72 +0,0 @@
# Path-based PR auto-labeling config for actions/labeler@v5.
# Each key is a label (must exist — see .github/labels.yml); the globs decide
# when it is applied. A PR can match several areas, which is expected.
"area:frontend":
- changed-files:
- any-glob-to-any-file:
- "frontend/**"
"area:backend":
- changed-files:
- any-glob-to-any-file:
- "backend/app/**"
- "backend/packages/harness/deerflow/runtime/**"
- "backend/packages/harness/deerflow/persistence/**"
- "backend/packages/harness/deerflow/config/**"
- "backend/packages/harness/deerflow/tools/**"
- "backend/packages/harness/deerflow/guardrails/**"
- "backend/packages/harness/deerflow/tracing/**"
- "backend/packages/harness/deerflow/models/**"
- "backend/packages/harness/deerflow/utils/**"
- "backend/packages/harness/deerflow/uploads/**"
"area:agents":
- changed-files:
- any-glob-to-any-file:
- "backend/packages/harness/deerflow/agents/**"
- "backend/packages/harness/deerflow/subagents/**"
- "backend/packages/harness/deerflow/reflection/**"
- "backend/langgraph.json"
- "backend/**/prompts/**"
"area:sandbox":
- changed-files:
- any-glob-to-any-file:
- "docker/**"
- "backend/packages/harness/deerflow/sandbox/**"
- "backend/Dockerfile"
- "frontend/Dockerfile"
"area:skills":
- changed-files:
- any-glob-to-any-file:
- "skills/**"
- "backend/packages/harness/deerflow/skills/**"
- "frontend/src/core/skills/**"
"area:mcp":
- changed-files:
- any-glob-to-any-file:
- "backend/packages/harness/deerflow/mcp/**"
- "frontend/src/core/mcp/**"
"area:ci":
- changed-files:
- any-glob-to-any-file:
- ".github/**"
- "scripts/**"
"area:docs":
- changed-files:
- any-glob-to-any-file:
- "docs/**"
- "**/*.md"
"area:deps":
- changed-files:
- any-glob-to-any-file:
- "backend/pyproject.toml"
- "backend/uv.lock"
- "frontend/package.json"
- "frontend/pnpm-lock.yaml"
-119
View File
@@ -1,119 +0,0 @@
# Declarative label source of truth for DeerFlow.
#
# This file is the single source of truth for repository labels used by the
# auto-labeling workflows (.github/workflows/pr-labeler.yml, pr-triage.yml,
# issue-triage.yml). Auto-labelers can only apply labels that already exist,
# so every label referenced by a workflow MUST be declared here.
#
# Apply with: uv run --with pyyaml python scripts/sync_labels.py [--repo OWNER/NAME]
# CI keeps it in sync via .github/workflows/label-sync.yml (runs on changes here).
#
# Sync is additive/update-only: it creates or updates the labels listed below
# and never deletes labels that are not listed.
#
# Color = 6-digit hex without the leading '#'.
labels:
# ── Type ─────────────────────────────────────────────────────────────────
# Mostly GitHub defaults; declared here so colors/descriptions stay stable
# and so issue templates can rely on them existing.
- name: bug
color: d73a4a
description: Something isn't working
- name: enhancement
color: a2eeef
description: New feature or request
- name: documentation
color: 0075ca
description: Improvements or additions to documentation
- name: question
color: d876e3
description: Further information is requested
# ── Area (auto, by changed paths — see .github/labeler.yml) ───────────────
# Mirrors the "Surface area" section of the pull request template.
- name: "area:frontend"
color: c5def5
description: Next.js frontend under frontend/
- name: "area:backend"
color: c5def5
description: Gateway / runtime / core backend under backend/
- name: "area:agents"
color: c5def5
description: Agents, subagents, graph wiring, prompts, langgraph.json
- name: "area:sandbox"
color: c5def5
description: Sandboxed execution and docker/
- name: "area:skills"
color: c5def5
description: Skills under skills/ or the skills harness
- name: "area:mcp"
color: c5def5
description: Model Context Protocol integration
- name: "area:ci"
color: c5def5
description: GitHub Actions, CI config, repo tooling
- name: "area:docs"
color: c5def5
description: Documentation and Markdown only
- name: "area:deps"
color: c5def5
description: Dependency manifests / lockfiles
# ── Size (auto, by additions + deletions — see pr-triage.yml) ─────────────
- name: "size/XS"
color: "009900"
description: PR changes < 20 lines
- name: "size/S"
color: 77bb00
description: PR changes 20-100 lines
- name: "size/M"
color: eebb00
description: PR changes 100-300 lines
- name: "size/L"
color: ee9900
description: PR changes 300-700 lines
- name: "size/XL"
color: ee5500
description: PR changes 700+ lines
# ── Risk (auto, by changed paths — see pr-triage.yml) ─────────────────────
- name: "risk:low"
color: 0e8a16
description: "Low risk: docs / i18n / assets only"
- name: "risk:medium"
color: fbca04
description: "Medium risk: regular code changes"
- name: "risk:high"
color: b60205
description: "High risk: backend API, agents, sandbox, auth, deps, CI"
# ── Priority (manual) ─────────────────────────────────────────────────────
- name: P0
color: b60205
description: Critical priority
- name: P1
color: d93f0b
description: Major priority
- name: P2
color: e99695
description: Normal priority
# ── Status (auto + manual) ────────────────────────────────────────────────
- name: needs-triage
color: fef2c0
description: Awaiting maintainer triage
- name: needs-validation
color: d4c5f9
description: Touches front/back contract surface; needs real-path validation
- name: skip-validation
color: cccccc
description: "Maintainer override: do not auto-add needs-validation on this PR"
- name: reviewing
color: 5319e7
description: A maintainer is reviewing this PR
# ── Contributor ───────────────────────────────────────────────────────────
- name: first-time-contributor
color: c2e0c6
description: First contribution to this repository — be welcoming
-14
View File
@@ -59,17 +59,3 @@ Fixes #
Frontend: cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test Frontend: cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test
Frontend E2E (if you touched frontend/): cd frontend && make test-e2e --> Frontend E2E (if you touched frontend/): cd frontend && make test-e2e -->
## AI assistance
<!-- DeerFlow is an AI project — most PRs here use AI coding tools, and that's
welcome. Disclosing it just helps reviewers calibrate how closely to read the
diff. Please fill all three; don't delete the section. -->
**Tool(s) used:** <!-- e.g. Claude Code, Cursor, GitHub Copilot, Codex, Windsurf, or "none" -->
**How you used it:** <!-- e.g. "generated the module from a spec", "autocomplete only",
"AI wrote tests, I wrote the impl". A prompt or conversation link is great too. -->
- [ ] I've read and understand every line of this change and take responsibility for it — it's not unreviewed AI output.
-44
View File
@@ -1,44 +0,0 @@
name: Issue Triage
# Ensures every newly opened issue carries `needs-triage`, even blank or
# API-created ones that bypass the issue templates. Creates the label if it is
# somehow missing, so the workflow is self-healing.
on:
issues:
types: [opened]
permissions:
issues: write
jobs:
needs-triage:
runs-on: ubuntu-latest
steps:
- name: Add needs-triage label
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const issue_number = context.payload.issue.number;
const current = (context.payload.issue.labels || []).map(l => l.name);
if (current.includes('needs-triage')) {
core.info('Issue already has needs-triage; nothing to do.');
return;
}
// Self-heal: create the label if it does not exist yet.
try {
await github.rest.issues.createLabel({
owner, repo, name: 'needs-triage', color: 'fef2c0',
description: 'Awaiting maintainer triage',
});
} catch (e) {
if (e.status !== 422) throw e; // 422 = already exists
}
await github.rest.issues.addLabels({
owner, repo, issue_number, labels: ['needs-triage'],
});
core.info(`Added needs-triage to #${issue_number}.`);
-38
View File
@@ -1,38 +0,0 @@
name: Label Sync
# Keeps repository labels in sync with the declarative source of truth
# (.github/labels.yml). Runs whenever that file changes on main, and can be
# triggered manually. Additive/update-only — never deletes labels.
on:
push:
branches: [main]
paths:
- ".github/labels.yml"
- "scripts/sync_labels.py"
- ".github/workflows/label-sync.yml"
workflow_dispatch:
permissions:
contents: read
issues: write
concurrency:
group: label-sync
cancel-in-progress: false
jobs:
sync:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Sync labels
run: uv run --with pyyaml python scripts/sync_labels.py
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
-28
View File
@@ -1,28 +0,0 @@
name: PR Labeler
# Applies area:* labels based on which files a PR changes (see .github/labeler.yml).
# Uses pull_request_target so it also works on fork PRs. SAFE: actions/labeler
# only reads the changed-file list via the API — it never checks out or runs PR code.
on:
pull_request_target:
types: [opened, synchronize, reopened, ready_for_review]
permissions:
contents: read
pull-requests: write
concurrency:
group: pr-labeler-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
label:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- name: Apply area labels
uses: actions/labeler@v5
with:
configuration-path: .github/labeler.yml
sync-labels: true
-164
View File
@@ -1,164 +0,0 @@
name: PR Triage
# Two responsibilities, both pure-metadata (no PR code is checked out or run):
# 1. On open/sync: apply size/* + risk:* labels, and needs-validation when the
# PR touches the front/back contract surface (backend API, SSE, agents, or
# the frontend streaming client). A `skip-validation` label opts out.
# 2. On maintainer review: apply the `reviewing` label.
#
# All labels are managed within their own namespace — labels outside size/*,
# risk:*, needs-validation and reviewing are never touched here.
on:
pull_request_target:
types: [opened, synchronize, reopened, ready_for_review]
pull_request_review:
types: [submitted]
permissions:
contents: read
pull-requests: write
concurrency:
group: pr-triage-${{ github.event.pull_request.number }}
cancel-in-progress: false
jobs:
size-and-risk:
if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- name: Label size, risk and validation need
uses: actions/github-script@v7
with:
script: |
const pr = context.payload.pull_request;
const { owner, repo } = context.repo;
const prNumber = pr.number;
// ---- size, from additions + deletions ----
const churn = (pr.additions || 0) + (pr.deletions || 0);
const sizeLabel =
churn < 20 ? 'size/XS' :
churn < 100 ? 'size/S' :
churn < 300 ? 'size/M' :
churn < 700 ? 'size/L' : 'size/XL';
// ---- changed paths ----
const files = await github.paginate(github.rest.pulls.listFiles, {
owner, repo, pull_number: prNumber, per_page: 100,
});
const paths = files.map(f => f.filename);
const matches = (re) => paths.some(p => re.test(p));
const docsOnly = paths.length > 0 && paths.every(p =>
/\.(md|mdx|txt)$/i.test(p) || p.startsWith('docs/') ||
/\.(png|jpe?g|gif|svg|webp|ico)$/i.test(p));
const highRisk = matches(
/^backend\/app\/gateway\//) || matches(
/^backend\/packages\/harness\/deerflow\/(agents|subagents|sandbox)\//) || matches(
/(^|\/)langgraph\.json$/) || matches(
/(^|\/)(auth|authz|security)/i) || matches(
/(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/) || matches(
/^docker\//) || matches(
/^\.github\/workflows\//);
const riskLabel = docsOnly ? 'risk:low' : (highRisk ? 'risk:high' : 'risk:medium');
// needs-validation: front/back contract surface
const contractSurface =
matches(/^backend\/app\/gateway\//) ||
matches(/^backend\/packages\/harness\/deerflow\/(agents|subagents)\//) ||
matches(/(^|\/)langgraph\.json$/) ||
matches(/^frontend\/src\/core\/(api|threads|messages)\//);
const current = (pr.labels || []).map(l => l.name);
const hasSkip = current.includes('skip-validation');
const desired = [sizeLabel, riskLabel];
if (contractSurface && !hasSkip) desired.push('needs-validation');
const managed = (name) =>
name.startsWith('size/') || name.startsWith('risk:') || name === 'needs-validation';
const toRemove = current.filter(l => managed(l) && !desired.includes(l));
const toAdd = desired.filter(l => !current.includes(l));
for (const name of toRemove) {
try {
await github.rest.issues.removeLabel({ owner, repo, issue_number: prNumber, name });
} catch (e) {
if (e.status !== 404) throw e;
}
}
if (toAdd.length) {
await github.rest.issues.addLabels({ owner, repo, issue_number: prNumber, labels: toAdd });
}
core.info(`size=${sizeLabel} risk=${riskLabel} churn=${churn} ` +
`validation=${desired.includes('needs-validation')} ` +
`(+${toAdd.join(',') || '-'} / -${toRemove.join(',') || '-'})`);
first-time:
if: github.event_name == 'pull_request_target' && github.event.action == 'opened'
runs-on: ubuntu-latest
steps:
- name: Label first-time contributors
uses: actions/github-script@v7
with:
script: |
const pr = context.payload.pull_request;
const { owner, repo } = context.repo;
const assoc = pr.author_association;
const isBot = pr.user.type === 'Bot';
core.info(`author=${pr.user.login} association=${assoc} bot=${isBot}`);
// FIRST_TIME_CONTRIBUTOR = no prior merged commit to this repo;
// FIRST_TIMER = no prior commit anywhere on GitHub. Either counts.
if (isBot || !['FIRST_TIME_CONTRIBUTOR', 'FIRST_TIMER'].includes(assoc)) {
core.info('Not a first-time contributor; skipping.');
return;
}
await github.rest.issues.addLabels({
owner, repo, issue_number: pr.number, labels: ['first-time-contributor'],
});
core.info(`Added first-time-contributor to #${pr.number}.`);
reviewing:
if: github.event_name == 'pull_request_review'
runs-on: ubuntu-latest
steps:
- name: Add reviewing label for maintainer reviews
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const prNumber = context.payload.pull_request.number;
const reviewer = context.payload.review.user.login;
const { data: perm } = await github.rest.repos.getCollaboratorPermissionLevel({
owner, repo, username: reviewer,
});
if (!['admin', 'write', 'maintain'].includes(perm.permission)) {
core.info(`Reviewer ${reviewer} (${perm.permission}) is not a maintainer; skipping.`);
return;
}
const { data: labels } = await github.rest.issues.listLabelsOnIssue({
owner, repo, issue_number: prNumber,
});
if (labels.some(l => l.name === 'reviewing')) {
core.info('Already labeled reviewing; skipping.');
return;
}
try {
await github.rest.issues.addLabels({
owner, repo, issue_number: prNumber, labels: ['reviewing'],
});
core.info(`Added "reviewing" (reviewer ${reviewer}).`);
} catch (e) {
// 403 is expected for review events on some fork PR contexts.
if (e.status === 403) core.info('No permission to label (expected on some fork PRs).');
else throw e;
}
-15
View File
@@ -287,21 +287,6 @@ Nginx (port 2026) ← Unified entry point
git push origin feature/your-feature-name git push origin feature/your-feature-name
``` ```
## AI assistance disclosure
DeerFlow is an AI project and we welcome AI-assisted contributions. To help
reviewers calibrate how closely to read a change, **every pull request must
complete the "AI assistance" section of the
[PR template](.github/pull_request_template.md)**:
- which tool(s) you used (or `none`),
- how you used them, and
- a confirmation that a human has read, understands, and takes responsibility
for the change.
Please don't delete the section. PRs that ignore it may be asked to fill it in
before review.
## Testing ## Testing
```bash ```bash
+31 -1
View File
@@ -89,7 +89,36 @@ install:
# Pre-pull sandbox Docker image (optional but recommended) # Pre-pull sandbox Docker image (optional but recommended)
setup-sandbox: setup-sandbox:
@$(RUN_WITH_GIT_BASH) ./scripts/setup-sandbox.sh @echo "=========================================="
@echo " Pre-pulling Sandbox Container Image"
@echo "=========================================="
@echo ""
@IMAGE=$$(grep -A 20 "# sandbox:" config.yaml 2>/dev/null | grep "image:" | awk '{print $$2}' | head -1); \
if [ -z "$$IMAGE" ]; then \
IMAGE="enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"; \
echo "Using default image: $$IMAGE"; \
else \
echo "Using configured image: $$IMAGE"; \
fi; \
echo ""; \
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
echo "Detected Apple Container on macOS, pulling image..."; \
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
fi; \
if command -v docker >/dev/null 2>&1; then \
echo "Pulling image using Docker..."; \
if docker pull "$$IMAGE"; then \
echo ""; \
echo "✓ Sandbox image pulled successfully"; \
else \
echo ""; \
echo "⚠ Failed to pull sandbox image (this is OK for local sandbox mode)"; \
fi; \
else \
echo "✗ Neither Docker nor Apple Container is available"; \
echo " Please install Docker: https://docs.docker.com/get-docker/"; \
exit 1; \
fi
# Start all services in development mode (with hot-reloading) # Start all services in development mode (with hot-reloading)
dev: dev:
@@ -119,6 +148,7 @@ stop:
clean: stop clean: stop
@echo "Cleaning up..." @echo "Cleaning up..."
@-rm -rf backend/.deer-flow 2>/dev/null || true @-rm -rf backend/.deer-flow 2>/dev/null || true
@-rm -rf backend/.langgraph_api 2>/dev/null || true
@-rm -rf logs/*.log 2>/dev/null || true @-rm -rf logs/*.log 2>/dev/null || true
@echo "✓ Cleanup complete" @echo "✓ Cleanup complete"
+11 -3
View File
@@ -208,7 +208,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) 14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`) 15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`) 16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer 17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last) 18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
@@ -223,9 +223,17 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart. **Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`. **Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`. | Field | Why a restart is required |
|---|---|
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
Configuration priority: Configuration priority:
1. Explicit `config_path` argument 1. Explicit `config_path` argument
+3 -3
View File
@@ -64,7 +64,7 @@ FROM builder AS dev
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket) # Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
EXPOSE 8001 EXPOSE 8001 2024
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"] CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
@@ -94,8 +94,8 @@ WORKDIR /app
# Copy backend with pre-built virtualenv from builder # Copy backend with pre-built virtualenv from builder
COPY --from=builder /app/backend ./backend COPY --from=builder /app/backend ./backend
# Expose Gateway API port. # Expose ports (gateway: 8001, langgraph: 2024)
EXPOSE 8001 EXPOSE 8001 2024
# Default command (can be overridden in docker-compose) # Default command (can be overridden in docker-compose)
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run --no-sync uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"] CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run --no-sync uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
+9 -188
View File
@@ -7,26 +7,16 @@ import json
import logging import logging
import re import re
import threading import threading
import time
from typing import Any, Literal from typing import Any, Literal
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import ( from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import get_sandbox_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60
def _is_feishu_command(text: str) -> bool: def _is_feishu_command(text: str) -> bool:
@@ -66,7 +56,6 @@ class FeishuChannel(Channel):
self._background_tasks: set[asyncio.Task] = set() self._background_tasks: set[asyncio.Task] = set()
self._running_card_ids: dict[str, str] = {} self._running_card_ids: dict[str, str] = {}
self._running_card_tasks: dict[str, asyncio.Task] = {} self._running_card_tasks: dict[str, asyncio.Task] = {}
self._pending_clarifications: dict[tuple[str, str], list[dict[str, Any]]] = {}
self._CreateFileRequest = None self._CreateFileRequest = None
self._CreateFileRequestBody = None self._CreateFileRequestBody = None
self._CreateImageRequest = None self._CreateImageRequest = None
@@ -74,16 +63,6 @@ class FeishuChannel(Channel):
self._GetMessageResourceRequest = None self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock() self._thread_lock = threading.Lock()
@staticmethod
def _non_empty_str(value: Any) -> str | None:
if isinstance(value, str) and value.strip():
return value.strip()
return None
@staticmethod
def _pending_key(chat_id: str, user_id: str) -> tuple[str, str]:
return (chat_id, user_id)
@property @property
def supports_streaming(self) -> bool: def supports_streaming(self) -> bool:
return True return True
@@ -552,25 +531,18 @@ class FeishuChannel(Channel):
"[Feishu] failed to patch running card %s, falling back to final reply", "[Feishu] failed to patch running card %s, falling back to final reply",
running_card_id, running_card_id,
) )
fallback_card_id = await self._reply_card(source_message_id, msg.text) await self._reply_card(source_message_id, msg.text)
self._remember_thread_mapping(msg, source_message_id, fallback_card_id)
self._remember_pending_clarification(msg, fallback_card_id)
else: else:
self._remember_thread_mapping(msg, source_message_id, running_card_id)
self._remember_pending_clarification(msg, running_card_id)
logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id) logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id)
elif msg.is_final: elif msg.is_final:
final_card_id = await self._reply_card(source_message_id, msg.text) await self._reply_card(source_message_id, msg.text)
self._remember_thread_mapping(msg, source_message_id, final_card_id)
self._remember_pending_clarification(msg, final_card_id)
elif awaited_running_card_task: elif awaited_running_card_task:
logger.warning( logger.warning(
"[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation", "[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation",
source_message_id, source_message_id,
) )
else: else:
created_card_id = await self._ensure_running_card(source_message_id, msg.text) await self._ensure_running_card(source_message_id, msg.text)
self._remember_thread_mapping(msg, source_message_id, created_card_id)
if msg.is_final: if msg.is_final:
self._running_card_ids.pop(source_message_id, None) self._running_card_ids.pop(source_message_id, None)
@@ -581,129 +553,6 @@ class FeishuChannel(Channel):
# -- internal ---------------------------------------------------------- # -- internal ----------------------------------------------------------
def _remember_thread_mapping(self, msg: OutboundMessage, *topic_ids: str | None) -> None:
store = self.config.get("channel_store")
if store is None or not msg.thread_id:
return
metadata_topic_ids = [
msg.metadata.get("message_id"),
msg.metadata.get("root_id"),
msg.metadata.get("parent_id"),
msg.metadata.get("thread_id"),
msg.metadata.get("topic_id"),
]
user_id = ""
raw_user_id = msg.metadata.get("user_id")
if isinstance(raw_user_id, str):
user_id = raw_user_id
seen: set[str] = set()
for topic_id in [*topic_ids, *metadata_topic_ids]:
topic_id = self._non_empty_str(topic_id)
if not topic_id or topic_id in seen:
continue
seen.add(topic_id)
try:
store.set_thread_id(
self.name,
msg.chat_id,
msg.thread_id,
topic_id=topic_id,
user_id=user_id,
)
except Exception:
logger.exception("[Feishu] failed to remember thread mapping for topic_id=%s", topic_id)
def _remember_pending_clarification(self, msg: OutboundMessage, card_message_id: str | None) -> None:
if not msg.is_final or msg.metadata.get(PENDING_CLARIFICATION_METADATA_KEY) is not True:
return
user_id = self._non_empty_str(msg.metadata.get("user_id"))
topic_id = self._non_empty_str(msg.metadata.get("topic_id"))
source_message_id = self._non_empty_str(msg.thread_ts) or self._non_empty_str(msg.metadata.get("message_id"))
if not (user_id and topic_id and msg.thread_id and source_message_id and card_message_id):
return
key = self._pending_key(msg.chat_id, user_id)
pending = {
"thread_id": msg.thread_id,
"topic_id": topic_id,
"source_message_id": source_message_id,
"card_message_id": card_message_id,
"created_at": time.time(),
}
with self._thread_lock:
# Plain-message clarification continuity is a short-lived in-memory
# hint; explicit Feishu replies are still covered by persisted
# message-id mappings.
self._pending_clarifications.setdefault(key, []).append(pending)
logger.info(
"[Feishu] pending clarification remembered: chat_id=%s user_id=%s topic_id=%s thread_id=%s",
msg.chat_id,
user_id,
topic_id,
msg.thread_id,
)
def _consume_pending_clarification(self, chat_id: str, user_id: str) -> dict[str, Any] | None:
key = self._pending_key(chat_id, user_id)
with self._thread_lock:
pending_items = self._pending_clarifications.get(key)
if not pending_items:
return None
now = time.time()
while pending_items:
pending = pending_items.pop(0)
created_at = pending.get("created_at")
if isinstance(created_at, (int, float)) and now - created_at <= PENDING_CLARIFICATION_TTL_SECONDS:
if pending_items:
self._pending_clarifications[key] = pending_items
else:
self._pending_clarifications.pop(key, None)
return pending
logger.info("[Feishu] pending clarification expired: chat_id=%s user_id=%s", chat_id, user_id)
self._pending_clarifications.pop(key, None)
return None
def _ensure_pending_thread_mapping(self, chat_id: str, user_id: str, pending: dict[str, Any]) -> None:
store = self.config.get("channel_store")
topic_id = self._non_empty_str(pending.get("topic_id"))
thread_id = self._non_empty_str(pending.get("thread_id"))
if store is None or not topic_id or not thread_id:
return
try:
store.set_thread_id(self.name, chat_id, thread_id, topic_id=topic_id, user_id=user_id)
except Exception:
logger.exception("[Feishu] failed to restore pending clarification mapping for topic_id=%s", topic_id)
def _resolve_topic_id(
self,
chat_id: str,
msg_id: str,
*,
root_id: str | None,
parent_id: str | None,
thread_id: str | None,
) -> tuple[str, bool]:
store = self.config.get("channel_store")
candidates = [root_id, parent_id, thread_id]
if store is not None:
for candidate in candidates:
candidate = self._non_empty_str(candidate)
if not candidate:
continue
try:
if store.get_thread_id(self.name, chat_id, topic_id=candidate):
return candidate, True
except Exception:
logger.exception("[Feishu] failed to resolve stored topic mapping for topic_id=%s", candidate)
return root_id or msg_id, False
@staticmethod @staticmethod
def _log_future_error(fut, name: str, msg_id: str) -> None: def _log_future_error(fut, name: str, msg_id: str) -> None:
"""Callback for run_coroutine_threadsafe futures to surface errors.""" """Callback for run_coroutine_threadsafe futures to surface errors."""
@@ -744,9 +593,7 @@ class FeishuChannel(Channel):
# root_id is set when the message is a reply within a Feishu thread. # root_id is set when the message is a reply within a Feishu thread.
# Use it as topic_id so all replies share the same DeerFlow thread. # Use it as topic_id so all replies share the same DeerFlow thread.
root_id = self._non_empty_str(getattr(message, "root_id", None)) root_id = getattr(message, "root_id", None) or None
parent_id = self._non_empty_str(getattr(message, "parent_id", None))
feishu_thread_id = self._non_empty_str(getattr(message, "thread_id", None))
# Parse message content # Parse message content
content = json.loads(message.content) content = json.loads(message.content)
@@ -807,12 +654,10 @@ class FeishuChannel(Channel):
text = text.strip() text = text.strip()
logger.info( logger.info(
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r", "[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, sender=%s, text=%r",
chat_id, chat_id,
msg_id, msg_id,
root_id, root_id,
parent_id,
feishu_thread_id,
sender_id, sender_id,
text[:100] if text else "", text[:100] if text else "",
) )
@@ -828,24 +673,8 @@ class FeishuChannel(Channel):
else: else:
msg_type = InboundMessageType.CHAT msg_type = InboundMessageType.CHAT
# Prefer any platform message id that already maps to a DeerFlow # topic_id: use root_id for replies (same topic), msg_id for new messages (new topic)
# thread. This keeps replies to bot clarification cards in the topic_id = root_id or msg_id
# original conversation even when Feishu reports the card as root.
topic_id, resolved_from_stored_mapping = self._resolve_topic_id(
chat_id,
msg_id,
root_id=root_id,
parent_id=parent_id,
thread_id=feishu_thread_id,
)
resolved_from_pending = False
if msg_type == InboundMessageType.CHAT and not resolved_from_stored_mapping:
pending = self._consume_pending_clarification(chat_id, sender_id)
pending_topic_id = self._non_empty_str(pending.get("topic_id")) if pending else None
if pending_topic_id:
topic_id = pending_topic_id
self._ensure_pending_thread_mapping(chat_id, sender_id, pending)
resolved_from_pending = True
inbound = self._make_inbound( inbound = self._make_inbound(
chat_id=chat_id, chat_id=chat_id,
@@ -854,15 +683,7 @@ class FeishuChannel(Channel):
msg_type=msg_type, msg_type=msg_type,
thread_ts=msg_id, thread_ts=msg_id,
files=files_list, files=files_list,
metadata={ metadata={"message_id": msg_id, "root_id": root_id},
"message_id": msg_id,
"root_id": root_id,
"parent_id": parent_id,
"thread_id": feishu_thread_id,
"topic_id": topic_id,
"user_id": sender_id,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY: resolved_from_pending,
},
) )
inbound.topic_id = topic_id inbound.topic_id = topic_id
+5 -71
View File
@@ -15,18 +15,10 @@ import httpx
from langgraph_sdk.errors import ConflictError from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import ( from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
PENDING_CLARIFICATION_METADATA_KEY,
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
from app.channels.store import ChannelStore from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.config.paths import make_safe_user_id
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -210,54 +202,6 @@ def _extract_response_text(result: dict | list) -> str:
return "" return ""
def _messages_from_result(result: dict | list) -> list[Any]:
if isinstance(result, list):
return result
if isinstance(result, dict):
messages = result.get("messages", [])
if isinstance(messages, list):
return messages
return []
def _current_turn_messages(result: dict | list) -> list[dict[str, Any]]:
messages = _messages_from_result(result)
current_turn: list[dict[str, Any]] = []
for msg in reversed(messages):
if not isinstance(msg, dict):
continue
if msg.get("type") == "human":
break
current_turn.append(msg)
current_turn.reverse()
return current_turn
def _has_current_turn_clarification(result: dict | list) -> bool:
"""Return True only when the current turn's final result is clarification."""
for msg in reversed(_current_turn_messages(result)):
msg_type = msg.get("type")
if msg_type == "tool":
return msg.get("name") == "ask_clarification"
if msg_type == "ai":
content = msg.get("content")
if isinstance(content, str):
if content:
return False
elif content:
return False
if msg.get("tool_calls"):
return False
return False
def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification: bool = False) -> dict[str, Any]:
metadata = _slim_metadata(base_metadata)
if pending_clarification:
metadata[PENDING_CLARIFICATION_METADATA_KEY] = True
return metadata
def _extract_text_content(content: Any) -> str: def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field.""" """Extract text from a streaming payload content field."""
if isinstance(content, str): if isinstance(content, str):
@@ -671,20 +615,12 @@ class ChannelManager:
configurable["checkpoint_ns"] = "" configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id configurable["thread_id"] = thread_id
# ``user_id`` drives user-scoped filesystem buckets that only accept
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
# under ``channel_user_id`` for platform-facing lookups.
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id
run_context = _merge_dicts( run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT, DEFAULT_RUN_CONTEXT,
self._default_session.get("context"), self._default_session.get("context"),
channel_layer.get("context"), channel_layer.get("context"),
user_layer.get("context"), user_layer.get("context"),
run_context_identity, {"thread_id": thread_id},
) )
# Custom agents are implemented as lead_agent + agent_name context. # Custom agents are implemented as lead_agent + agent_name context.
@@ -870,7 +806,6 @@ class ChannelManager:
raise raise
response_text = _extract_response_text(result) response_text = _extract_response_text(result)
pending_clarification = _has_current_turn_clarification(result)
artifacts = _extract_artifacts(result) artifacts = _extract_artifacts(result)
logger.info( logger.info(
@@ -896,7 +831,7 @@ class ChannelManager:
artifacts=artifacts, artifacts=artifacts,
attachments=attachments, attachments=attachments,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), metadata=_slim_metadata(msg.metadata),
) )
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id) logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
await self.bus.publish_outbound(outbound) await self.bus.publish_outbound(outbound)
@@ -958,7 +893,7 @@ class ChannelManager:
text=latest_text, text=latest_text,
is_final=False, is_final=False,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
metadata=_response_metadata(msg.metadata), metadata=_slim_metadata(msg.metadata),
) )
) )
last_published_text = latest_text last_published_text = latest_text
@@ -972,7 +907,6 @@ class ChannelManager:
finally: finally:
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]} result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
response_text = _extract_response_text(result) response_text = _extract_response_text(result)
pending_clarification = _has_current_turn_clarification(result)
artifacts = _extract_artifacts(result) artifacts = _extract_artifacts(result)
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts) response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
@@ -1004,7 +938,7 @@ class ChannelManager:
attachments=attachments, attachments=attachments,
is_final=True, is_final=True,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), metadata=_slim_metadata(msg.metadata),
) )
) )
-3
View File
@@ -13,9 +13,6 @@ from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PENDING_CLARIFICATION_METADATA_KEY = "pending_clarification"
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY = "resolved_from_pending_clarification"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Message types # Message types
-56
View File
@@ -17,7 +17,6 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
@@ -34,43 +33,6 @@ from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Upper bound (seconds) for draining in-flight runs during shutdown, before the
# AsyncExitStack tears down the checkpointer (and its connection pool). Kept
# local to avoid an app -> deps -> app import cycle. This is a *separate* budget
# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s,
# which bounds channel-service stop): the two govern independent teardown steps
# and may diverge, but both count toward the lifespan shutdown window — revisit
# them together if their sum must stay within the server's graceful-shutdown
# timeout.
_RUN_DRAIN_TIMEOUT_SECONDS = 5.0
async def _drain_inflight_runs(run_manager: RunManager) -> None:
"""Drain in-flight runs before the checkpointer is torn down (issue #3373).
Shields the (internally-bounded) drain so that even if the lifespan
coroutine is itself cancelled mid-shutdown — a second SIGINT or the server's
graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the
checkpointer pool is not closed while run tasks are still writing
checkpoints. On such a cancellation we let the already-running drain finish
(it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate
the cancellation.
"""
drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS))
try:
await asyncio.shield(drain)
except asyncio.CancelledError:
# Re-shield so this second wait does not abandon the in-flight drain;
# it is bounded, so this cannot hang. Then re-raise to honour shutdown.
try:
await asyncio.shield(drain)
except Exception:
logger.exception("In-flight run drain failed after shutdown cancellation")
raise
except Exception:
logger.exception("Failed to drain in-flight runs during shutdown")
if TYPE_CHECKING: if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
@@ -119,16 +81,6 @@ def get_config() -> AppConfig:
split-brain where the worker / lead-agent thread saw a stale startup split-brain where the worker / lead-agent thread saw a stale startup
snapshot. snapshot.
Hot-reload boundary: fields backed by startup-time singletons
(engines, sandbox provider, IM channels, logging handler) require a
process restart to change at runtime. The authoritative list lives in
:mod:`deerflow.config.reload_boundary` and is mirrored by the
standardised ``"startup-only:"`` prefix on the matching
``Field(description=...)`` in :class:`AppConfig` — IDE hover on those
fields will surface the boundary inline. See
``backend/CLAUDE.md`` "Config Hot-Reload Boundary" for the operator
summary.
Any failure to materialise the config (missing file, permission denied, Any failure to materialise the config (missing file, permission denied,
YAML parse error, validation error) is reported as 503 — semantically YAML parse error, validation error) is reported as 503 — semantically
"the gateway cannot serve requests without a usable configuration" — and "the gateway cannot serve requests without a usable configuration" — and
@@ -225,14 +177,6 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
try: try:
yield yield
finally: finally:
# Drain in-flight run tasks BEFORE the AsyncExitStack tears down the
# checkpointer (and its connection pool). A run still mid-graph would
# otherwise leak into asyncio.run() shutdown, where langgraph's
# _checkpointer_put_after_previous aput races the closed pool and
# raises PoolClosed (issue #3373).
run_manager = getattr(app.state, "run_manager", None)
if run_manager is not None:
await _drain_inflight_runs(run_manager)
await close_engine() await close_engine()
+1 -2
View File
@@ -10,7 +10,6 @@ from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal"
def _load_internal_auth_token() -> str: def _load_internal_auth_token() -> str:
@@ -35,4 +34,4 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
def get_internal_user(): def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls.""" """Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
-15
View File
@@ -1,15 +0,0 @@
"""Shared pagination helpers for gateway routers."""
from __future__ import annotations
def trim_run_message_page(rows: list[dict], *, limit: int, after_seq: int | None) -> tuple[list[dict], bool]:
"""Trim a ``limit + 1`` run-message page while preserving page boundaries."""
has_more = len(rows) > limit
if not has_more:
return rows, False
if after_seq is not None:
return rows[:limit], True
return rows[-limit:], True
+2 -2
View File
@@ -15,7 +15,6 @@ from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.pagination import trim_run_message_page
from app.gateway.routers.thread_runs import RunCreateRequest from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
@@ -130,7 +129,8 @@ async def run_messages(
before_seq=before_seq, before_seq=before_seq,
after_seq=after_seq, after_seq=after_seq,
) )
data, has_more = trim_run_message_page(rows, limit=limit, after_seq=after_seq) has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more} return {"data": data, "has_more": has_more}
+2 -2
View File
@@ -21,7 +21,6 @@ from pydantic import BaseModel, Field
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.pagination import trim_run_message_page
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
@@ -403,7 +402,8 @@ async def list_run_messages(
before_seq=before_seq, before_seq=before_seq,
after_seq=after_seq, after_seq=after_seq,
) )
data, has_more = trim_run_message_page(rows, limit=limit, after_seq=after_seq) has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more} return {"data": data, "has_more": has_more}
+5 -29
View File
@@ -39,39 +39,15 @@ DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024 DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
class UploadedFileInfo(BaseModel):
"""Uploaded file metadata exposed by upload and list APIs."""
filename: str
size: int
path: str
virtual_path: str
artifact_url: str
extension: str | None = None
modified: float | None = None
original_filename: str | None = None
markdown_file: str | None = None
markdown_path: str | None = None
markdown_virtual_path: str | None = None
markdown_artifact_url: str | None = None
class UploadResponse(BaseModel): class UploadResponse(BaseModel):
"""Response model for file upload.""" """Response model for file upload."""
success: bool success: bool
files: list[UploadedFileInfo] files: list[dict[str, str]]
message: str message: str
skipped_files: list[str] = Field(default_factory=list) skipped_files: list[str] = Field(default_factory=list)
class UploadListResponse(BaseModel):
"""Response model for uploaded file listing."""
files: list[UploadedFileInfo]
count: int
class UploadLimits(BaseModel): class UploadLimits(BaseModel):
"""Application-level upload limits exposed to clients.""" """Application-level upload limits exposed to clients."""
@@ -280,7 +256,7 @@ async def upload_files(
file_info = { file_info = {
"filename": safe_filename, "filename": safe_filename,
"size": file_size, "size": str(file_size),
"path": str(sandbox_uploads / safe_filename), "path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path, "virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename), "artifact_url": upload_artifact_url(thread_id, safe_filename),
@@ -357,9 +333,9 @@ async def get_upload_limits(
return _get_upload_limits(config) return _get_upload_limits(config)
@router.get("/list", response_model=UploadListResponse) @router.get("/list", response_model=dict)
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> UploadListResponse: async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory.""" """List all files in a thread's uploads directory."""
try: try:
uploads_dir = get_uploads_dir(thread_id) uploads_dir = get_uploads_dir(thread_id)
@@ -373,7 +349,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> UploadListRes
for f in result["files"]: for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"]) f["path"] = str(sandbox_uploads / f["filename"])
return UploadListResponse(**result) return result
@router.delete("/{filename}") @router.delete("/{filename}")
+1 -14
View File
@@ -19,7 +19,6 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.runtime import ( from deerflow.runtime import (
@@ -141,14 +140,7 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']`` """Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
and ``config['context']`` so they are visible to legacy configurable readers and and ``config['context']`` so they are visible to legacy configurable readers and
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool — to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
see issue #2677). see issue #2677)."""
``user_id`` is intentionally propagated into ``config['context']`` in addition to
the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in
``body.context`` keep it on ``ToolRuntime.context``. It is merged with
``setdefault`` so a server-authenticated id stamped by
:func:`inject_authenticated_user_context` always wins over the client-supplied one.
"""
if not context: if not context:
return return
configurable = config.setdefault("configurable", {}) configurable = config.setdefault("configurable", {})
@@ -159,8 +151,6 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
configurable.setdefault(key, context[key]) configurable.setdefault(key, context[key])
if isinstance(runtime_context, dict): if isinstance(runtime_context, dict):
runtime_context.setdefault(key, context[key]) runtime_context.setdefault(key, context[key])
if "user_id" in context and isinstance(runtime_context, dict):
runtime_context.setdefault("user_id", context["user_id"])
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None: def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
@@ -176,9 +166,6 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request)
if user_id is None: if user_id is None:
return return
if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
return
runtime_context = config.setdefault("context", {}) runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict): if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id) runtime_context["user_id"] = str(user_id)
+2 -2
View File
@@ -29,7 +29,7 @@ All other test plan sections were executed against either:
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container | | TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` | | TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` |
| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | | TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-06 | Docker deploy uses Gateway embedded runtime | `./scripts/deploy.sh` produces a Gateway + frontend + nginx topology (no `langgraph` container); same auth flow as local `make dev` | needs `docker compose up` | | TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
## Coverage already provided by non-Docker tests ## Coverage already provided by non-Docker tests
@@ -43,7 +43,7 @@ the test cases that ran on sg_dev or local:
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs | | TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies | | TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies |
| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | | TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-06 (Gateway embedded runtime container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (Gateway auth flow on sg_dev) — same Gateway code, container is just a packaging change | | TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
## Reproduction steps when Docker becomes available ## Reproduction steps when Docker becomes available
+2 -2
View File
@@ -124,8 +124,8 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
## 兼容性 ## 兼容性
- **本地开发**`make dev`):Gateway embedded runtime 完全兼容;无 admin 时访问 `/setup` 初始化 - **标准模式**`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化
- **Gateway embedded runtime**:标准脚本、Docker dev 和生产部署均通过 Gateway 提供认证与 LangGraph-compatible API - **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载 - **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶 - **IM 渠道**Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响 - **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
+6 -16
View File
@@ -95,30 +95,20 @@ models:
thinking: thinking:
type: enabled type: enabled
- name: minimax-m3 - name: minimax-m2.5
display_name: MiniMax M3 display_name: MiniMax M2.5
use: langchain_openai:ChatOpenAI use: langchain_openai:ChatOpenAI
model: MiniMax-M3 model: MiniMax-M2.5
api_key: $MINIMAX_API_KEY api_key: $MINIMAX_API_KEY
base_url: https://api.minimax.io/v1 base_url: https://api.minimax.io/v1
max_tokens: 4096 max_tokens: 4096
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0] temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
supports_vision: true supports_vision: true
- name: minimax-m2.7 - name: minimax-m2.5-highspeed
display_name: MiniMax M2.7 display_name: MiniMax M2.5 Highspeed
use: langchain_openai:ChatOpenAI use: langchain_openai:ChatOpenAI
model: MiniMax-M2.7 model: MiniMax-M2.5-highspeed
api_key: $MINIMAX_API_KEY
base_url: https://api.minimax.io/v1
max_tokens: 4096
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
supports_vision: true
- name: minimax-m2.7-highspeed
display_name: MiniMax M2.7 Highspeed
use: langchain_openai:ChatOpenAI
model: MiniMax-M2.7-highspeed
api_key: $MINIMAX_API_KEY api_key: $MINIMAX_API_KEY
base_url: https://api.minimax.io/v1 base_url: https://api.minimax.io/v1
max_tokens: 4096 max_tokens: 4096
-1
View File
@@ -19,7 +19,6 @@ This directory contains detailed documentation for the DeerFlow backend.
| [STREAMING.md](STREAMING.md) | Token-level streaming design: Gateway vs DeerFlowClient paths, `stream_mode` semantics, per-id dedup | | [STREAMING.md](STREAMING.md) | Token-level streaming design: Gateway vs DeerFlowClient paths, `stream_mode` semantics, per-id dedup |
| [FILE_UPLOAD.md](FILE_UPLOAD.md) | File upload functionality | | [FILE_UPLOAD.md](FILE_UPLOAD.md) | File upload functionality |
| [PATH_EXAMPLES.md](PATH_EXAMPLES.md) | Path types and usage examples | | [PATH_EXAMPLES.md](PATH_EXAMPLES.md) | Path types and usage examples |
| [SANDBOX_MEMORY_PROFILING.md](SANDBOX_MEMORY_PROFILING.md) | Sandbox memory baseline and runtime comparison guide |
| [summarization.md](summarization.md) | Context summarization feature | | [summarization.md](summarization.md) | Context summarization feature |
| [plan_mode_usage.md](plan_mode_usage.md) | Plan mode with TodoList | | [plan_mode_usage.md](plan_mode_usage.md) | Plan mode with TodoList |
| [AUTO_TITLE_GENERATION.md](AUTO_TITLE_GENERATION.md) | Automatic title generation | | [AUTO_TITLE_GENERATION.md](AUTO_TITLE_GENERATION.md) | Automatic title generation |
-81
View File
@@ -1,81 +0,0 @@
# Sandbox Memory Profiling
This guide records a repeatable baseline before changing the sandbox runtime.
Issue #3213 reports per-sandbox memory near 1 GiB in Kubernetes. Before adding
or recommending a new provider, capture the current AIO sandbox baseline and
compare candidates with the same DeerFlow workload.
## What to Measure
Measure at least these samples:
1. Empty sandbox after it becomes ready.
2. After a simple bash command.
3. After a Python task that imports common packages.
4. After a Node task when Node-based workloads are expected.
5. After generating files under `/mnt/user-data/outputs`.
6. After release and warm reuse.
7. At the target concurrency level, for example 10, 50, or 100 sandboxes.
`kubectl top` reports Kubernetes/container working set memory. Treat it as a
capacity signal, not exclusive RSS/PSS. Pod-level memory includes every
container in the Pod and may include cache charged to the cgroup. If a result
looks surprising, inspect the sandbox processes and cgroup metrics on the node
before drawing conclusions.
## Capture a Snapshot
Run this from the repository root:
```bash
python scripts/sandbox_memory_profile.py \
--namespace deer-flow \
--selector app=deer-flow-sandbox \
--sample empty \
--include-processes \
--format markdown
```
Use a descriptive `--sample` value for each phase:
```bash
python scripts/sandbox_memory_profile.py --sample after-bash --format json
python scripts/sandbox_memory_profile.py --sample after-python --format json
python scripts/sandbox_memory_profile.py --sample after-artifact --format json
```
`--include-processes` runs `kubectl exec ... ps` in each sandbox Pod and adds
the highest-RSS processes to the report. This helps distinguish Pod-level cgroup
memory from process RSS. The two numbers will not match exactly because cgroup
memory can include cache and other kernel-accounted memory.
Save the raw JSON when comparing backends so totals, pod names, images,
requests, limits, and timestamps can be audited later.
## Candidate Runtime Matrix
For AIO, CubeSandbox, OpenSandbox, gVisor, Kata, or another candidate, compare
the same workload and record:
| Area | Required Evidence |
| --- | --- |
| Capacity | Pod or instance count, total memory, average memory, max memory |
| Startup | Ready latency at 1, 10, 50, and 100 concurrent sandboxes |
| Commands | Bash output, timeout behavior, failure shape |
| Files | `read_file`, `write_file`, binary `update_file`, `list_dir`, `glob`, `grep` |
| Uploads | Files uploaded by the gateway are visible inside the sandbox |
| Artifacts | Files written to `/mnt/user-data/outputs` are readable by the backend artifact API |
| Paths | `/mnt/user-data/workspace`, `/mnt/user-data/uploads`, `/mnt/user-data/outputs`, `/mnt/acp-workspace`, and skills paths keep their expected semantics |
| Isolation | Different users and threads cannot read each other's data |
| Cleanup | Release, idle timeout, process restart, and orphan cleanup free resources |
| Operations | Deployment prerequisites, privileged components, networking, storage, and upgrade path |
## PR Guidance
Do not claim that a new provider fixes high-concurrency memory usage until the
same DeerFlow workload has been measured on both the current AIO sandbox and the
candidate backend.
For an experimental provider PR, prefer `Related to #3213` unless the PR also
includes reproducible DeerFlow workload data that demonstrates the target memory
reduction and preserves uploads, outputs, artifacts, and isolation behavior.
@@ -18,10 +18,7 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
``create_chat_model`` call must add to this list and pass the flag. ``create_chat_model`` call must add to this list and pass the flag.
""" """
from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
@@ -48,11 +45,6 @@ from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.tracing import build_tracing_callbacks from deerflow.tracing import build_tracing_callbacks
if TYPE_CHECKING:
from langchain.tools import BaseTool
from deerflow.tools.builtins.tool_search import DeferredToolSetup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -278,7 +270,6 @@ def _build_middlewares(
custom_middlewares: list[AgentMiddleware] | None = None, custom_middlewares: list[AgentMiddleware] | None = None,
*, *,
app_config: AppConfig | None = None, app_config: AppConfig | None = None,
deferred_setup=None,
): ):
"""Build middleware chain based on runtime configuration. """Build middleware chain based on runtime configuration.
@@ -327,13 +318,11 @@ def _build_middlewares(
if model_config is not None and model_config.supports_vision: if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware()) middlewares.append(ViewImageMiddleware())
# Hide deferred tool schemas from model binding until tool_search promotes them. # Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
# The deferred set + catalog hash come from the build-time setup (assembled if resolved_app_config.tool_search.enabled:
# after tool-policy filtering); promotion is read from graph state.
if deferred_setup is not None and deferred_setup.deferred_names:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware(deferred_setup.deferred_names, deferred_setup.catalog_hash)) middlewares.append(DeferredToolFilterMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls # Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = cfg.get("subagent_enabled", False) subagent_enabled = cfg.get("subagent_enabled", False)
@@ -364,26 +353,6 @@ def _build_middlewares(
return middlewares return middlewares
def _assemble_deferred(filtered_tools: list[BaseTool], *, enabled: bool) -> tuple[list[BaseTool], DeferredToolSetup]:
"""Build the final tool list + deferred setup from a policy-filtered list.
Call AFTER tool-policy filtering so the deferred catalog never exposes a
tool the agent is not allowed to use. Fail-closed: if tool_search is enabled
and MCP tools survived filtering but no deferred set was recovered, raise
rather than silently binding their full schemas to the model.
"""
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import is_mcp_tool
deferred_setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
if enabled and not deferred_setup.deferred_names and any(is_mcp_tool(t) for t in filtered_tools):
raise RuntimeError("tool_search enabled and MCP tools survived policy filtering, but no deferred set was recovered — refusing to bind MCP schemas (fail-closed).")
final_tools = list(filtered_tools)
if deferred_setup.tool_search_tool:
final_tools.append(deferred_setup.tool_search_tool)
return final_tools, deferred_setup
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None: def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
if is_bootstrap: if is_bootstrap:
return {"bootstrap"} return {"bootstrap"}
@@ -491,19 +460,16 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
if is_bootstrap: if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow # Special bootstrap agent with minimal prompt for initial custom agent creation flow
raw_tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent] tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
filtered = filter_tools_by_skill_allowed_tools(raw_tools, skills_for_tool_policy)
final_tools, setup = _assemble_deferred(filtered, enabled=resolved_app_config.tool_search.enabled)
return create_agent( return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False), model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False),
tools=final_tools, tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config, deferred_setup=setup), middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template( system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents, max_concurrent_subagents=max_concurrent_subagents,
available_skills=set(["bootstrap"]), available_skills=set(["bootstrap"]),
app_config=resolved_app_config, app_config=resolved_app_config,
deferred_names=setup.deferred_names,
), ),
state_schema=ThreadState, state_schema=ThreadState,
) )
@@ -512,20 +478,17 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# The default agent (no agent_name) does not see this tool. # The default agent (no agent_name) does not see this tool.
extra_tools = [update_agent] if agent_name else [] extra_tools = [update_agent] if agent_name else []
# Default lead agent (unchanged behavior) # Default lead agent (unchanged behavior)
raw_tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config) tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
filtered = filter_tools_by_skill_allowed_tools(raw_tools + extra_tools, skills_for_tool_policy)
final_tools, setup = _assemble_deferred(filtered, enabled=resolved_app_config.tool_search.enabled)
return create_agent( return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False), model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False),
tools=final_tools, tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config, deferred_setup=setup), middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template( system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents, max_concurrent_subagents=max_concurrent_subagents,
agent_name=agent_name, agent_name=agent_name,
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
app_config=resolved_app_config, app_config=resolved_app_config,
deferred_names=setup.deferred_names,
), ),
state_schema=ThreadState, state_schema=ThreadState,
) )
@@ -542,14 +542,6 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks. {subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
- Progressive Loading: Load resources incrementally as referenced in skills - Progressive Loading: Load resources incrementally as referenced in skills
- Output Files: Final deliverables must be in `/mnt/user-data/outputs` - Output Files: Final deliverables must be in `/mnt/user-data/outputs`
- File Editing Workflow: When revising an existing file, prefer
`str_replace` over `write_file` — it sends only the diff and avoids
re-emitting the whole file (mirrors Claude Code's Edit and Codex's
apply_patch). When writing long new content from scratch, split it
into sections: the first `write_file` call creates the file, then use
`write_file` with append=True to extend it section by section. This
keeps each tool call small and avoids mid-stream chunk-gap timeouts
on oversized single-shot writes. (See issue #3189.)
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary - Clarity: Be direct and helpful, avoid unnecessary meta-commentary
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `![Image Description](image_path)\n\n` or "```mermaid" to display images in response or Markdown files - Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `![Image Description](image_path)\n\n` or "```mermaid" to display images in response or Markdown files
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance - Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
@@ -686,23 +678,39 @@ SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace a
Rules: Rules:
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits. - Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
- Only pass the fields that should change. Omit the others to preserve them. - Only pass the fields that should change. Omit the others to preserve them.
- Never pass literal strings like `"null"`, `"none"`, or `"undefined"` for unchanged fields.
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist. - Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn. - After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
</self_update> </self_update>
""" """
def get_deferred_tools_prompt_section(*, deferred_names: frozenset[str] = frozenset()) -> str: def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> from an explicit deferred-name set. """Generate <available-deferred-tools> block for the system prompt.
Lists only names so the agent knows what exists and can use tool_search to Lists only deferred tool names so the agent knows what exists
load them. Returns empty string when there are no deferred tools. The set is and can use tool_search to load them.
computed at agent build time (after tool-policy filtering) and passed in. Returns empty string when tool_search is disabled or no tools are deferred.
""" """
if not deferred_names: from deerflow.tools.builtins.tool_search import get_deferred_registry
if app_config is None:
try:
from deerflow.config import get_app_config
config = get_app_config()
except Exception:
return ""
else:
config = app_config
if not config.tool_search.enabled:
return "" return ""
names = "\n".join(sorted(deferred_names))
registry = get_deferred_registry()
if not registry:
return ""
names = "\n".join(e.name for e in registry.entries)
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>" return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
@@ -764,7 +772,6 @@ def apply_prompt_template(
agent_name: str | None = None, agent_name: str | None = None,
available_skills: set[str] | None = None, available_skills: set[str] | None = None,
app_config: AppConfig | None = None, app_config: AppConfig | None = None,
deferred_names: frozenset[str] = frozenset(),
) -> str: ) -> str:
# Include subagent section only if enabled (from runtime parameter) # Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents n = max_concurrent_subagents
@@ -792,7 +799,7 @@ def apply_prompt_template(
skills_section = get_skills_prompt_section(available_skills, app_config=app_config) skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
# Get deferred tools section (tool_search) # Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section(deferred_names=deferred_names) deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
# Build ACP agent section only if ACP agents are configured # Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section(app_config=app_config) acp_section = _build_acp_section(app_config=app_config)
@@ -1,15 +1,12 @@
"""Middleware to filter deferred tool schemas from model binding. """Middleware to filter deferred tool schemas from model binding.
When tool_search is enabled, MCP tools are still passed to ToolNode for When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
execution, but their schemas must NOT be sent to the LLM via bind_tools until and passed to ToolNode for execution, but their schemas should NOT be sent to the
the model has discovered them via tool_search. This middleware removes the LLM via bind_tools (that's the whole point of deferral — saving context tokens).
still-deferred tools from request.tools before model binding, and blocks tool
calls to tools that have not been promoted yet.
The deferred name set and the catalog hash are injected at construction time This middleware intercepts wrap_model_call and removes deferred tools from
(no ContextVar). Promotion state is read from graph state (``state["promoted"]``), request.tools so that model.bind_tools only receives active tool schemas.
scoped by catalog hash so a stale persisted promotion cannot expose a renamed The agent discovers deferred tools at runtime via the tool_search tool.
or drifted tool.
""" """
import logging import logging
@@ -27,49 +24,47 @@ logger = logging.getLogger(__name__)
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]): class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
"""Hide deferred tool schemas from the bound model until promoted. """Remove deferred tools from request.tools before model binding.
ToolNode still holds all tools (including deferred) for execution routing, ToolNode still holds all tools (including deferred) for execution routing,
but the LLM only sees active tool schemas plus tools that have already been but the LLM only sees active tool schemas — deferred tools are discoverable
promoted (recorded in ``state["promoted"]`` under the current catalog hash). via tool_search at runtime.
""" """
def __init__(self, deferred_names: frozenset[str], catalog_hash: str | None):
super().__init__()
self._deferred = deferred_names
self._catalog_hash = catalog_hash
def _promoted(self, state) -> set[str]:
promoted = (state or {}).get("promoted")
if promoted and promoted.get("catalog_hash") == self._catalog_hash:
return set(promoted.get("names") or [])
return set()
def _hidden(self, state) -> set[str]:
return set(self._deferred) - self._promoted(state)
def _filter_tools(self, request: ModelRequest) -> ModelRequest: def _filter_tools(self, request: ModelRequest) -> ModelRequest:
if not self._deferred: from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return request return request
hide = self._hidden(request.state)
if not hide: deferred_names = registry.deferred_names
return request active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
active = [t for t in request.tools if getattr(t, "name", None) not in hide]
if len(active) < len(request.tools): if len(active_tools) < len(request.tools):
logger.debug("Filtered %d deferred tool schema(s) from model binding", len(request.tools) - len(active)) logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
return request.override(tools=active)
return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None: def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
if not self._deferred: from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None return None
name = str(request.tool_call.get("name") or "")
if not name or name not in self._hidden(request.state): tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id") tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage( return ToolMessage(
content=(f"Error: Tool '{name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."), content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name=name, name=tool_name,
status="error", status="error",
) )
@@ -62,41 +62,6 @@ _AUTH_PATTERNS = (
"未授权", "未授权",
) )
# Per-exception retry budget overrides.
#
# Some transient errors are retriable in principle but expensive to retry at
# the default budget. StreamChunkTimeoutError in particular fires after the
# upstream provider has already stalled for `stream_chunk_timeout` seconds
# (typically 120-240s); a full 3-attempt loop can therefore stack 6-12 minutes
# of dead air before surfacing the failure to the user. We keep exactly one
# retry (cheap reconnect that catches genuine transient TCP blips) and then
# fail fast — the same buffered payload is overwhelmingly likely to fail
# again at the upstream provider for the same reason.
#
# Keys are exception class *names* (not classes) so we don't introduce
# import-time coupling on optional dependencies like langchain-openai. The
# value is the absolute max attempt count, NOT additional retries — so a
# value of 2 means "1 first attempt + 1 retry" (the CR-requested
# "keep one retry" behavior).
_RETRY_BUDGET_OVERRIDES: dict[str, int] = {
"StreamChunkTimeoutError": 2,
}
# Exception class names that indicate the upstream stream-chunk watchdog
# fired because the model stalled mid-flight. These deserve a more specific
# user-facing message than the generic "temporarily unavailable" copy,
# because the typical root cause is a long tool-call serialization stalling
# the upstream stream — and the most actionable advice we can give the user
# is "ask for a shorter / split output" rather than "wait and retry".
# Generic connection drops (httpx RemoteProtocolError / ReadError) are
# intentionally excluded: they routinely fire on transient network blips
# with normal payloads, where the "split the work" guidance is misleading.
_STREAM_DROP_EXCEPTIONS: frozenset[str] = frozenset(
{
"StreamChunkTimeoutError",
}
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages.""" """Retry transient LLM errors and surface graceful assistant messages."""
@@ -118,18 +83,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
self._circuit_state = "closed" self._circuit_state = "closed"
self._circuit_probe_in_flight = False self._circuit_probe_in_flight = False
def _max_attempts_for(self, exc: BaseException) -> int:
"""Return the effective max attempt count for this exception.
Falls back to `self.retry_max_attempts` unless the exception class name
appears in the per-exception override table.
"""
override = _RETRY_BUDGET_OVERRIDES.get(type(exc).__name__)
if override is None:
return self.retry_max_attempts
return min(override, self.retry_max_attempts)
def _check_circuit(self) -> bool: def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise.""" """Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock: with self._circuit_lock:
@@ -200,7 +153,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"InternalServerError", "InternalServerError",
"ReadError", # httpx.ReadError: connection dropped mid-stream "ReadError", # httpx.ReadError: connection dropped mid-stream
"RemoteProtocolError", # httpx: server closed connection unexpectedly "RemoteProtocolError", # httpx: server closed connection unexpectedly
"StreamChunkTimeoutError", # langchain-openai: chunk gap exceeded stream_chunk_timeout
}: }:
return True, "transient" return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES: if status_code in _RETRIABLE_STATUS_CODES:
@@ -225,24 +177,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_circuit_breaker_message(self) -> str: def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again." return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_error_fallback_message(
self,
content: str,
*,
error_type: str,
reason: str,
detail: str,
) -> AIMessage:
return AIMessage(
content=content,
additional_kwargs={
"deerflow_error_fallback": True,
"error_type": error_type,
"error_reason": reason,
"error_detail": detail,
},
)
def _build_user_message(self, exc: BaseException, reason: str) -> str: def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc) detail = _extract_error_detail(exc)
if reason == "quota": if reason == "quota":
@@ -250,31 +184,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
if reason == "auth": if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again." return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}: if reason in {"busy", "transient"}:
# Stream-drop failures (chunk-gap timeout, peer-closed connection,
# raw read error) almost always point at a single oversized
# tool-call payload — the model spent so long serializing JSON
# arguments that the upstream provider buffered and the stream
# gap exceeded `stream_chunk_timeout`. Surfacing this distinct
# cause lets the user split or shorten their next request
# instead of helplessly retrying the same prompt.
if type(exc).__name__ in _STREAM_DROP_EXCEPTIONS:
return (
"The model's streaming response was interrupted before it could "
"finish. This usually happens when a single response or tool call "
"is very large — please ask the assistant to split the work into "
"smaller steps, or shorten the requested output, and try again."
)
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation." return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}" return f"LLM request failed: {detail}"
def _build_user_fallback_message(self, exc: BaseException, reason: str) -> AIMessage:
return self._build_error_fallback_message(
self._build_user_message(exc, reason),
error_type=type(exc).__name__,
reason=reason,
detail=_extract_error_detail(exc),
)
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None: def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try: try:
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
@@ -300,12 +212,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], ModelResponse], handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult: ) -> ModelCallResult:
if self._check_circuit(): if self._check_circuit():
return self._build_error_fallback_message( return AIMessage(content=self._build_circuit_breaker_message())
self._build_circuit_breaker_message(),
error_type="CircuitBreakerOpen",
reason="circuit_open",
detail="LLM circuit breaker is open",
)
attempt = 1 attempt = 1
while True: while True:
@@ -321,8 +228,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
raise raise
except Exception as exc: except Exception as exc:
retriable, reason = self._classify_error(exc) retriable, reason = self._classify_error(exc)
max_attempts = self._max_attempts_for(exc) if retriable and attempt < self.retry_max_attempts:
if retriable and attempt < max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc) wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning( logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s", "Transient LLM error on attempt %d/%d; retrying in %dms: %s",
@@ -343,7 +249,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
) )
if retriable: if retriable:
self._record_failure() self._record_failure()
return self._build_user_fallback_message(exc, reason) return AIMessage(content=self._build_user_message(exc, reason))
@override @override
async def awrap_model_call( async def awrap_model_call(
@@ -352,12 +258,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult: ) -> ModelCallResult:
if self._check_circuit(): if self._check_circuit():
return self._build_error_fallback_message( return AIMessage(content=self._build_circuit_breaker_message())
self._build_circuit_breaker_message(),
error_type="CircuitBreakerOpen",
reason="circuit_open",
detail="LLM circuit breaker is open",
)
attempt = 1 attempt = 1
while True: while True:
@@ -373,8 +274,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
raise raise
except Exception as exc: except Exception as exc:
retriable, reason = self._classify_error(exc) retriable, reason = self._classify_error(exc)
max_attempts = self._max_attempts_for(exc) if retriable and attempt < self.retry_max_attempts:
if retriable and attempt < max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc) wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning( logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s", "Transient LLM error on attempt %d/%d; retrying in %dms: %s",
@@ -395,7 +295,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
) )
if retriable: if retriable:
self._record_failure() self._record_failure()
return self._build_user_fallback_message(exc, reason) return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool: def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
@@ -9,9 +9,8 @@ from typing import Any, Protocol, override, runtime_checkable
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage, get_buffer_string from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langgraph.config import get_config from langgraph.config import get_config
from langgraph.constants import TAG_NOSTREAM
from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@@ -117,74 +116,6 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count) self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens) self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill) self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
# The summary LLM call runs inside a LangGraph middleware hook, so its token
# stream would otherwise be captured by the messages-tuple stream callback and
# broadcast to the frontend as a phantom AI message. Tag a dedicated model copy
# with TAG_NOSTREAM so the streaming handler skips it.
# Keep self.model untagged so the parent's profile / ls_params inspection still works.
#
# Preserve any tags already bound on the model (e.g. "middleware:summarize" set in
# lead_agent/agent.py for RunJournal attribution): RunnableBinding.with_config does a
# shallow merge that would otherwise overwrite the existing tags list entirely.
existing_tags = list((getattr(self.model, "config", None) or {}).get("tags") or [])
merged_tags = [*existing_tags, TAG_NOSTREAM] if TAG_NOSTREAM not in existing_tags else existing_tags
self._summary_model = self.model.with_config(tags=merged_tags)
@override
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
return self._summarize_with(messages_to_summarize)
@override
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
return await self._asummarize_with(messages_to_summarize)
def _summarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Mirror the parent ``_create_summary`` but invoke the nostream-tagged model.
We do not swap ``self.model`` at the instance level: the agent/middleware is
cached and reused across concurrent runs, so a temporary swap would leak the
``RunnableBinding`` to other coroutines during ``await`` and break parent logic
that inspects the raw model (``profile`` / ``_get_ls_params``).
"""
if not messages_to_summarize:
return "No previous conversation history."
prompt = self._build_summary_prompt(messages_to_summarize)
if prompt is None:
return "Previous conversation was too long to summarize."
try:
response = self._summary_model.invoke(
prompt,
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
async def _asummarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Async counterpart of :meth:`_summarize_with` using the nostream model."""
if not messages_to_summarize:
return "No previous conversation history."
prompt = self._build_summary_prompt(messages_to_summarize)
if prompt is None:
return "Previous conversation was too long to summarize."
try:
response = await self._summary_model.ainvoke(
prompt,
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
def _build_summary_prompt(self, messages_to_summarize: list[AnyMessage]) -> str | None:
"""Build the summary prompt, returning ``None`` when trimming leaves nothing."""
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return None
# Format messages to avoid token inflation from metadata when str() is called on
# message objects.
formatted_messages = get_buffer_string(trimmed_messages)
return self.summary_prompt.format(messages=formatted_messages).rstrip()
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime) return self._maybe_summarize(state, runtime)
@@ -58,32 +58,6 @@ def merge_todos(existing: list | None, new: list | None) -> list | None:
return new return new
class PromotedTools(TypedDict):
catalog_hash: str
names: list[str]
def merge_promoted(existing: PromotedTools | None, new: PromotedTools | None) -> PromotedTools | None:
"""Reducer for deferred-tool promotions, scoped by catalog hash.
- new None/empty -> preserve existing (node didn't touch promotions).
- catalog_hash changed -> replace wholesale, dropping stale names (prevents a
persisted bare name from exposing a different tool after catalog drift).
- same catalog_hash -> union names, dedupe, preserve order.
"""
if not new:
return existing
if existing is None or existing.get("catalog_hash") != new["catalog_hash"]:
return {
"catalog_hash": new["catalog_hash"],
"names": list(dict.fromkeys(new["names"])),
}
return {
"catalog_hash": existing["catalog_hash"],
"names": list(dict.fromkeys(existing["names"] + new["names"])),
}
class ThreadState(AgentState): class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None] sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None] thread_data: NotRequired[ThreadDataState | None]
@@ -92,4 +66,3 @@ class ThreadState(AgentState):
todos: Annotated[list | None, merge_todos] todos: Annotated[list | None, merge_todos]
uploaded_files: NotRequired[list[dict] | None] uploaded_files: NotRequired[list[dict] | None]
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type} viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
promoted: Annotated[PromotedTools | None, merge_promoted]
+4 -7
View File
@@ -33,7 +33,7 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.agent import _assemble_deferred, _build_middlewares from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
@@ -237,22 +237,19 @@ class DeerFlowClient:
subagent_enabled = cfg.get("subagent_enabled", False) subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
tools = self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled)
final_tools, deferred_setup = _assemble_deferred(tools, enabled=self._app_config.tool_search.enabled)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
# attach_tracing=False because ``stream()`` injects tracing # attach_tracing=False because ``stream()`` injects tracing
# callbacks at the graph invocation root so a single embedded run # callbacks at the graph invocation root so a single embedded run
# produces one trace with correct session_id / user_id propagation. # produces one trace with correct session_id / user_id propagation.
# Attaching them again on the model would emit duplicate spans. # Attaching them again on the model would emit duplicate spans.
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False), "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False),
"tools": final_tools, "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares, deferred_setup=deferred_setup), "middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"system_prompt": apply_prompt_template( "system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled, subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents, max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name, agent_name=self._agent_name,
available_skills=self._available_skills, available_skills=self._available_skills,
deferred_names=deferred_setup.deferred_names,
), ),
"state_schema": ThreadState, "state_schema": ThreadState,
} }
@@ -1209,7 +1206,7 @@ class DeerFlowClient:
info: dict[str, Any] = { info: dict[str, Any] = {
"filename": dest_name, "filename": dest_name,
"size": dest.stat().st_size, "size": str(dest.stat().st_size),
"path": str(dest), "path": str(dest),
"virtual_path": upload_virtual_path(dest_name), "virtual_path": upload_virtual_path(dest_name),
"artifact_url": upload_artifact_url(thread_id, dest_name), "artifact_url": upload_artifact_url(thread_id, dest_name),
@@ -39,63 +39,11 @@ class AioSandbox(Sandbox):
self._client = AioSandboxClient(base_url=base_url, timeout=600) self._client = AioSandboxClient(base_url=base_url, timeout=600)
self._home_dir = home_dir self._home_dir = home_dir
self._lock = threading.Lock() self._lock = threading.Lock()
self._closed = False
@property @property
def base_url(self) -> str: def base_url(self) -> str:
return self._base_url return self._base_url
def close(self) -> None:
"""Best-effort close of the host-side HTTP client owned by this sandbox.
The agent_sandbox SDK is Fern-generated and exposes no ``close()`` /
``__exit__``, so we reach the socket-owning ``httpx.Client`` explicitly
through its attribute chain::
Sandbox._client_wrapper -> SyncClientWrapper
.httpx_client -> Fern HttpClient (a wrapper, NOT httpx.Client)
.httpx_client -> httpx.Client <- the real socket owner
Closing it releases pooled sockets so long-running provider lifecycles
do not accumulate unreclaimed host-side resources (#2872).
Resolution is most-specific-first with graceful degradation: if a future
SDK adds a top-level ``Sandbox.close()`` it is picked up automatically
without changing this code. Idempotent, thread-safe, and non-fatal:
failures during teardown are logged and swallowed so provider/backend
cleanup is never blocked.
"""
with self._lock:
if self._closed:
return
self._closed = True
client = self._client
# Drop the reference under the lock for use-after-close safety: any
# later command on this instance fails loudly instead of reusing a
# half-closed client.
self._client = None
if client is None:
return
# Walk from the real httpx.Client up to the top-level client, picking the
# first object that actually exposes close().
wrapper = getattr(client, "_client_wrapper", None)
fern_http = getattr(wrapper, "httpx_client", None)
real_httpx = getattr(fern_http, "httpx_client", None)
target = next(
(c for c in (real_httpx, fern_http, client) if c is not None and hasattr(c, "close")),
None,
)
if target is None:
logger.debug("AioSandbox %s: no closable client found, nothing to release", self.id)
return
try:
target.close()
except Exception as e:
logger.warning(f"Error closing AioSandbox client for {self.id}: {e}")
@property @property
def home_dir(self) -> str: def home_dir(self) -> str:
"""Get the home directory inside the sandbox.""" """Get the home directory inside the sandbox."""
@@ -790,20 +790,14 @@ class AioSandboxProvider(SandboxProvider):
thread on its next turn without a cold-start. The container will only be thread on its next turn without a cold-start. The container will only be
stopped when the replicas limit forces eviction or during shutdown. stopped when the replicas limit forces eviction or during shutdown.
The host-side HTTP client owned by the cached ``AioSandbox`` instance is
closed before the instance is dropped (#2872). The warm-pool entry only
stores ``SandboxInfo``, so a fresh ``AioSandbox`` (and a fresh client)
is constructed if the container is later reclaimed.
Args: Args:
sandbox_id: The ID of the sandbox to release. sandbox_id: The ID of the sandbox to release.
""" """
info = None info = None
sandbox = None
thread_ids_to_remove: list[str] = [] thread_ids_to_remove: list[str] = []
with self._lock: with self._lock:
sandbox = self._sandboxes.pop(sandbox_id, None) self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None) info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id] thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove: for tid in thread_ids_to_remove:
@@ -813,15 +807,6 @@ class AioSandboxProvider(SandboxProvider):
if info and sandbox_id not in self._warm_pool: if info and sandbox_id not in self._warm_pool:
self._warm_pool[sandbox_id] = (info, time.time()) self._warm_pool[sandbox_id] = (info, time.time())
if sandbox is not None:
# Defense-in-depth: close() already swallows its own errors; this
# guard only protects against a future close() that misbehaves, so
# host-side client cleanup can never block parking in the warm pool.
try:
sandbox.close()
except Exception as e:
logger.warning(f"Error closing sandbox {sandbox_id} during release: {e}")
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)") logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
def destroy(self, sandbox_id: str) -> None: def destroy(self, sandbox_id: str) -> None:
@@ -830,19 +815,14 @@ class AioSandboxProvider(SandboxProvider):
Unlike release(), this actually stops the container. Use this for Unlike release(), this actually stops the container. Use this for
explicit cleanup, capacity-driven eviction, or shutdown. explicit cleanup, capacity-driven eviction, or shutdown.
The host-side HTTP client owned by the cached ``AioSandbox`` instance is
closed alongside backend/container destruction so no client/socket
resources leak (#2872).
Args: Args:
sandbox_id: The ID of the sandbox to destroy. sandbox_id: The ID of the sandbox to destroy.
""" """
info = None info = None
sandbox = None
thread_ids_to_remove: list[str] = [] thread_ids_to_remove: list[str] = []
with self._lock: with self._lock:
sandbox = self._sandboxes.pop(sandbox_id, None) self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None) info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id] thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove: for tid in thread_ids_to_remove:
@@ -854,15 +834,6 @@ class AioSandboxProvider(SandboxProvider):
else: else:
self._warm_pool.pop(sandbox_id, None) self._warm_pool.pop(sandbox_id, None)
if sandbox is not None:
# Defense-in-depth: close() already swallows its own errors; this
# guard only protects against a future close() that misbehaves, so
# host-side client cleanup can never block container destruction.
try:
sandbox.close()
except Exception as e:
logger.warning(f"Error closing sandbox {sandbox_id} during destroy: {e}")
if info: if info:
self._backend.destroy(info) self._backend.destroy(info)
logger.info(f"Destroyed sandbox {sandbox_id}") logger.info(f"Destroyed sandbox {sandbox_id}")
@@ -18,7 +18,6 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_
from deerflow.config.loop_detection_config import LoopDetectionConfig from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.reload_boundary import format_field_description
from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.runtime_paths import existing_project_file from deerflow.config.runtime_paths import existing_project_file
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
@@ -86,21 +85,10 @@ def apply_logging_level(name: str | None) -> None:
class AppConfig(BaseModel): class AppConfig(BaseModel):
"""Config for the DeerFlow application""" """Config for the DeerFlow application"""
log_level: str = Field( log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected")
default="info",
description=format_field_description(
"log_level",
field_doc="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected.",
),
)
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration") token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
models: list[ModelConfig] = Field(default_factory=list, description="Available models") models: list[ModelConfig] = Field(default_factory=list, description="Available models")
sandbox: SandboxConfig = Field( sandbox: SandboxConfig = Field(description="Sandbox configuration")
description=format_field_description(
"sandbox",
field_doc="Sandbox provider configuration (local filesystem or Docker-based aio sandbox).",
),
)
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools") tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups") tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration") skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
@@ -119,34 +107,10 @@ class AppConfig(BaseModel):
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field( database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
default_factory=DatabaseConfig, run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
description=format_field_description( checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
"database", stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
field_doc="Unified database backend for run/feedback metadata (memory, sqlite, or postgres).",
),
)
run_events: RunEventsConfig = Field(
default_factory=RunEventsConfig,
description=format_field_description(
"run_events",
field_doc="Run-event store backend (memory for dev, db for production queries, jsonl for lightweight single-node persistence).",
),
)
checkpointer: CheckpointerConfig | None = Field(
default=None,
description=format_field_description(
"checkpointer",
field_doc="LangGraph state-persistence checkpointer configuration.",
),
)
stream_bridge: StreamBridgeConfig | None = Field(
default=None,
description=format_field_description(
"stream_bridge",
field_doc="Stream bridge connecting agent workers to SSE endpoints.",
),
)
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path: def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -5,7 +5,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.runtime_paths import existing_project_file from deerflow.config.runtime_paths import existing_project_file
@@ -47,24 +47,6 @@ class McpServerConfig(BaseModel):
description: str = Field(default="", description="Human-readable description of what this MCP server provides") description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@model_validator(mode="before")
@classmethod
def _accept_transport_alias(cls, data: Any) -> Any:
"""Accept the MCP-spec ``transport`` field as an alias for ``type``.
The official MCP configuration schema uses ``transport`` to indicate
the transport mechanism (``stdio``/``sse``/``http``). Earlier versions
of this project only honored ``type``, which caused remote SSE/HTTP
servers configured with just ``transport`` to be incorrectly treated as
``stdio`` (the default). This validator normalizes the two so either
spelling works, with ``type`` taking precedence when both are provided.
"""
if isinstance(data, dict):
transport = data.get("transport")
if transport and not data.get("type"):
data = {**data, "type": transport}
return data
class SkillStateConfig(BaseModel): class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state.""" """Configuration for a single skill's state."""
@@ -32,16 +32,6 @@ class ModelConfig(BaseModel):
description="Extra settings to be passed to the model when thinking is disabled", description="Extra settings to be passed to the model when thinking is disabled",
) )
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs") supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
stream_chunk_timeout: float | None = Field(
default=None,
description=(
"Maximum seconds to wait between successive streaming chunks before "
"langchain-openai raises StreamChunkTimeoutError. None means use the "
"factory default (240s for OpenAI-compatible clients). Tune higher for "
"reasoning models with long thinking pauses; lower for latency-sensitive "
"interactive endpoints. Has no effect on non-OpenAI-compatible providers."
),
)
thinking: dict | None = Field( thinking: dict | None = Field(
default_factory=lambda: None, default_factory=lambda: None,
description=( description=(
@@ -1,4 +1,3 @@
import hashlib
import os import os
import re import re
import shutil import shutil
@@ -11,8 +10,6 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
def _default_local_base_dir() -> Path: def _default_local_base_dir() -> Path:
@@ -34,23 +31,6 @@ def _validate_user_id(user_id: str) -> str:
return user_id return user_id
def make_safe_user_id(raw: str) -> str:
"""Normalize an external identity into the user-id charset (``[A-Za-z0-9_-]``).
IM channel ids (Feishu/Slack/Telegram) may contain characters that
:func:`_validate_user_id` rejects. Already-safe ids pass through unchanged;
lossy ones get a short digest suffix so two distinct inputs never share a
storage bucket.
"""
if not raw:
raise ValueError("user_id must be a non-empty string.")
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
if sanitized == raw:
return raw
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
return f"{sanitized}-{digest}"
def _join_host_path(base: str, *parts: str) -> str: def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style. """Join host filesystem path segments while preserving native style.
@@ -1,104 +0,0 @@
"""Single source of truth for the config hot-reload boundary.
Bytedance/deer-flow issue #3144: gateway request dependencies resolve
``AppConfig`` through ``get_app_config()`` on every request, so per-run
fields take effect on the next message without restarting the gateway.
The fields listed in this module are the **infrastructure** subset that
the gateway captures once at startup — engines, singletons, IM clients,
the logging handler — and that therefore require a process restart to
change at runtime.
The registry covers two kinds of entries:
- Top-level ``AppConfig`` fields (``database``, ``checkpointer``,
``run_events``, ``stream_bridge``, ``sandbox``, ``log_level``). For
these, :func:`format_field_description` produces the standardised
``"startup-only: ..."`` prefix that the matching Pydantic
``Field(description=...)`` carries, so the boundary surfaces in IDE
hover next to the field itself.
- Top-level ``config.yaml`` sections that are not part of the
``AppConfig`` schema (``channels``). These cannot be standardised at
the schema level, so the registry is their only canonical location.
Any future "needs restart" scanner — operator tooling, lint hooks, doc
generators — should drive off this registry rather than re-parsing
prose.
"""
from __future__ import annotations
from collections.abc import Iterator
#: The standardised prefix every restart-required field description starts
#: with. ``test_reload_boundary`` enforces both directions: registered
#: fields must use this prefix in the schema, and any schema field using
#: this prefix must be in the registry.
STARTUP_ONLY_PREFIX = "startup-only:"
#: Restart-required field paths mapped to the human-readable reason.
#:
#: The reason text is what surfaces in ``Field(description=...)``, so it
#: must explain *what* code captures the snapshot — not just that the
#: field is restart-required — so an operator changing the value knows
#: which subsystem to restart.
STARTUP_ONLY_FIELDS: dict[str, str] = {
"database": ("init_engine_from_config() runs once during langgraph_runtime() startup; the SQLAlchemy engine holds the connection pool and is not rebuilt on config.yaml edits."),
"checkpointer": ("make_checkpointer() binds the persistent checkpointer once at startup, including SQLite WAL / busy_timeout settings."),
"run_events": ("make_run_event_store() picks the memory- vs SQL-backed implementation at startup and is frozen onto app.state.run_events_config to stay paired with the underlying event store."),
"stream_bridge": ("make_stream_bridge() constructs the stream-bridge singleton once during startup."),
"sandbox": ("get_sandbox_provider() caches the provider singleton (``_default_sandbox_provider``); a different ``sandbox.use`` class path only takes effect on next process start."),
"log_level": (
"apply_logging_level() runs only during app.py startup; it sets the deerflow/app logger levels and may lower root handler thresholds so configured messages can propagate. A freshly reloaded AppConfig does not retrigger it."
),
# Not part of the AppConfig Pydantic schema — channel credentials are
# consumed directly by ``start_channel_service()`` once at lifespan
# startup and the live channel clients are not rebuilt on
# config.yaml edits.
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
}
def iter_startup_only_field_paths() -> Iterator[str]:
"""Yield every registered restart-required field path."""
return iter(STARTUP_ONLY_FIELDS)
def is_startup_only_field(field_path: str) -> bool:
"""Return ``True`` when *field_path* is registered as restart-required.
Accepts only top-level paths (``"database"``, ``"sandbox"`` etc.);
nested keys like ``"database.url"`` are not modelled here because the
boundary is per-section, not per-leaf.
"""
return field_path in STARTUP_ONLY_FIELDS
def format_field_description(field_path: str, *, field_doc: str | None = None) -> str:
"""Build the standardised description for a registered field.
Used inside ``AppConfig`` ``Field(description=...)`` so the hover
text in IDEs matches the registry and the drift tests can pin one
side against the other.
Args:
field_path: A registered top-level field path (e.g. ``"log_level"``).
field_doc: Optional human-facing description for the field itself
(allowed values, semantics, etc.). When supplied, it is
appended after the ``startup-only:`` marker block separated by
a blank line so IDE hover shows both the restart-required
reason *and* the field's normal documentation. Composition
keeps the marker as the leading token machine-readable tooling
pivots on while restoring the prose that ``Field(description=)``
used to carry before the registry took over.
Raises:
KeyError: when *field_path* is not registered. This is deliberate
— silently returning a placeholder would let a typo bypass
the drift coverage.
"""
reason = STARTUP_ONLY_FIELDS[field_path]
header = f"{STARTUP_ONLY_PREFIX} {reason}"
if field_doc is None:
return header
return f"{header}\n\n{field_doc.strip()}"
@@ -1,10 +1,6 @@
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters.""" """MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
from .cache import ( from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
get_cached_mcp_tools,
initialize_mcp_tools,
reset_mcp_tools_cache,
)
from .client import build_server_params, build_servers_config from .client import build_server_params, build_servers_config
from .tools import get_mcp_tools from .tools import get_mcp_tools
+2 -11
View File
@@ -143,20 +143,11 @@ def reset_mcp_tools_cache() -> None:
# Close persistent sessions they will be recreated by the next # Close persistent sessions they will be recreated by the next
# get_mcp_tools() call with the (possibly updated) connection config. # get_mcp_tools() call with the (possibly updated) connection config.
#
# close_all_sync() already picks the correct strategy per owning loop:
# * sessions owned by the *current* running loop are only *signalled*
# (their owner task runs __aexit__ once the loop regains control
# this is correct and leak-free, since the loop keeps the task alive),
# * sessions on other threads' loops are torn down deterministically,
# * idle/closed loops are handled or skipped.
# We deliberately do NOT try to synchronously wait for the current running
# loop to finish teardown here: that is a self-deadlock (the loop can only
# run the teardown after this synchronous call returns control to it).
try: try:
from deerflow.mcp.session_pool import get_session_pool from deerflow.mcp.session_pool import get_session_pool
get_session_pool().close_all_sync() pool = get_session_pool()
pool.close_all_sync()
except Exception: except Exception:
logger.debug("Could not close MCP session pool on cache reset", exc_info=True) logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
@@ -8,27 +8,6 @@ This module provides a session pool that maintains persistent MCP sessions,
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id — scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
so that consecutive tool calls share the same session and server-side state. so that consecutive tool calls share the same session and server-side state.
Sessions are evicted in LRU order when the pool reaches capacity. Sessions are evicted in LRU order when the pool reaches capacity.
Lifecycle model (owner task)
----------------------------
An MCP ``ClientSession`` is implemented on top of an ``anyio`` task group, and
anyio enforces that a cancel scope must be exited from the *same task* that
entered it. Calling ``cm.__aexit__`` from any task other than the one that ran
``cm.__aenter__`` raises::
RuntimeError: Attempted to exit cancel scope in a different task than it
was entered in
The sync-tool path (``make_sync_tool_wrapper``) drives each call through a fresh
``asyncio.run`` event loop, so a session entered while answering one call would
otherwise be exited while answering another — from a different task — and crash
(GitHub issue #3379).
To make this impossible, every pooled session is owned by a dedicated
``_run_session`` task. That task enters the context manager, hands the live
session back to the caller, and then *waits* on a close event. All shutdown
paths only ever **signal** that event; the owner task performs ``__aexit__``
itself, guaranteeing enter and exit always happen in the same task.
""" """
from __future__ import annotations from __future__ import annotations
@@ -48,81 +27,18 @@ class MCPSessionPool:
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``.""" """Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
MAX_SESSIONS = 256 MAX_SESSIONS = 256
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session on a foreign loop SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
def __init__(self) -> None: def __init__(self) -> None:
# Each entry: (session, owning_loop, owner_task, close_event).
self._entries: OrderedDict[ self._entries: OrderedDict[
tuple[str, str], tuple[str, str],
tuple[ tuple[ClientSession, asyncio.AbstractEventLoop],
ClientSession,
asyncio.AbstractEventLoop,
asyncio.Task[Any],
asyncio.Event,
],
] = OrderedDict() ] = OrderedDict()
# In-flight creations, keyed by (server, scope). Lets concurrent callers self._context_managers: dict[tuple[str, str], Any] = {}
# on the same loop share a single creation instead of each spawning a
# duplicate session. Value: (loop, ready_future, owner_task, close_event).
self._inflight: dict[
tuple[str, str],
tuple[
asyncio.AbstractEventLoop,
asyncio.Future[ClientSession],
asyncio.Task[Any],
asyncio.Event,
],
] = {}
# threading.Lock is not bound to any event loop, so it is safe to # threading.Lock is not bound to any event loop, so it is safe to
# acquire from both async paths and sync/worker-thread paths. # acquire from both async paths and sync/worker-thread paths.
self._lock = threading.Lock() self._lock = threading.Lock()
# ------------------------------------------------------------------
# Session owner task
# ------------------------------------------------------------------
async def _run_session(
self,
connection: dict[str, Any],
ready: asyncio.Future[ClientSession],
close_evt: asyncio.Event,
) -> None:
"""Own a single MCP session for its entire lifetime.
Enters the session context manager, initializes it, publishes the live
session via ``ready``, then blocks until ``close_evt`` is set. The
context manager is *always* exited from this task, satisfying anyio's
cancel-scope same-task requirement.
"""
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
try:
session = await cm.__aenter__()
except BaseException as e:
# Never entered the cancel scope, so there is nothing to exit.
if not ready.done():
ready.set_exception(e)
return
# The context manager is now entered. From here on __aexit__ MUST run in
# this task — on init failure, on cancellation, or on the close signal —
# to satisfy anyio's same-task cancel-scope requirement and to avoid
# leaking the session/subprocess.
try:
await session.initialize()
if not ready.done():
ready.set_result(session)
await close_evt.wait()
except BaseException as e:
if not ready.done():
ready.set_exception(e)
finally:
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session", exc_info=True)
async def get_session( async def get_session(
self, self,
server_name: str, server_name: str,
@@ -131,9 +47,9 @@ class MCPSessionPool:
) -> ClientSession: ) -> ClientSession:
"""Get or create a persistent MCP session. """Get or create a persistent MCP session.
If an existing session was created in a different (or closed) event If an existing session was created in a different event loop (e.g.
loop, it is evicted and replaced with a fresh one owned by a task on the sync-wrapper path), it is closed and replaced with a fresh one
the current loop. in the current loop.
Args: Args:
server_name: MCP server name. server_name: MCP server name.
@@ -147,118 +63,44 @@ class MCPSessionPool:
current_loop = asyncio.get_running_loop() current_loop = asyncio.get_running_loop()
# Phase 1: inspect/mutate the registry under the thread lock (no awaits). # Phase 1: inspect/mutate the registry under the thread lock (no awaits).
# Decide one of three outcomes atomically: return an existing session, cms_to_close: list[tuple[tuple[str, str], Any]] = []
# join an in-flight creation, or become the creator for this key.
# Each item: (loop, owner_task, close_event, cancel). ``cancel`` is True
# for in-flight creations, whose owner may be blocked inside
# ``initialize()`` where close_evt cannot wake it — it must be cancelled.
evicted: list[tuple[asyncio.AbstractEventLoop, asyncio.Task[Any], asyncio.Event, bool]] = []
join: asyncio.Future[ClientSession] | None = None
ready: asyncio.Future[ClientSession] | None = None
close_evt: asyncio.Event | None = None
task: asyncio.Task[Any] | None = None
with self._lock: with self._lock:
if key in self._entries: if key in self._entries:
session, loop, ent_task, ent_close = self._entries[key] session, loop = self._entries[key]
if loop is current_loop and not loop.is_closed(): if loop is current_loop:
self._entries.move_to_end(key) self._entries.move_to_end(key)
return session return session
# Session belongs to a different/closed event loop evict it. # Session belongs to a different event loop evict it.
cm = self._context_managers.pop(key, None)
self._entries.pop(key) self._entries.pop(key)
evicted.append((loop, ent_task, ent_close, False)) if cm is not None:
cms_to_close.append((key, cm))
inflight = self._inflight.get(key)
if inflight is not None and inflight[0] is current_loop and not inflight[0].is_closed():
# Another caller on this loop is already creating the session;
# wait for the same result instead of building a duplicate.
join = inflight[1]
else:
if inflight is not None:
# Stale in-flight creation owned by a different/closed loop.
# Drop the record and tear its owner down; because that owner
# may be blocked inside initialize() (where close_evt cannot
# wake it), it must be cancelled. We then create a fresh
# session here.
self._inflight.pop(key)
evicted.append((inflight[0], inflight[2], inflight[3], True))
# Become the creator: publish an in-flight record before any
# await so concurrent callers join us instead of racing.
ready = current_loop.create_future()
close_evt = asyncio.Event()
task = current_loop.create_task(self._run_session(connection, ready, close_evt))
self._inflight[key] = (current_loop, ready, task, close_evt)
# Evict LRU entries when at capacity. # Evict LRU entries when at capacity.
while len(self._entries) >= self.MAX_SESSIONS: while len(self._entries) >= self.MAX_SESSIONS:
oldest_key, (_, loop, ent_task, ent_close) = next(iter(self._entries.items())) oldest_key = next(iter(self._entries))
cm = self._context_managers.pop(oldest_key, None)
self._entries.pop(oldest_key) self._entries.pop(oldest_key)
evicted.append((loop, ent_task, ent_close, False)) if cm is not None:
cms_to_close.append((oldest_key, cm))
# Phase 2: shut down evicted sessions/creations. Same-loop owners are # Phase 2: async cleanup outside the lock so we never await while holding it.
# awaited so they finish deterministically; foreign-loop owners are for close_key, cm in cms_to_close:
# routed to their own loop. In every case the owner task — never this
# one — runs __aexit__. In-flight owners are cancelled (cancel=True) so a
# blocking initialize() cannot leave them hung.
for loop, ent_task, ent_close, cancel in evicted:
if loop is current_loop and not loop.is_closed():
await self._shutdown(ent_close, ent_task, cancel)
elif cancel:
await self._shutdown_entry(loop, ent_task, ent_close, cancel=True)
else:
self._signal_close(loop, ent_close)
# Phase 2b: a concurrent creation for this key is already in progress on
# this loop — share its result rather than create a duplicate session.
if join is not None:
return await asyncio.shield(join)
assert ready is not None and close_evt is not None and task is not None
# Phase 3: wait for our owner task to publish the initialized session.
try:
session = await asyncio.shield(ready)
except BaseException:
# Two distinct cases reach here:
#
# 1. The owner task failed (e.g. connect/initialize error) and
# reported it via ready.set_exception(). It is *already* in its
# finally block running cm.__aexit__ in its own task, so we must
# NOT cancel it — doing so would interrupt that cleanup. We only
# wait for it to finish unwinding.
# 2. This call itself was cancelled (CancelledError). Because of the
# shield, `ready` is still pending and the owner task is alive and
# blocked. We signal close and cancel it so it exits the cancel
# scope in its own task, then wait for it to finish.
#
# The session is never registered yet, so nobody else can close it;
# waiting here guarantees we never leak a session or owner task.
owner_already_failed = ready.done() and not ready.cancelled() and ready.exception() is not None
if not owner_already_failed:
close_evt.set()
task.cancel()
try: try:
await asyncio.shield(task) await cm.__aexit__(None, None, None)
except BaseException: except Exception:
logger.debug("Owner task ended during get_session unwind", exc_info=True) logger.warning("Error closing MCP session %s", close_key, exc_info=True)
with self._lock:
if self._inflight.get(key) == (current_loop, ready, task, close_evt):
self._inflight.pop(key)
raise
# Phase 4: promote the in-flight creation to a registered entry — but from langchain_mcp_adapters.sessions import create_session
# only if our in-flight record is still the live one. A concurrent
# close_* / close_all may have removed it while we were initializing; in cm = create_session(connection)
# that case we must NOT resurrect the session into _entries. Instead we session = await cm.__aenter__()
# own the teardown: signal our owner task and wait for it to run await session.initialize()
# __aexit__ in its own task, then surface the cancellation.
# Phase 3: register the new session under the lock.
with self._lock: with self._lock:
still_ours = self._inflight.get(key) == (current_loop, ready, task, close_evt) self._entries[key] = (session, current_loop)
if still_ours: self._context_managers[key] = cm
self._inflight.pop(key)
self._entries[key] = (session, current_loop, task, close_evt)
if not still_ours:
await self._shutdown(close_evt, task)
raise asyncio.CancelledError("MCP session pool was closed while the session was being created")
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key) logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
return session return session
@@ -266,169 +108,70 @@ class MCPSessionPool:
# Cleanup helpers # Cleanup helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
def _signal_close(loop: asyncio.AbstractEventLoop, close_evt: asyncio.Event) -> None: """Close a single context manager (must be called WITHOUT the lock)."""
"""Ask an owner task to shut down without waiting.
``asyncio.Event.set`` is not thread-safe, so it is scheduled on the
owning loop. A closed loop means the owner task is already gone.
"""
if loop.is_closed():
return
try: try:
loop.call_soon_threadsafe(close_evt.set) await cm.__aexit__(None, None, None)
except RuntimeError: except Exception:
# Loop was closed between the is_closed() check and now. logger.warning("Error closing MCP session %s", key, exc_info=True)
pass
async def _shutdown(
self,
close_evt: asyncio.Event,
task: asyncio.Task[Any],
cancel: bool = False,
) -> None:
"""Signal an owner task and wait for it to finish (runs on its loop).
``cancel=True`` is used for in-flight creations: the owner task may be
blocked inside ``initialize()`` where ``close_evt`` cannot wake it, so it
must be cancelled. Its ``finally`` block still runs ``__aexit__`` in its
own task, satisfying anyio's same-task cancel-scope requirement.
"""
close_evt.set()
if cancel:
task.cancel()
try:
await task
except (Exception, asyncio.CancelledError):
logger.debug("Owner task ended during shutdown", exc_info=True)
async def _shutdown_entry(
self,
loop: asyncio.AbstractEventLoop,
task: asyncio.Task[Any],
close_evt: asyncio.Event,
cancel: bool = False,
) -> None:
"""Shut down one entry, routing the close to its owning loop."""
if loop.is_closed():
return
current_loop = asyncio.get_running_loop()
if loop is current_loop:
await self._shutdown(close_evt, task, cancel)
elif loop.is_running():
future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop)
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning("Error closing MCP session on owning loop", exc_info=True)
else:
# Owning loop exists but is neither the current loop nor running.
# We are inside an async context here, so run_until_complete() would
# raise "Cannot run the event loop while another loop is running";
# and the loop may belong to another thread, where driving it from
# here is unsafe. This branch is not expected in practice — a
# session's owning loop is either the long-lived gateway loop (which
# is running) or a short-lived asyncio.run loop (which is closed and
# caught above). Fall back to a best-effort thread-safe signal so the
# owner task tears down if/when its loop runs again.
logger.warning("Owning loop for MCP session is idle; signalling close best-effort. Session may leak until the loop runs again.")
self._signal_close(loop, close_evt)
if cancel:
try:
loop.call_soon_threadsafe(task.cancel)
except RuntimeError:
pass
async def close_scope(self, scope_key: str) -> None: async def close_scope(self, scope_key: str) -> None:
"""Close all sessions for a given scope (e.g. thread_id).""" """Close all sessions for a given scope (e.g. thread_id)."""
with self._lock: with self._lock:
keys = [k for k in self._entries if k[1] == scope_key] keys = [k for k in self._entries if k[1] == scope_key]
entries = [(self._entries.pop(k)) for k in keys] cms = [(k, self._context_managers.pop(k, None)) for k in keys]
inflight_keys = [k for k in self._inflight if k[1] == scope_key] for k in keys:
inflight = [self._inflight.pop(k) for k in inflight_keys] self._entries.pop(k, None)
for _session, loop, task, close_evt in entries: for key, cm in cms:
await self._shutdown_entry(loop, task, close_evt) if cm is not None:
for loop, _ready, task, close_evt in inflight: await self._close_cm(key, cm)
await self._shutdown_entry(loop, task, close_evt, cancel=True)
async def close_server(self, server_name: str) -> None: async def close_server(self, server_name: str) -> None:
"""Close all sessions for a given server.""" """Close all sessions for a given server."""
with self._lock: with self._lock:
keys = [k for k in self._entries if k[0] == server_name] keys = [k for k in self._entries if k[0] == server_name]
entries = [(self._entries.pop(k)) for k in keys] cms = [(k, self._context_managers.pop(k, None)) for k in keys]
inflight_keys = [k for k in self._inflight if k[0] == server_name] for k in keys:
inflight = [self._inflight.pop(k) for k in inflight_keys] self._entries.pop(k, None)
for _session, loop, task, close_evt in entries: for key, cm in cms:
await self._shutdown_entry(loop, task, close_evt) if cm is not None:
for loop, _ready, task, close_evt in inflight: await self._close_cm(key, cm)
await self._shutdown_entry(loop, task, close_evt, cancel=True)
async def close_all(self) -> None: async def close_all(self) -> None:
"""Close every managed session.""" """Close every managed session."""
with self._lock: with self._lock:
entries = list(self._entries.values()) cms = list(self._context_managers.items())
self._context_managers.clear()
self._entries.clear() self._entries.clear()
inflight = list(self._inflight.values()) for key, cm in cms:
self._inflight.clear() await self._close_cm(key, cm)
for _session, loop, task, close_evt in entries:
await self._shutdown_entry(loop, task, close_evt)
for loop, _ready, task, close_evt in inflight:
await self._shutdown_entry(loop, task, close_evt, cancel=True)
def close_all_sync(self) -> None: def close_all_sync(self) -> None:
"""Close all sessions on their owning event loops (synchronous). """Close all sessions using their owning event loops (synchronous).
Each session is closed by its owner task on the loop it was created in, Each session is closed on the loop it was created in, avoiding
avoiding cross-loop and cross-task errors. Safe to call from any thread cross-loop resource leaks. Safe to call from any thread without an
without an active event loop. active event loop.
Closing semantics differ by where the owning loop runs:
* Owning loop is idle, or running on another thread — this call blocks
until teardown completes (or ``SESSION_CLOSE_TIMEOUT`` elapses).
* Owning loop is the one currently running on *this* thread — we cannot
block on it without deadlocking, so teardown is only *signalled* here
and completes asynchronously once control returns to that loop. The
caller must therefore keep that loop running afterwards; if it stops
the loop immediately, the owner task's ``__aexit__`` may not run. When
a deterministic close is required from inside a running loop, ``await
close_all()`` instead.
""" """
with self._lock: with self._lock:
entries = list(self._entries.values()) entries = list(self._entries.items())
cms = dict(self._context_managers)
self._entries.clear() self._entries.clear()
inflight = list(self._inflight.values()) self._context_managers.clear()
self._inflight.clear()
# Entries are initialized (gentle close_evt path). In-flight creations for key, (_, loop) in entries:
# may be blocked mid-init, so they are cancelled to unblock teardown. cm = cms.get(key)
owners = [(loop, task, close_evt, False) for _s, loop, task, close_evt in entries] if cm is None or loop.is_closed():
owners += [(loop, task, close_evt, True) for loop, _r, task, close_evt in inflight]
try:
current_running_loop = asyncio.get_running_loop()
except RuntimeError:
current_running_loop = None
for loop, task, close_evt, cancel in owners:
if loop.is_closed():
continue continue
try: try:
if loop is current_running_loop: if loop.is_running():
# We are executing inside this loop's thread, so synchronously # Schedule on the owning loop from this (different) thread.
# waiting on run_coroutine_threadsafe(...).result() would future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
# deadlock until timeout. Signal the owner task directly and
# let it finish once this synchronous call returns control to
# the running loop.
close_evt.set()
if cancel:
task.cancel()
elif loop.is_running():
# Schedule the shutdown on the owning loop from this thread.
future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop)
future.result(timeout=self.SESSION_CLOSE_TIMEOUT) future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
else: else:
loop.run_until_complete(self._shutdown(close_evt, task, cancel)) loop.run_until_complete(cm.__aexit__(None, None, None))
except Exception: except Exception:
logger.debug("Error closing MCP session during sync close", exc_info=True) logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
+1 -10
View File
@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Mapping
from typing import Any from typing import Any
from langchain_core.tools import BaseTool, StructuredTool from langchain_core.tools import BaseTool, StructuredTool
@@ -138,15 +137,7 @@ def _make_session_pool_tool(
from langchain_mcp_adapters.interceptors import MCPToolCallRequest from langchain_mcp_adapters.interceptors import MCPToolCallRequest
async def base_handler(request: MCPToolCallRequest) -> Any: async def base_handler(request: MCPToolCallRequest) -> Any:
# Preserve interceptor-injected headers for stdio MCP calls by return await session.call_tool(request.name, request.args)
# forwarding them through MCP call meta.
call_kwargs: dict[str, Any] = {}
if request.headers:
if isinstance(request.headers, Mapping):
call_kwargs["meta"] = {"headers": dict(request.headers)}
else:
logger.warning("Ignoring MCP interceptor headers with unsupported type: %s", type(request.headers).__name__)
return await session.call_tool(request.name, request.args, **call_kwargs)
handler = base_handler handler = base_handler
for interceptor in reversed(tool_interceptors): for interceptor in reversed(tool_interceptors):
@@ -47,38 +47,6 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True model_settings_from_config["stream_usage"] = True
# Default chunk-gap budget for OpenAI-compatible streaming responses.
#
# langchain-openai raises ``StreamChunkTimeoutError`` after this many seconds
# without receiving a chunk. Its own default is 60s, which is too aggressive for
# reasoning models (DeepSeek-R1, Doubao-thinking, GPT-5) whose first chunk can
# legitimately take 90~150s. We default to 240s so the streaming layer rarely
# trips on long thinking pauses; the LLMErrorHandlingMiddleware still retries
# (budget=2) if a real stall happens. Users can override per-model in config.yaml.
_DEFAULT_STREAM_CHUNK_TIMEOUT_SECONDS: float = 240.0
def _apply_stream_chunk_timeout_default(model_use_path: str, model_settings_from_config: dict) -> None:
"""Inject a generous ``stream_chunk_timeout`` for OpenAI-compatible clients.
The ``stream_chunk_timeout`` kwarg is specific to ``langchain_openai:ChatOpenAI``
and is rejected by other providers' constructors as an unexpected keyword
argument. Behaviour:
* OpenAI-compatible path: an explicit value in ``config.yaml`` is preserved.
An explicit ``null`` is dropped upstream by ``model_dump(exclude_none=True)``
and therefore treated as "unset", so the default is injected.
* Non-OpenAI path: drop the key so it is never forwarded to an incompatible
constructor (which would raise ``TypeError: unexpected keyword argument``).
"""
if model_use_path != "langchain_openai:ChatOpenAI":
model_settings_from_config.pop("stream_chunk_timeout", None)
return
if "stream_chunk_timeout" in model_settings_from_config:
return
model_settings_from_config["stream_chunk_timeout"] = _DEFAULT_STREAM_CHUNK_TIMEOUT_SECONDS
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, attach_tracing: bool = True, **kwargs) -> BaseChatModel: def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, attach_tracing: bool = True, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config. """Create a chat model instance from the config.
@@ -160,7 +128,6 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
model_settings_from_config.pop("reasoning_effort", None) model_settings_from_config.pop("reasoning_effort", None)
_enable_stream_usage_by_default(model_config.use, model_settings_from_config) _enable_stream_usage_by_default(model_config.use, model_settings_from_config)
_apply_stream_chunk_timeout_default(model_config.use, model_settings_from_config)
# For Codex Responses API models: map thinking mode to reasoning_effort # For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel from deerflow.models.openai_codex_provider import CodexChatModel
@@ -47,41 +47,6 @@ def _prepare_database_sqlite_checkpointer_path(db_config) -> str:
return conn_str return conn_str
def _build_postgres_pool(conn_string: str):
"""Build an AsyncConnectionPool with TCP keepalive and connection checking."""
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
return AsyncConnectionPool(
conn_string,
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 6,
},
check=AsyncConnectionPool.check_connection,
)
def _ensure_postgres_imports():
"""Import and return (AsyncPostgresSaver, AsyncConnectionPool), raising ImportError on failure."""
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
try:
from psycopg_pool import AsyncConnectionPool
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
return AsyncPostgresSaver, AsyncConnectionPool
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Async factory # Async factory
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -109,13 +74,15 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
return return
if config.type == "postgres": if config.type == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string: if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED) raise ValueError(POSTGRES_CONN_REQUIRED)
AsyncPostgresSaver, _ = _ensure_postgres_imports() async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
pool = _build_postgres_pool(config.connection_string)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
await saver.setup() await saver.setup()
yield saver yield saver
return return
@@ -150,13 +117,15 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
return return
if db_config.backend == "postgres": if db_config.backend == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url: if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend") raise ValueError("database.postgres_url is required for the postgres backend")
AsyncPostgresSaver, _ = _ensure_postgres_imports() async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
pool = _build_postgres_pool(db_config.postgres_url)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
await saver.setup() await saver.setup()
yield saver yield saver
return return
@@ -86,8 +86,6 @@ class RunJournal(BaseCallbackHandler):
self._last_ai_msg: str | None = None self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None self._first_human_msg: str | None = None
self._msg_count = 0 self._msg_count = 0
self._had_llm_error_fallback = False
self._llm_error_fallback_message: str | None = None
# Latency tracking # Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
@@ -258,18 +256,6 @@ class RunJournal(BaseCallbackHandler):
# Token usage from message # Token usage from message
usage = getattr(message, "usage_metadata", None) usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {} usage_dict = dict(usage) if usage else {}
additional_kwargs = getattr(message, "additional_kwargs", None) or {}
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
self._had_llm_error_fallback = True
detail = additional_kwargs.get("error_detail")
reason = additional_kwargs.get("error_reason")
fallback_text = self._message_text(message).strip()
if isinstance(detail, str) and detail.strip():
self._llm_error_fallback_message = detail.strip()
elif isinstance(reason, str) and reason.strip():
self._llm_error_fallback_message = reason.strip()
elif fallback_text:
self._llm_error_fallback_message = fallback_text[:2000]
# Resolve call index # Resolve call index
call_index = self._llm_call_index call_index = self._llm_call_index
@@ -583,11 +569,3 @@ class RunJournal(BaseCallbackHandler):
"last_ai_message": self._last_ai_msg, "last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg, "first_human_message": self._first_human_msg,
} }
@property
def had_llm_error_fallback(self) -> bool:
return self._had_llm_error_fallback
@property
def llm_error_fallback_message(self) -> str | None:
return self._llm_error_fallback_message
@@ -1,16 +1,39 @@
"""Run lifecycle management for LangGraph Platform API compatibility.""" """Run lifecycle management for LangGraph Platform API compatibility."""
from .domain import (
AssistantId,
CancelAction,
DisconnectMode,
EventSeq,
InvalidRunTransition,
MultitaskStrategy,
Run,
RunId,
RunScope,
RunStatus,
ThreadId,
UserId,
)
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import RunContext, run_agent from .worker import RunContext, run_agent
__all__ = [ __all__ = [
"AssistantId",
"CancelAction",
"ConflictError", "ConflictError",
"DisconnectMode", "DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskStrategy",
"Run",
"RunContext", "RunContext",
"RunId",
"RunManager", "RunManager",
"RunRecord", "RunRecord",
"RunScope",
"RunStatus", "RunStatus",
"ThreadId",
"UnsupportedStrategyError", "UnsupportedStrategyError",
"UserId",
"run_agent", "run_agent",
] ]
@@ -0,0 +1,20 @@
"""Application-layer DTOs and services for run runtime use cases."""
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle, StoredRunEvent
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
from .services import RunsApplicationService
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"GetRunQuery",
"JoinRunStreamCommand",
"ListRunMessagesQuery",
"ListRunsQuery",
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"RunsApplicationService",
"StoredRunEvent",
]
@@ -0,0 +1,46 @@
"""Application command DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
from ..domain import AssistantId, CancelAction, DisconnectMode, MultitaskStrategy, RunId, RunScope, ThreadId
@dataclass(frozen=True)
class CreateRunCommand:
thread_id: ThreadId
assistant_id: AssistantId | None = None
input: dict[str, Any] | None = None
command: dict[str, Any] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
config: dict[str, Any] = field(default_factory=dict)
context: dict[str, Any] = field(default_factory=dict)
scope: RunScope = RunScope.stateful
on_disconnect: DisconnectMode = DisconnectMode.cancel
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
stream_mode: list[str] | str | None = None
stream_subgraphs: bool = False
interrupt_before: list[str] | Literal["*"] | None = None
interrupt_after: list[str] | Literal["*"] | None = None
@dataclass(frozen=True)
class CancelRunCommand:
run_id: RunId
action: CancelAction = CancelAction.interrupt
wait: bool = False
@dataclass(frozen=True)
class JoinRunStreamCommand:
run_id: RunId
last_event_id: str | None = None
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"JoinRunStreamCommand",
]
@@ -0,0 +1,76 @@
"""Application output DTOs for run use cases."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
from ..domain import AssistantId, EventSeq, Run, RunId, RunStatus, ThreadId
@dataclass(frozen=True)
class RunSnapshot:
run_id: RunId
thread_id: ThreadId
assistant_id: AssistantId | None = None
status: RunStatus = RunStatus.pending
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
error: str | None = None
model_name: str | None = None
@classmethod
def from_run(cls, run: Run) -> RunSnapshot:
return cls(
run_id=run.run_id,
thread_id=run.thread_id,
assistant_id=run.assistant_id,
status=run.status,
metadata=dict(run.metadata),
kwargs=dict(run.kwargs),
created_at=run.created_at,
updated_at=run.updated_at,
error=run.error,
model_name=run.model_name,
)
@dataclass(frozen=True)
class RunMessageView:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class StoredRunEvent:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
category: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class RunStreamHandle:
run_id: RunId
thread_id: ThreadId
events: AsyncIterator[Any]
__all__ = [
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"StoredRunEvent",
]
@@ -0,0 +1,37 @@
"""Application query DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..domain import RunId, ThreadId, UserId
@dataclass(frozen=True)
class GetRunQuery:
run_id: RunId
thread_id: ThreadId | None = None
user_id: UserId | None = None
@dataclass(frozen=True)
class ListRunsQuery:
thread_id: ThreadId
user_id: UserId | None = None
limit: int = 100
@dataclass(frozen=True)
class ListRunMessagesQuery:
thread_id: ThreadId
run_id: RunId
limit: int = 50
before_seq: int | None = None
after_seq: int | None = None
__all__ = [
"GetRunQuery",
"ListRunMessagesQuery",
"ListRunsQuery",
]
@@ -0,0 +1,74 @@
"""Application service skeleton for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..execution import RunExecutionScheduler, RunSupervisor
from ..repositories import RunEventLog, RunRepository
from ..streams import RunStreamBroker
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
@dataclass
class RunsApplicationService:
"""Use-case orchestration boundary for run runtime operations.
PR1 only introduces the boundary and dependency shape. Existing Gateway
handlers continue to call the legacy service functions until later PRs move
behavior into this class.
"""
run_repository: RunRepository
run_event_log: RunEventLog
stream_broker: RunStreamBroker
scheduler: RunExecutionScheduler
supervisor: RunSupervisor
async def create_background(self, command: CreateRunCommand) -> RunSnapshot:
# PR1 defines the application boundary; later PRs move Gateway runtime
# behavior behind this method.
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_stream(self, command: CreateRunCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_wait(self, command: CreateRunCommand) -> RunSnapshot:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def join_stream(self, command: JoinRunStreamCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def cancel(self, command: CancelRunCommand) -> bool:
return await self.supervisor.cancel(command.run_id, action=command.action)
async def get_run(self, query: GetRunQuery) -> RunSnapshot | None:
run = await self.run_repository.get(query.run_id, user_id=query.user_id)
if run is None:
return None
if query.thread_id is not None and run.thread_id != query.thread_id:
return None
return RunSnapshot.from_run(run)
async def list_runs(self, query: ListRunsQuery) -> list[RunSnapshot]:
return await self.run_repository.list_by_thread(
query.thread_id,
user_id=query.user_id,
limit=query.limit,
)
async def list_run_messages(self, query: ListRunMessagesQuery) -> list[RunMessageView]:
return await self.run_event_log.list_messages_by_run(
query.thread_id,
query.run_id,
limit=query.limit,
before_seq=query.before_seq,
after_seq=query.after_seq,
)
__all__ = [
"RunsApplicationService",
]
@@ -0,0 +1,33 @@
"""Run runtime domain model."""
from .errors import InvalidRunTransition, RunDomainError
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, UserId
from .model import Run
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
__all__ = [
"AssistantId",
"CancelAction",
"CancelPolicy",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskDecision",
"MultitaskPolicy",
"MultitaskStrategy",
"Run",
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunDomainError",
"RunEvent",
"RunFailed",
"RunId",
"RunScope",
"RunStarted",
"RunStatus",
"ThreadId",
"UserId",
]
@@ -0,0 +1,24 @@
"""Domain-level errors for run lifecycle operations."""
from __future__ import annotations
from .value_objects import RunStatus
class RunDomainError(Exception):
"""Base class for run runtime domain errors."""
class InvalidRunTransition(RunDomainError):
"""Raised when a run status transition violates lifecycle rules."""
def __init__(self, current: RunStatus, target: RunStatus) -> None:
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
self.current = current
self.target = target
__all__ = [
"InvalidRunTransition",
"RunDomainError",
]
@@ -0,0 +1,64 @@
"""Domain events emitted by the run aggregate."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .identifiers import AssistantId, RunId, ThreadId
from .value_objects import CancelAction, RunStatus
@dataclass(frozen=True)
class RunCreated:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
assistant_id: AssistantId | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class RunStarted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunCompleted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunFailed:
run_id: RunId
thread_id: ThreadId
status: RunStatus
occurred_at: str = field(default_factory=now_iso)
error: str | None = None
@dataclass(frozen=True)
class RunCancelled:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
action: CancelAction = CancelAction.interrupt
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
__all__ = [
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunEvent",
"RunFailed",
"RunStarted",
]
@@ -0,0 +1,27 @@
"""Lightweight identifiers for the run runtime domain."""
from __future__ import annotations
from typing import NewType
RunId = NewType("RunId", str)
ThreadId = NewType("ThreadId", str)
AssistantId = NewType("AssistantId", str)
UserId = NewType("UserId", str)
def require_non_empty(value: str, *, field_name: str) -> str:
"""Return a stripped identifier value, rejecting empty identifiers."""
normalized = value.strip()
if not normalized:
raise ValueError(f"{field_name} must not be empty")
return normalized
__all__ = [
"AssistantId",
"RunId",
"ThreadId",
"UserId",
"require_non_empty",
]
@@ -0,0 +1,193 @@
"""Run aggregate root and lifecycle invariants."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .errors import InvalidRunTransition
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
# Keep lifecycle transitions explicit so later application code cannot invent
# ad hoc status moves outside the aggregate.
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
RunStatus.pending: frozenset(
{
RunStatus.running,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.running: frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.success: frozenset(),
RunStatus.error: frozenset(),
RunStatus.timeout: frozenset(),
RunStatus.interrupted: frozenset(),
}
@dataclass
class Run:
"""Run aggregate root.
The aggregate owns lifecycle invariants only. Infrastructure concerns such
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
this model.
"""
run_id: RunId
thread_id: ThreadId
status: RunStatus
assistant_id: AssistantId | None = None
scope: RunScope = RunScope.stateful
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = field(default_factory=now_iso)
updated_at: str = field(default_factory=now_iso)
error: str | None = None
model_name: str | None = None
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
def __post_init__(self) -> None:
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
if self.assistant_id is not None:
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
@classmethod
def create(
cls,
*,
run_id: RunId,
thread_id: ThreadId,
assistant_id: AssistantId | None = None,
scope: RunScope = RunScope.stateful,
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
model_name: str | None = None,
created_at: str | None = None,
) -> Run:
timestamp = created_at or now_iso()
run = cls(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
scope=scope,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=timestamp,
updated_at=timestamp,
model_name=model_name,
)
run._record_event(
RunCreated(
run_id=run.run_id,
thread_id=run.thread_id,
occurred_at=timestamp,
assistant_id=run.assistant_id,
metadata=dict(run.metadata),
)
)
return run
@property
def is_terminal(self) -> bool:
return not _ALLOWED_TRANSITIONS[self.status]
def pull_events(self) -> tuple[RunEvent, ...]:
# Domain events are drained by the application layer after the aggregate
# has accepted a state change.
events = tuple(self._pending_events)
self._pending_events.clear()
return events
def mark_started(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.running, at=at)
def mark_completed(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.success, at=at)
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.error, error=error, at=at)
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.timeout, error=error, at=at)
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
self._transition_to(RunStatus.interrupted, action=action, at=at)
def _transition_to(
self,
target: RunStatus,
*,
error: str | None = None,
action: CancelAction = CancelAction.interrupt,
at: str | None = None,
) -> None:
if target == self.status:
return
if target not in _ALLOWED_TRANSITIONS[self.status]:
raise InvalidRunTransition(self.status, target)
timestamp = at or now_iso()
self.status = target
self.updated_at = timestamp
if error is not None:
self.error = error
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
def _event_for_transition(
self,
target: RunStatus,
occurred_at: str,
*,
error: str | None,
action: CancelAction,
) -> RunEvent:
# Keep event construction next to the transition rules so a new status
# cannot be added without an explicit durable event shape.
if target == RunStatus.running:
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target == RunStatus.success:
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target in (RunStatus.error, RunStatus.timeout):
return RunFailed(
run_id=self.run_id,
thread_id=self.thread_id,
status=target,
occurred_at=occurred_at,
error=error,
)
if target == RunStatus.interrupted:
return RunCancelled(
run_id=self.run_id,
thread_id=self.thread_id,
occurred_at=occurred_at,
action=action,
)
raise InvalidRunTransition(self.status, target)
def _record_event(self, event: RunEvent) -> None:
self._pending_events.append(event)
__all__ = [
"Run",
"RunStatus",
]
@@ -0,0 +1,50 @@
"""Domain policies for run concurrency and cancellation."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from enum import StrEnum
from .model import Run
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
class MultitaskDecision(StrEnum):
"""Application-level decision produced by a multitask policy."""
allow = "allow"
reject = "reject"
cancel_existing = "cancel_existing"
enqueue = "enqueue"
@dataclass(frozen=True)
class MultitaskPolicy:
strategy: MultitaskStrategy = MultitaskStrategy.reject
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
if not inflight:
return MultitaskDecision.allow
if self.strategy == MultitaskStrategy.reject:
return MultitaskDecision.reject
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
return MultitaskDecision.cancel_existing
return MultitaskDecision.enqueue
@dataclass(frozen=True)
class CancelPolicy:
action: CancelAction = CancelAction.interrupt
@property
def rolls_back_checkpoint(self) -> bool:
return self.action == CancelAction.rollback
__all__ = [
"CancelPolicy",
"MultitaskDecision",
"MultitaskPolicy",
]
@@ -0,0 +1,88 @@
"""Domain value objects for run lifecycle semantics."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
class RunScope(StrEnum):
"""Conversation scope for a run."""
stateful = "stateful"
stateless = "stateless"
temporary_thread = "temporary_thread"
class MultitaskStrategy(StrEnum):
"""Concurrency strategy for a new run on a thread."""
reject = "reject"
interrupt = "interrupt"
rollback = "rollback"
enqueue = "enqueue"
class CancelAction(StrEnum):
"""Cancellation action requested by an API or supervisor."""
interrupt = "interrupt"
rollback = "rollback"
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
)
def is_terminal_status(status: RunStatus) -> bool:
return status in TERMINAL_RUN_STATUSES
@dataclass(frozen=True, order=True)
class EventSeq:
"""Thread-local event sequence number."""
value: int
def __post_init__(self) -> None:
if self.value < 0:
raise ValueError("EventSeq must be non-negative")
def next(self) -> EventSeq:
return EventSeq(self.value + 1)
__all__ = [
"CancelAction",
"DisconnectMode",
"EventSeq",
"MultitaskStrategy",
"RunScope",
"RunStatus",
"TERMINAL_RUN_STATUSES",
"is_terminal_status",
]
@@ -0,0 +1,12 @@
"""Execution contracts for run lifecycle orchestration."""
from .executor import RunExecutor
from .scheduler import RunExecutionHandle, RunExecutionScheduler
from .supervisor import RunSupervisor
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
"RunExecutor",
"RunSupervisor",
]
@@ -0,0 +1,19 @@
"""Run executor contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import Run
class RunExecutor(Protocol):
"""Executes one run against the underlying agent or graph runtime."""
async def execute(self, run: Run) -> None:
pass
__all__ = [
"RunExecutor",
]
@@ -0,0 +1,26 @@
"""Run execution scheduler contract."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunExecutionHandle:
run_id: RunId
class RunExecutionScheduler(Protocol):
"""Starts background execution for an accepted run."""
async def start(self, run_id: RunId) -> RunExecutionHandle:
pass
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
]
@@ -0,0 +1,19 @@
"""Run execution supervision contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import CancelAction, RunId
class RunSupervisor(Protocol):
"""Controls lifecycle operations for already scheduled runs."""
async def cancel(self, run_id: RunId, *, action: CancelAction = CancelAction.interrupt) -> bool:
pass
__all__ = [
"RunSupervisor",
]
@@ -645,98 +645,6 @@ class RunManager:
self._runs.pop(run_id, None) self._runs.pop(run_id, None)
logger.debug("Run record %s cleaned up", run_id) logger.debug("Run record %s cleaned up", run_id)
async def shutdown(self, *, timeout: float = 5.0) -> None:
"""Cancel and bounded-await all in-flight runs on process shutdown.
Chat runs execute in fire-and-forget background ``asyncio`` tasks that
write checkpoints through a shared checkpointer. On shutdown the
checkpointer's resources (e.g. the postgres connection pool owned by the
gateway's ``AsyncExitStack``) are torn down; if a run task is still
mid-graph at that point, langgraph's
``AsyncPregelLoop._checkpointer_put_after_previous`` runs its
``finally: await checkpointer.aput(...)`` against the closed pool. Because
that put runs in a langgraph-internal task (not on ``run_agent``'s call
stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the
worker and surfaces as an unhandled exception during ``asyncio.run()``
shutdown (bytedance/deer-flow issue #3373).
Draining in-flight runs *before* the checkpointer is closed lets each
run that settles within ``timeout`` flush its final checkpoint while
resources are still open. Only runs that do **not** settle on their own
are marked ``interrupted`` — a run that completes (e.g. ``success``)
during the drain keeps its real terminal status instead of being
blanket-overwritten. The whole drain, including the trailing status
persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a
slow store under DB pressure) cannot hang worker shutdown — the
precondition for the signal-reentrancy deadlock guarded by
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active
after ``timeout`` are logged and may still race teardown.
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
async with self._lock:
inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()]
for record in inflight:
record.abort_action = "interrupt"
record.abort_event.set()
record.task.cancel() # type: ignore[union-attr] # filtered above
# Status is decided AFTER the drain (below), not here: a run that
# completes on its own during the drain must keep its real status.
if not inflight:
return
tasks = [record.task for record in inflight]
_, pending = await asyncio.wait(tasks, timeout=timeout)
# Only mark/persist ``interrupted`` for runs that did not settle on their
# own (still pending after the timeout, or ended cancelled). A run that
# finished normally during the drain keeps the status it set for itself.
to_persist: list[RunRecord] = []
async with self._lock:
for record in inflight:
task = record.task
if task not in pending and not task.cancelled():
# Completed on its own — retrieve any surfaced exception so it
# is not reported as "never retrieved", and keep its status.
task.exception() # type: ignore[union-attr] # done & not cancelled
continue
if record.status in (RunStatus.pending, RunStatus.running):
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
to_persist.append(record)
# Bound the trailing status persistence within the remaining budget so a
# slow store (``_call_store_with_retry`` can back off under DB pressure)
# cannot push shutdown past ``timeout``.
if to_persist:
remaining = deadline - loop.time()
if remaining <= 0:
logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist))
else:
try:
results = await asyncio.wait_for(
asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True),
timeout=remaining,
)
except TimeoutError:
logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist))
else:
# ``_persist_status`` is best-effort: it catches and logs its
# own failures, returning ``False``. Inspect the aggregate so a
# partial failure is surfaced at shutdown level (with the
# run_id) instead of being silently swallowed by the gather.
for record, result in zip(to_persist, results):
if isinstance(result, Exception):
logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result)
elif result is False:
logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id)
if pending:
logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending))
logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout)
class ConflictError(Exception): class ConflictError(Exception):
"""Raised when multitask_strategy=reject and thread has inflight runs.""" """Raised when multitask_strategy=reject and thread has inflight runs."""
@@ -0,0 +1,9 @@
"""Repository contracts for the run runtime application layer."""
from .run_event_log import RunEventLog
from .run_repository import RunRepository
__all__ = [
"RunEventLog",
"RunRepository",
]
@@ -0,0 +1,42 @@
"""Durable run event log contract."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from ..domain import RunEvent, RunId, ThreadId
if TYPE_CHECKING:
from ..application.dto import RunMessageView, StoredRunEvent
class RunEventLog(Protocol):
"""Persistence boundary for run messages and execution trace events."""
async def append(self, events: list[RunEvent]) -> list[StoredRunEvent]:
pass
async def list_messages_by_run(
self,
thread_id: ThreadId,
run_id: RunId,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[RunMessageView]:
pass
async def list_events_by_run(
self,
thread_id: ThreadId,
run_id: RunId,
*,
limit: int = 500,
) -> list[StoredRunEvent]:
pass
__all__ = [
"RunEventLog",
]
@@ -0,0 +1,37 @@
"""Run state repository contract."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from ..domain import Run, RunId, ThreadId, UserId
if TYPE_CHECKING:
from ..application.dto import RunSnapshot
class RunRepository(Protocol):
"""Persistence boundary for run state snapshots."""
async def save(self, run: Run) -> None:
pass
async def get(self, run_id: RunId, *, user_id: UserId | None = None) -> Run | None:
pass
async def list_by_thread(
self,
thread_id: ThreadId,
*,
user_id: UserId | None = None,
limit: int = 100,
) -> list[RunSnapshot]:
pass
async def delete(self, run_id: RunId) -> bool:
pass
__all__ = [
"RunRepository",
]
@@ -1,21 +1,10 @@
"""Run status and disconnect mode enums.""" """Compatibility exports for run status and disconnect mode enums."""
from enum import StrEnum # Existing callers import these enums from ``runs.schemas``. Re-export the
# domain definitions until all imports move to ``runs.domain``.
from .domain import DisconnectMode, RunStatus
__all__ = [
class RunStatus(StrEnum): "DisconnectMode",
"""Lifecycle status of a single run.""" "RunStatus",
]
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
@@ -0,0 +1,8 @@
"""Realtime stream contracts for run application use cases."""
from .run_stream_broker import RunStreamBroker, RunStreamEvent
__all__ = [
"RunStreamBroker",
"RunStreamEvent",
]
@@ -0,0 +1,44 @@
"""Realtime run stream broker contract."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunStreamEvent:
id: str
event: str
data: Any
class RunStreamBroker(Protocol):
"""Realtime publish/subscribe boundary for run streams."""
async def publish(self, run_id: RunId, event: str, data: Any) -> None:
pass
async def publish_terminal(self, run_id: RunId, *, event: str = "end", data: Any = None) -> None:
pass
def subscribe(
self,
run_id: RunId,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[RunStreamEvent]:
pass
async def cleanup(self, run_id: RunId, *, delay: float = 0) -> None:
pass
__all__ = [
"RunStreamBroker",
"RunStreamEvent",
]
@@ -150,7 +150,6 @@ async def run_agent(
pre_run_checkpoint_id: str | None = None pre_run_checkpoint_id: str | None = None
pre_run_snapshot: dict[str, Any] | None = None pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False snapshot_capture_failed = False
llm_error_fallback_message: str | None = None
journal = None journal = None
@@ -313,7 +312,6 @@ async def run_agent(
if record.abort_event.is_set(): if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id) logger.info("Run %s abort requested — stopping", run_id)
break break
llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk)
sse_event = _lg_mode_to_sse_event(single_mode) sse_event = _lg_mode_to_sse_event(single_mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode)) await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
else: else:
@@ -332,7 +330,6 @@ async def run_agent(
if mode is None: if mode is None:
continue continue
llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk)
sse_event = _lg_mode_to_sse_event(mode) sse_event = _lg_mode_to_sse_event(mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode)) await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
@@ -355,12 +352,6 @@ async def run_agent(
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True) logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
else: else:
await run_manager.set_status(run_id, RunStatus.interrupted) await run_manager.set_status(run_id, RunStatus.interrupted)
elif llm_error_fallback_message or (journal is not None and journal.had_llm_error_fallback):
error_msg = llm_error_fallback_message
if error_msg is None and journal is not None:
error_msg = journal.llm_error_fallback_message
error_msg = error_msg or "LLM provider failed after retries"
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
else: else:
await run_manager.set_status(run_id, RunStatus.success) await run_manager.set_status(run_id, RunStatus.success)
@@ -563,85 +554,6 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode return mode
def _error_fallback_message_from_metadata(metadata: dict[str, Any], content: Any) -> str:
detail = metadata.get("error_detail")
if isinstance(detail, str) and detail.strip():
return detail.strip()
reason = metadata.get("error_reason")
if isinstance(reason, str) and reason.strip():
return reason.strip()
if isinstance(content, str) and content.strip():
return content.strip()[:2000]
return "LLM provider failed after retries"
def _try_extract_from_message(obj: Any) -> str | None:
"""Try to extract fallback marker from a single message object or dict."""
additional_kwargs = getattr(obj, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
return _error_fallback_message_from_metadata(additional_kwargs, getattr(obj, "content", None))
if isinstance(obj, dict):
nested_kwargs = obj.get("additional_kwargs")
if isinstance(nested_kwargs, dict) and nested_kwargs.get("deerflow_error_fallback"):
return _error_fallback_message_from_metadata(nested_kwargs, obj.get("content"))
return None
def _extract_llm_error_fallback_message(value: Any) -> str | None:
"""Find LLM fallback markers in streamed LangGraph chunks.
Error fallback messages returned by model-call middleware are not guaranteed
to pass through LLM end callbacks, but they do appear in graph state chunks.
"""
# Fast path: large state chunks produced by stream_mode="values" have a
# top-level "messages" list. Scanning only that list avoids expensive deep
# recursion into large state dicts.
if isinstance(value, dict):
messages = value.get("messages")
if isinstance(messages, (list, tuple)):
for msg in messages:
result = _try_extract_from_message(msg)
if result is not None:
return result
# Fallback marker is attached to an AI message in the messages
# channel; it will never appear elsewhere in a values chunk.
return None
# No top-level "messages" — this is likely an "updates" chunk (small
# dict keyed by node name). Fall through to deep walk, which is cheap
# for these payloads.
# Deep walk for updates / messages / tuple / list modes. Payloads are
# small, so full recursion is acceptable here.
seen: set[int] = set()
def walk(obj: Any) -> str | None:
oid = id(obj)
if oid in seen:
return None
seen.add(oid)
result = _try_extract_from_message(obj)
if result is not None:
return result
if isinstance(obj, dict):
for item in obj.values():
result = walk(item)
if result is not None:
return result
return None
if isinstance(obj, (list, tuple, set)):
for item in obj:
result = walk(item)
if result is not None:
return result
return None
return walk(value)
def _extract_human_message(graph_input: dict) -> HumanMessage | None: def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording. """Extract or construct a HumanMessage from graph_input for event recording.
@@ -1,5 +1,4 @@
import asyncio import asyncio
import os
import posixpath import posixpath
import re import re
import shlex import shlex
@@ -44,16 +43,6 @@ _MAX_GLOB_MAX_RESULTS = 1000
_DEFAULT_GREP_MAX_RESULTS = 100 _DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500 _MAX_GREP_MAX_RESULTS = 500
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000 _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000
# Maximum bytes accepted in a single non-append write_file call (issue #3189).
# Oversized single-shot writes correlate with LLM streaming chunk-gap timeouts
# because the tool-call JSON payload (which the model must emit as one
# continuous stream) grows past the safe window. 80 KB ≈ 20K tokens, a
# comfortable headroom under the factory-default 240s stream_chunk_timeout.
# Deployments can override via env var DEERFLOW_WRITE_FILE_MAX_BYTES; set to
# 0 (or negative) to disable the guard entirely.
_WRITE_FILE_CONTENT_MAX_BYTES = 80 * 1024
_WRITE_FILE_MAX_BYTES_ENV = "DEERFLOW_WRITE_FILE_MAX_BYTES"
_LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"} _LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"}
_LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"} _LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"}
_LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"} _LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"}
@@ -1682,23 +1671,6 @@ async def _read_file_tool_async(
read_file_tool.coroutine = _read_file_tool_async read_file_tool.coroutine = _read_file_tool_async
def _effective_write_file_max_bytes() -> int:
"""Return the active size cap for non-append write_file calls.
Reads ``DEERFLOW_WRITE_FILE_MAX_BYTES`` at call time (not import time)
so tests and runtime tweaks take effect without restart. Falls back to
the default on missing/malformed values. A non-positive value disables
the guard.
"""
raw = os.environ.get(_WRITE_FILE_MAX_BYTES_ENV)
if raw is None:
return _WRITE_FILE_CONTENT_MAX_BYTES
try:
return int(raw)
except ValueError:
return _WRITE_FILE_CONTENT_MAX_BYTES
@tool("write_file", parse_docstring=True) @tool("write_file", parse_docstring=True)
def write_file_tool( def write_file_tool(
runtime: Runtime, runtime: Runtime,
@@ -1707,47 +1679,14 @@ def write_file_tool(
content: str, content: str,
append: bool = False, append: bool = False,
) -> str: ) -> str:
"""Write text content to a file. By default this overwrites the target file; set append=True to add content to the end without replacing existing content. """Write text content to a file. By default this overwrites the target file; set append to true to add content to the end without replacing existing content.
SIZE POLICY (issue #3189):
A single non-append write_file call must not exceed 80 KB of UTF-8 content.
Oversized single-shot writes correlate with LLM streaming chunk-gap
timeouts because the tool-call JSON payload which the model must emit as
one continuous stream grows past the safe window. For larger documents,
use ONE of these strategies (write_file rejects oversized payloads with an
actionable error):
1. INCREMENTAL EDIT (preferred for revisions): after the initial write,
use `str_replace` to surgically update sections. This is the same
pattern Claude Code's Write+Edit and OpenAI Codex's apply_patch use,
and keeps each tool call's payload small.
2. APPEND-IN-CHUNKS (for new long-form content): split the document into
sections, each well under 80 KB. First call uses append=False to
create the file; subsequent calls use append=True. The 80 KB cap does
NOT apply to append=True calls.
Operators can override the cap via env var `DEERFLOW_WRITE_FILE_MAX_BYTES`
(0 disables the guard entirely). Raising it risks streaming timeouts.
Args: Args:
description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND. path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD. content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
append: Whether to append content to the end of the file instead of overwriting it. Defaults to False. append: Whether to append content to the end of the file instead of overwriting it. Defaults to false.
""" """
if not append:
max_bytes = _effective_write_file_max_bytes()
if max_bytes > 0:
content_bytes = len(content.encode("utf-8"))
if content_bytes > max_bytes:
return (
f"Error: write_file content ({content_bytes} bytes) exceeds the "
f"{max_bytes}-byte single-call limit. Split the content into smaller "
"pieces: either (a) write the first section now, then use `str_replace` "
"for further edits, or (b) call write_file again with append=True "
"carrying the next section. See SIZE POLICY in the tool docstring "
"or issue #3189 for the rationale."
)
try: try:
requested_path = path requested_path = path
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
@@ -9,37 +9,6 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _format_yaml_error(skill_file: Path, exc: yaml.YAMLError, source: str) -> str:
"""Render a developer-friendly explanation of a YAML front-matter error."""
lines = [f"Invalid YAML front-matter in {skill_file}: {exc}"]
mark = getattr(exc, "problem_mark", None)
source_lines = source.splitlines()
if mark is not None and 0 <= mark.line < len(source_lines):
offending = source_lines[mark.line]
# mark.line is 0-based within the front-matter body; +1 makes it
# 1-based, +1 more accounts for the leading `---` fence that the
# front-matter regex strips before yaml.safe_load sees it. The
# result matches the line number an author sees in their editor.
file_line_number = mark.line + 2
lines.append(f" line {file_line_number}: {offending}")
# Targeted hint for the most common authoring mistake: an unquoted
# scalar value whose body contains ``: ``. We only surface the hint
# when we are confident it applies, to avoid misleading authors who
# hit unrelated YAML errors.
if getattr(exc, "problem", "") == "mapping values are not allowed here" and ":" in offending:
key, _, value = offending.partition(":")
value = value.strip()
if value and value[0] not in {'"', "'", "|", ">", "[", "{"}:
escaped = value.replace("\\", "\\\\").replace('"', '\\"')
lines.append(f' hint: values containing ":" must be quoted, e.g. {key}: "{escaped}"')
return "\n".join(lines)
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None: def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
"""Parse the optional allowed-tools frontmatter field. """Parse the optional allowed-tools frontmatter field.
@@ -91,7 +60,7 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
try: try:
metadata = yaml.safe_load(front_matter_text) metadata = yaml.safe_load(front_matter_text)
except yaml.YAMLError as exc: except yaml.YAMLError as exc:
logger.error("%s", _format_yaml_error(skill_file, exc, front_matter_text)) logger.error("Invalid YAML front-matter in %s: %s", skill_file, exc)
return None return None
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
@@ -24,17 +24,6 @@ Do NOT use for simple, single-step operations.""",
- Do NOT ask for clarification - work with the information provided - Do NOT ask for clarification - work with the information provided
</guidelines> </guidelines>
<file_editing_workflow>
When revising an existing file, prefer `str_replace` over `write_file`
it sends only the diff and avoids re-emitting the whole file (mirrors
Claude Code's Edit and Codex's apply_patch). When writing long new
content from scratch, split it into sections: the first `write_file`
call creates the file, then use `write_file` with append=True to extend
it section by section. This keeps each tool call small and avoids
mid-stream chunk-gap timeouts on oversized single-shot writes.
(See issue #3189.)
</file_editing_workflow>
<output_format> <output_format>
When you complete the task, provide: When you complete the task, provide:
1. A brief summary of what was accomplished 1. A brief summary of what was accomplished
@@ -1,181 +1,202 @@
"""Tool search — deferred tool discovery at runtime. """Tool search — deferred tool discovery at runtime.
Contains: Contains:
- DeferredToolCatalog: immutable, searchable catalog of deferred tools. - DeferredToolRegistry: stores deferred tools and handles regex search
- build_tool_search_tool: builds the `tool_search` tool as a closure over a - tool_search: the LangChain tool the agent calls to discover deferred tools
catalog; it records promotions into graph state via ``Command``.
- build_deferred_tool_setup: assembles the catalog + tool from a
policy-filtered tool list (call AFTER tool-policy filtering).
The agent sees deferred tool names in <available-deferred-tools> but cannot The agent sees deferred tool names in <available-deferred-tools> but cannot
call them until it fetches their full schema via the tool_search tool. The call them until it fetches their full schema via the tool_search tool.
deferred set rides on a build-time closure and promotion lives in per-thread Source-agnostic: no mention of MCP or tool origin.
graph state there is no ContextVar. Source-agnostic: a tool is "deferred"
when it carries the ``deerflow_mcp`` metadata tag.
""" """
import hashlib import contextvars
import json import json
import logging import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property
from typing import Annotated
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain_core.messages import ToolMessage from langchain_core.tools import tool
from langchain_core.tools import InjectedToolCallId, tool
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from langgraph.types import Command
from deerflow.tools.mcp_metadata import is_mcp_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_RESULTS = 5 # Max tools returned per search MAX_RESULTS = 5 # Max tools returned per search
def _compile_catalog_regex(pattern: str) -> re.Pattern[str]: # ── Registry ──
"""Compile ``pattern`` case-insensitively, falling back to a literal match.
Search queries come from the model, so an invalid regex (e.g. an unbalanced
paren) must degrade to a literal substring match rather than raise.
"""
try:
return re.compile(pattern, re.IGNORECASE)
except re.error:
return re.compile(re.escape(pattern), re.IGNORECASE)
# ── Catalog ── @dataclass
class DeferredToolEntry:
"""Lightweight metadata for a deferred tool (no full schema in context)."""
name: str
description: str
tool: BaseTool # Full tool object, returned only on search match
# NOTE: frozen=True without slots=True keeps __dict__, which is what lets the class DeferredToolRegistry:
# @cached_property fields below cache (they write to instance.__dict__, bypassing """Registry of deferred tools, searchable by regex pattern."""
# the frozen __setattr__). Do NOT add slots=True or hash/names break at runtime.
@dataclass(frozen=True)
class DeferredToolCatalog:
"""Immutable catalog of deferred tools. Pure search, no mutation."""
tools: tuple[BaseTool, ...] def __init__(self):
self._entries: list[DeferredToolEntry] = []
@cached_property def register(self, tool: BaseTool) -> None:
def names(self) -> frozenset[str]: self._entries.append(
return frozenset(t.name for t in self.tools) DeferredToolEntry(
name=tool.name,
description=tool.description or "",
tool=tool,
)
)
@cached_property def promote(self, names: set[str]) -> None:
def hash(self) -> str: """Remove tools from the deferred registry so they pass through the filter.
canon = [{"name": t.name, "schema": convert_to_openai_function(t)} for t in sorted(self.tools, key=lambda t: t.name)]
blob = json.dumps(canon, sort_keys=True, ensure_ascii=False, default=str) Called after tool_search returns a tool's schema — the LLM now knows
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16] the full definition, so the DeferredToolFilterMiddleware should stop
stripping it from bind_tools on subsequent calls.
"""
if not names:
return
before = len(self._entries)
self._entries = [e for e in self._entries if e.name not in names]
promoted = before - len(self._entries)
if promoted:
logger.debug(f"Promoted {promoted} tool(s) from deferred to active: {names}")
def search(self, query: str) -> list[BaseTool]: def search(self, query: str) -> list[BaseTool]:
query = query.strip() """Search deferred tools by regex pattern against name + description.
if not query:
return []
Supports three query forms (aligned with Claude Code):
- "select:name1,name2" exact name match
- "+keyword rest" name must contain keyword, rank by rest
- "keyword query" regex match against name + description
Returns:
List of matched BaseTool objects (up to MAX_RESULTS).
"""
if query.startswith("select:"): if query.startswith("select:"):
wanted = {n.strip() for n in query[7:].split(",")} names = {n.strip() for n in query[7:].split(",")}
return [t for t in self.tools if t.name in wanted][:MAX_RESULTS] return [e.tool for e in self._entries if e.name in names][:MAX_RESULTS]
if query.startswith("+"): if query.startswith("+"):
parts = query[1:].split(None, 1) parts = query[1:].split(None, 1)
if not parts:
return [] # bare "+" with no required token — nothing to require
required = parts[0].lower() required = parts[0].lower()
candidates = [t for t in self.tools if required in t.name.lower()] candidates = [e for e in self._entries if required in e.name.lower()]
if len(parts) > 1: if len(parts) > 1:
candidates.sort(key=lambda t: _catalog_regex_score(parts[1], t), reverse=True) candidates.sort(
return candidates[:MAX_RESULTS] key=lambda e: _regex_score(parts[1], e),
reverse=True,
)
return [e.tool for e in candidates][:MAX_RESULTS]
regex = _compile_catalog_regex(query) # General regex search
scored: list[tuple[int, BaseTool]] = [] try:
for t in self.tools: regex = re.compile(query, re.IGNORECASE)
searchable = f"{t.name} {t.description or ''}" except re.error:
regex = re.compile(re.escape(query), re.IGNORECASE)
scored = []
for entry in self._entries:
searchable = f"{entry.name} {entry.description}"
if regex.search(searchable): if regex.search(searchable):
scored.append((2 if regex.search(t.name) else 1, t)) score = 2 if regex.search(entry.name) else 1
scored.append((score, entry))
scored.sort(key=lambda x: x[0], reverse=True) scored.sort(key=lambda x: x[0], reverse=True)
return [t for _, t in scored][:MAX_RESULTS] return [entry.tool for _, entry in scored][:MAX_RESULTS]
@property
def entries(self) -> list[DeferredToolEntry]:
return list(self._entries)
@property
def deferred_names(self) -> set[str]:
"""Names of tools that are still hidden from model binding."""
return {entry.name for entry in self._entries}
def contains(self, name: str) -> bool:
"""Return whether *name* is still deferred."""
return any(entry.name == name for entry in self._entries)
def __len__(self) -> int:
return len(self._entries)
def _catalog_regex_score(pattern: str, t: BaseTool) -> int: def _regex_score(pattern: str, entry: DeferredToolEntry) -> int:
regex = _compile_catalog_regex(pattern) try:
return len(regex.findall(f"{t.name} {t.description or ''}")) regex = re.compile(pattern, re.IGNORECASE)
except re.error:
regex = re.compile(re.escape(pattern), re.IGNORECASE)
return len(regex.findall(f"{entry.name} {entry.description}"))
# ── Setup / tool ── # ── Per-request registry (ContextVar) ──
#
# Using a ContextVar instead of a module-level global prevents concurrent
# requests from clobbering each other's registry. In asyncio-based LangGraph
# each graph run executes in its own async context, so each request gets an
# independent registry value. For synchronous tools run via
# loop.run_in_executor, Python copies the current context to the worker thread,
# so the ContextVar value is correctly inherited there too.
_registry_var: contextvars.ContextVar[DeferredToolRegistry | None] = contextvars.ContextVar("deferred_tool_registry", default=None)
@dataclass(frozen=True) def get_deferred_registry() -> DeferredToolRegistry | None:
class DeferredToolSetup: return _registry_var.get()
"""Result of assembling deferred-tool support for one agent build.
The three fields move as a unit, so callers branch on ``tool_search_tool``:
- **Empty** ``(None, frozenset(), None)``: deferral is disabled, or no MCP def set_deferred_registry(registry: DeferredToolRegistry) -> None:
tool survived policy filtering. Nothing is deferred bind tools as-is. _registry_var.set(registry)
- **Populated**: ``tool_search_tool`` is appended to the agent's tools,
``deferred_names`` are withheld from the model until promoted, and
``catalog_hash`` scopes those promotions in graph state.
Invariant: ``tool_search_tool is None`` ``deferred_names`` is empty
``catalog_hash is None``. def reset_deferred_registry() -> None:
"""Reset the deferred registry for the current async context."""
_registry_var.set(None)
# ── Tool ──
@tool
def tool_search(query: str) -> str:
"""Fetches full schema definitions for deferred tools so they can be called.
Deferred tools appear by name in <available-deferred-tools> in the system
prompt. Until fetched, only the name is known there is no parameter
schema, so the tool cannot be invoked. This tool takes a query, matches
it against the deferred tool list, and returns the matched tools' complete
definitions. Once a tool's schema appears in that result, it is callable.
Query forms:
- "select:Read,Edit,Grep" fetch these exact tools by name
- "notebook jupyter" keyword search, up to max_results best matches
- "+slack send" require "slack" in the name, rank by remaining terms
Args:
query: Query to find deferred tools. Use "select:<tool_name>" for
direct selection, or keywords to search.
Returns:
Matched tool definitions as JSON array.
""" """
registry = get_deferred_registry()
if not registry:
return "No deferred tools available."
tool_search_tool: BaseTool | None matched_tools = registry.search(query)
deferred_names: frozenset[str] if not matched_tools:
catalog_hash: str | None return f"No tools found matching: {query}"
# Use LangChain's built-in serialization to produce OpenAI function format.
# This is model-agnostic: all LLMs understand this standard schema.
tool_defs = [convert_to_openai_function(t) for t in matched_tools[:MAX_RESULTS]]
def build_tool_search_tool(catalog: DeferredToolCatalog) -> BaseTool: # Promote matched tools so the DeferredToolFilterMiddleware stops filtering
catalog_hash = catalog.hash # them from bind_tools — the LLM now has the full schema and can invoke them.
registry.promote({t.name for t in matched_tools[:MAX_RESULTS]})
@tool return json.dumps(tool_defs, indent=2, ensure_ascii=False)
def tool_search(query: str, tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
"""Fetches full schema definitions for deferred tools so they can be called.
Deferred tools appear by name in <available-deferred-tools> in the system
prompt. Until fetched, only the name is known. This tool matches a query
against the deferred tools and returns the matched tools complete schemas;
once returned, a tool becomes callable.
Query forms:
- "select:Read,Edit" -- fetch these exact tools by name
- "notebook jupyter" -- keyword search, up to max_results best matches
- "+slack send" -- require "slack" in the name, rank by remaining terms
"""
matched = catalog.search(query)[:MAX_RESULTS]
if not matched:
content, names = f"No tools found matching: {query}", []
else:
content = json.dumps([convert_to_openai_function(t) for t in matched], indent=2, ensure_ascii=False)
names = [t.name for t in matched]
return Command(
update={
"promoted": {"catalog_hash": catalog_hash, "names": names},
"messages": [ToolMessage(content=content, tool_call_id=tool_call_id, name="tool_search")],
}
)
return tool_search
def build_deferred_tool_setup(filtered_tools: list[BaseTool], *, enabled: bool) -> DeferredToolSetup:
"""Build the deferred-tool setup from a POLICY-FILTERED tool list.
Must be called after skill/agent tool-policy filtering so the catalog never
exposes a tool the current agent is not allowed to use.
Returns an empty setup (see :class:`DeferredToolSetup`) in two distinct
cases: deferral is disabled, or it is enabled but no MCP tool survived
filtering.
"""
if not enabled:
# Deferral disabled: defer nothing; the model binds every tool as before.
return DeferredToolSetup(None, frozenset(), None)
deferred = [t for t in filtered_tools if is_mcp_tool(t)]
if not deferred:
# Enabled, but no MCP tool to defer: same empty result, different reason.
return DeferredToolSetup(None, frozenset(), None)
catalog = DeferredToolCatalog(tuple(deferred))
return DeferredToolSetup(build_tool_search_tool(catalog), catalog.names, catalog.hash)
@@ -17,13 +17,12 @@ from __future__ import annotations
import logging import logging
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Annotated, Any from typing import Any
import yaml import yaml
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.types import Command from langgraph.types import Command
from pydantic import BeforeValidator
from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
@@ -33,8 +32,6 @@ from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NULLISH_STRINGS = frozenset({"null", "none", "undefined"})
def _stage_temp(path: Path, text: str) -> Path: def _stage_temp(path: Path, text: str) -> Path:
"""Write ``text`` into a sibling temp file and return its path. """Write ``text`` into a sibling temp file and return its path.
@@ -70,26 +67,14 @@ def _cleanup_temps(temps: list[Path]) -> None:
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True) logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
def _is_nullish_string(value: object) -> bool:
return isinstance(value, str) and value.strip().lower() in _NULLISH_STRINGS
def _normalize_nullish_string(value: object) -> object:
return None if _is_nullish_string(value) else value
OptionalText = Annotated[str | None, BeforeValidator(_normalize_nullish_string)]
OptionalStringList = Annotated[list[str] | None, BeforeValidator(_normalize_nullish_string)]
@tool(parse_docstring=True) @tool(parse_docstring=True)
def update_agent( def update_agent(
runtime: Runtime, runtime: Runtime,
soul: OptionalText = None, soul: str | None = None,
description: OptionalText = None, description: str | None = None,
skills: OptionalStringList = None, skills: list[str] | None = None,
tool_groups: OptionalStringList = None, tool_groups: list[str] | None = None,
model: OptionalText = None, model: str | None = None,
) -> Command: ) -> Command:
"""Persist updates to the current custom agent's SOUL.md and config.yaml. """Persist updates to the current custom agent's SOUL.md and config.yaml.
@@ -101,9 +86,7 @@ def update_agent(
semantics, so always start from the current SOUL and apply your edits. semantics, so always start from the current SOUL and apply your edits.
Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills`` Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills``
entirely to keep the existing whitelist. Do not pass literal strings like entirely to keep the existing whitelist.
``"null"`` / ``"none"`` / ``"undefined"`` for unchanged fields; omit those
fields instead.
Args: Args:
soul: Optional full replacement SOUL.md content. soul: Optional full replacement SOUL.md content.
@@ -121,10 +104,10 @@ def update_agent(
agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None
def _err(message: str) -> Command: def _err(message: str) -> Command:
return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id, status="error")]}) return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id)]})
if soul is None and description is None and skills is None and tool_groups is None and model is None: if soul is None and description is None and skills is None and tool_groups is None and model is None:
return _err('No fields provided. Pass at least one of: soul, description, skills, tool_groups, model. Omit unchanged fields instead of passing null-like strings such as "null", "none", or "undefined".') return _err("No fields provided. Pass at least one of: soul, description, skills, tool_groups, model.")
try: try:
agent_name = validate_agent_name(agent_name_raw) agent_name = validate_agent_name(agent_name_raw)
@@ -1,29 +0,0 @@
"""Single source of truth for the MCP-tool metadata tag.
A tool is "MCP-sourced" when it carries the ``deerflow_mcp`` metadata flag.
The tag is *written* where MCP tools are loaded (``tools.py``) and *read* by
deferred-tool assembly (``tool_search.py``) and the agent build site
(``agent.py``). Keeping the key, the tagger, and the predicate here means the
magic string lives in exactly one place, and readers import a public predicate
instead of a private cross-module helper.
This is a leaf module by design: it depends only on ``BaseTool`` so that any
module (including the tool loader) can import it without an import cycle.
"""
from __future__ import annotations
from langchain.tools import BaseTool
MCP_TOOL_METADATA_KEY = "deerflow_mcp"
def tag_mcp_tool(tool: BaseTool) -> BaseTool:
"""Mark ``tool`` as MCP-sourced. Mutates in place and returns it for chaining."""
tool.metadata = {**(tool.metadata or {}), MCP_TOOL_METADATA_KEY: True}
return tool
def is_mcp_tool(tool: BaseTool) -> bool:
"""True when ``tool`` carries the MCP-source tag written by :func:`tag_mcp_tool`."""
return (getattr(tool, "metadata", None) or {}).get(MCP_TOOL_METADATA_KEY) is True
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.mcp_metadata import tag_mcp_tool from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.sync import make_sync_tool_wrapper from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -127,13 +127,57 @@ def get_available_tools(
if mcp_tools: if mcp_tools:
logger.info(f"Using {len(mcp_tools)} cached MCP tool(s)") logger.info(f"Using {len(mcp_tools)} cached MCP tool(s)")
# Tag MCP-sourced tools so deferred-tool assembly (done at # When tool_search is enabled, register MCP tools in the
# the agent construction site, AFTER tool-policy filtering) # deferred registry and add tool_search to builtin tools.
# can identify them. No ContextVar / registry is built here; if config.tool_search.enabled:
# the deferred catalog + tool_search tool are assembled per from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
# agent from the policy-filtered tool list. from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
for t in mcp_tools:
tag_mcp_tool(t) # Reuse the existing registry if one is already set for
# this async context. ``get_available_tools`` is
# re-entered whenever a subagent is spawned
# (``task_tool`` calls it to build the child agent's
# toolset), and previously we used to unconditionally
# rebuild the registry — wiping out the parent agent's
# tool_search promotions. The
# ``DeferredToolFilterMiddleware`` then re-hid those
# tools from subsequent model calls, leaving the agent
# able to see a tool's name but unable to invoke it
# (issue #2884). ``contextvars`` already gives us the
# lifetime semantics we want: a fresh request / graph
# run starts in a new asyncio task with the
# ContextVar at its default of ``None``, so reuse is
# only triggered for re-entrant calls inside one run.
#
# Intentionally NOT reconciling against the current
# ``mcp_tools`` snapshot. The MCP cache only refreshes
# on ``extensions_config.json`` mtime changes, which
# in practice happens between graph runs — not inside
# one. And even if a refresh did happen mid-run, the
# already-built lead agent's ``ToolNode`` still holds
# the *previous* tool set (LangGraph binds tools at
# graph construction time), so a brand-new MCP tool
# couldn't actually be invoked anyway. The
# ``DeferredToolRegistry`` doesn't retain the names
# of previously-promoted tools (``promote()`` drops
# the entry entirely), so re-syncing the registry
# against a fresh ``mcp_tools`` list would
# mis-classify those promotions as new tools and
# re-register them as deferred — exactly the bug
# this fix exists to prevent.
existing_registry = get_deferred_registry()
if existing_registry is None:
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
else:
mcp_tool_names = {t.name for t in mcp_tools}
still_deferred = len(existing_registry)
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
builtin_tools.append(tool_search_tool)
except ImportError: except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.") logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e: except Exception as e:
@@ -226,7 +226,8 @@ def list_files_in_dir(directory: Path) -> dict:
Returns: Returns:
Dict with "files" list (sorted by name) and "count". Dict with "files" list (sorted by name) and "count".
Each file entry has ``size`` as *int* (bytes). Call Each file entry has ``size`` as *int* (bytes). Call
:func:`enrich_file_listing` to add virtual / artifact URLs. :func:`enrich_file_listing` to stringify sizes and add
virtual / artifact URLs.
""" """
if not directory.is_dir(): if not directory.is_dir():
return {"files": [], "count": 0} return {"files": [], "count": 0}
@@ -297,12 +298,13 @@ def upload_virtual_path(filename: str) -> str:
def enrich_file_listing(result: dict, thread_id: str) -> dict: def enrich_file_listing(result: dict, thread_id: str) -> dict:
"""Add virtual paths and artifact URLs on a listing result. """Add virtual paths, artifact URLs, and stringify sizes on a listing result.
Mutates *result* in place and returns it for convenience. Mutates *result* in place and returns it for convenience.
""" """
for f in result["files"]: for f in result["files"]:
filename = f["filename"] filename = f["filename"]
f["size"] = str(f["size"])
f["virtual_path"] = upload_virtual_path(filename) f["virtual_path"] = upload_virtual_path(filename)
f["artifact_url"] = upload_artifact_url(thread_id, filename) f["artifact_url"] = upload_artifact_url(thread_id, filename)
return result return result
@@ -1,16 +0,0 @@
from fastapi.testclient import TestClient
def assert_run_message_page(
client: TestClient,
url: str,
*,
expected_seq: list[int],
has_more: bool = True,
) -> None:
response = client.get(url)
assert response.status_code == 200
body = response.json()
assert body["has_more"] is has_more
assert [m["seq"] for m in body["data"]] == expected_seq
-73
View File
@@ -318,76 +318,3 @@ class TestDownloadFile:
result = sandbox.download_file("/mnt/user-data/outputs/single.bin") result = sandbox.download_file("/mnt/user-data/outputs/single.bin")
assert result == b"single-chunk" assert result == b"single-chunk"
class TestClose:
"""Verify AioSandbox.close() tears down the host-side HTTP client (#2872)."""
def test_close_calls_real_nested_httpx_client(self, sandbox):
"""close() must close the real httpx.Client at the bottom of the chain.
Mirrors the actual Fern structure:
Sandbox._client_wrapper.httpx_client -> Fern HttpClient (no close())
.httpx_client -> httpx.Client (the real owner)
The intermediate HttpClient deliberately exposes NO close(), so a naive
one-level lookup (the original bug) would silently close nothing.
"""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx) # no close on this layer
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
real_httpx.close.assert_called_once_with()
def test_close_clears_client_reference(self, sandbox):
"""After close(), the client reference must be dropped (use-after-close safety)."""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
assert sandbox._client is None
assert sandbox._closed is True
def test_close_is_idempotent(self, sandbox):
"""Calling close() multiple times must close the underlying client at most once."""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
sandbox.close()
sandbox.close()
assert real_httpx.close.call_count == 1
def test_close_swallows_exceptions(self, sandbox, caplog):
"""close() must be best-effort: client errors are logged but never raised."""
real_httpx = MagicMock(spec=["close"])
real_httpx.close.side_effect = RuntimeError("teardown boom")
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
with caplog.at_level("WARNING"):
sandbox.close()
assert "Error closing AioSandbox client" in caplog.text
def test_close_falls_back_to_client_close(self, sandbox):
"""If no nested httpx.Client is reachable, close() degrades to the client's own close()."""
# Replace the mocked client with a stub that exposes only top-level close()
client = MagicMock(spec=["close"])
sandbox._client = client
sandbox.close()
client.close.assert_called_once_with()
def test_close_when_no_close_attr_does_not_raise(self, sandbox):
"""A client without any close attribute must not crash close()."""
sandbox._client = SimpleNamespace() # no close, no _client_wrapper
sandbox.close() # must not raise
assert sandbox._client is None
@@ -348,89 +348,3 @@ def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
"thread_id": "thread-42", "thread_id": "thread-42",
"user_id": "user-7", "user_id": "user-7",
} }
# ── Sandbox client teardown (#2872) ──────────────────────────────────────────
def _make_provider_with_active_sandbox(tmp_path, sandbox_id: str):
"""Build a provider with one active sandbox suitable for release/destroy/shutdown tests."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._lock = aio_mod.threading.Lock()
provider._warm_pool = {}
provider._sandbox_infos = {
sandbox_id: aio_mod.SandboxInfo(sandbox_id=sandbox_id, sandbox_url="http://sandbox-host"),
}
provider._thread_sandboxes = {}
provider._last_activity = {sandbox_id: 0.0}
provider._shutdown_called = False
provider._idle_checker_thread = None
provider._backend = SimpleNamespace(destroy=MagicMock())
sandbox = MagicMock()
sandbox.id = sandbox_id
sandbox.close = MagicMock()
provider._sandboxes = {sandbox_id: sandbox}
return provider, sandbox, aio_mod
def test_release_closes_cached_sandbox_client(tmp_path):
"""release() must close the host-side client owned by the cached AioSandbox (#2872)."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-rel")
provider.release("sandbox-rel")
sandbox.close.assert_called_once_with()
# And the sandbox is parked in the warm pool (container still running).
assert "sandbox-rel" in provider._warm_pool
assert "sandbox-rel" not in provider._sandboxes
def test_destroy_closes_cached_sandbox_client(tmp_path):
"""destroy() must close the host-side client before backend container teardown (#2872)."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-destroy")
backend_destroy = provider._backend.destroy
provider.destroy("sandbox-destroy")
sandbox.close.assert_called_once_with()
backend_destroy.assert_called_once()
assert "sandbox-destroy" not in provider._sandboxes
assert "sandbox-destroy" not in provider._sandbox_infos
def test_shutdown_closes_all_active_sandbox_clients(tmp_path):
"""shutdown() must close every cached AioSandbox client during teardown (#2872)."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-shut")
provider.shutdown()
sandbox.close.assert_called_once_with()
provider._backend.destroy.assert_called_once()
assert provider._sandboxes == {}
def test_release_swallows_close_errors(tmp_path, caplog):
"""A failure inside sandbox.close() must not break provider release()."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-rel-err")
sandbox.close.side_effect = RuntimeError("boom")
with caplog.at_level("WARNING"):
provider.release("sandbox-rel-err")
assert "Error closing sandbox sandbox-rel-err during release" in caplog.text
# Still moved to warm pool: client teardown failure must not block lifecycle.
assert "sandbox-rel-err" in provider._warm_pool
def test_destroy_swallows_close_errors_and_still_destroys_backend(tmp_path, caplog):
"""A failure in sandbox.close() must not skip backend container destruction."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dest-err")
sandbox.close.side_effect = RuntimeError("boom")
with caplog.at_level("WARNING"):
provider.destroy("sandbox-dest-err")
assert "Error closing sandbox sandbox-dest-err during destroy" in caplog.text
provider._backend.destroy.assert_called_once()
+2 -233
View File
@@ -12,14 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.message_bus import ( from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
PENDING_CLARIFICATION_METADATA_KEY,
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
from app.channels.store import ChannelStore from app.channels.store import ChannelStore
@@ -399,47 +392,6 @@ class TestExtractResponseText:
assert _extract_response_text(result) == "Here is the plan." assert _extract_response_text(result) == "Here is the plan."
class TestClarificationDetection:
def test_final_clarification_tool_message_is_pending(self):
from app.channels.manager import _has_current_turn_clarification
result = {
"messages": [
{"type": "human", "content": "deploy"},
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
]
}
assert _has_current_turn_clarification(result) is True
def test_clarification_followed_by_regular_ai_is_not_pending(self):
from app.channels.manager import _has_current_turn_clarification
result = {
"messages": [
{"type": "human", "content": "deploy"},
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
{"type": "ai", "content": "I will continue without pending clarification."},
]
}
assert _has_current_turn_clarification(result) is False
def test_previous_turn_clarification_does_not_mark_current_turn(self):
from app.channels.manager import _has_current_turn_clarification
result = {
"messages": [
{"type": "human", "content": "deploy"},
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
{"type": "human", "content": "prod"},
{"type": "ai", "content": "Deploying to prod."},
]
}
assert _has_current_turn_clarification(result) is False
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# ChannelManager tests # ChannelManager tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -685,74 +637,6 @@ class TestChannelManager:
_run(go()) _run(go())
def test_handle_chat_marks_clarification_outbound_metadata(self):
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
mock_client = _make_mock_langgraph_client(
run_result={
"messages": [
{"type": "human", "content": "deploy"},
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
]
}
)
manager._client = mock_client
await manager.start()
inbound = InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="deploy",
metadata={"message_id": "msg-1"},
)
await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
assert outbound_received[0].text == "Which environment?"
assert outbound_received[0].metadata["message_id"] == "msg-1"
assert outbound_received[0].metadata[PENDING_CLARIFICATION_METADATA_KEY] is True
_run(go())
def test_handle_chat_does_not_mark_regular_outbound_as_clarification(self):
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
mock_client = _make_mock_langgraph_client()
manager._client = mock_client
await manager.start()
await bus.publish_inbound(InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi"))
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
assert outbound_received[0].text == "Hello from agent!"
assert PENDING_CLARIFICATION_METADATA_KEY not in outbound_received[0].metadata
_run(go())
def test_handle_chat_outbound_drops_large_metadata_keys(self): def test_handle_chat_outbound_drops_large_metadata_keys(self):
"""Large metadata keys like raw_message should be stripped from outbound messages.""" """Large metadata keys like raw_message should be stripped from outbound messages."""
from app.channels.manager import ChannelManager from app.channels.manager import ChannelManager
@@ -1134,67 +1018,6 @@ class TestChannelManager:
_run(go()) _run(go())
def test_handle_feishu_streaming_marks_only_final_clarification_outbound(self, monkeypatch):
from app.channels.manager import ChannelManager
monkeypatch.setattr("app.channels.manager.STREAM_UPDATE_MIN_INTERVAL_SECONDS", 0.0)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
stream_events = [
_make_stream_part(
"messages-tuple",
[
{"id": "ai-1", "content": "Thinking", "type": "AIMessageChunk"},
{"langgraph_node": "agent"},
],
),
_make_stream_part(
"values",
{
"messages": [
{"type": "human", "content": "deploy"},
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
],
"artifacts": [],
},
),
]
mock_client = _make_mock_langgraph_client()
mock_client.runs.stream = MagicMock(return_value=_make_async_iterator(stream_events))
manager._client = mock_client
await manager.start()
await bus.publish_inbound(
InboundMessage(
channel_name="feishu",
chat_id="chat1",
user_id="user1",
text="deploy",
thread_ts="om-source-1",
)
)
await _wait_for(lambda: len(outbound_received) >= 2)
await manager.stop()
assert [msg.is_final for msg in outbound_received] == [False, False, True]
assert outbound_received[0].text == "Thinking"
assert outbound_received[1].text == "Which environment?"
assert outbound_received[2].text == "Which environment?"
assert all(PENDING_CLARIFICATION_METADATA_KEY not in msg.metadata for msg in outbound_received[:-1])
assert outbound_received[-1].metadata[PENDING_CLARIFICATION_METADATA_KEY] is True
_run(go())
def test_handle_feishu_stream_error_still_sends_final(self, monkeypatch): def test_handle_feishu_stream_error_still_sends_final(self, monkeypatch):
"""When the stream raises mid-way, a final outbound with is_final=True must still be published.""" """When the stream raises mid-way, a final outbound with is_final=True must still be published."""
from app.channels.manager import ChannelManager from app.channels.manager import ChannelManager
@@ -1787,51 +1610,6 @@ class TestChannelManager:
_run(go()) _run(go())
class TestResolveRunParamsUserId:
"""Regression for PR #3294: channel identity must reach ``run_context``
while staying safe for user-scoped filesystem buckets.
"""
def _manager(self):
from app.channels.manager import ChannelManager
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
return ChannelManager(bus=bus, store=store)
def test_safe_user_id_is_passed_through(self):
manager = self._manager()
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == "123456"
assert run_context["channel_user_id"] == "123456"
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
from deerflow.config.paths import make_safe_user_id
manager = self._manager()
raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == make_safe_user_id(raw)
assert run_context["user_id"] != raw
assert run_context["channel_user_id"] == raw
@pytest.mark.parametrize("raw_user_id", ["", None])
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
manager = self._manager()
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert "user_id" not in run_context
assert "channel_user_id" not in run_context
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# ChannelService tests # ChannelService tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -2232,8 +2010,7 @@ class TestFeishuChannel:
async def go(): async def go():
bus = MessageBus() bus = MessageBus()
bus.publish_inbound = AsyncMock() bus.publish_inbound = AsyncMock()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") channel = FeishuChannel(bus, config={})
channel = FeishuChannel(bus, config={"channel_store": store})
channel._api_client = MagicMock() channel._api_client = MagicMock()
reply_started = asyncio.Event() reply_started = asyncio.Event()
@@ -2269,11 +2046,6 @@ class TestFeishuChannel:
text="Hello", text="Hello",
is_final=False, is_final=False,
thread_ts="om-source-msg", thread_ts="om-source-msg",
metadata={
"user_id": "user-1",
"root_id": "om-root-msg",
"topic_id": "om-root-msg",
},
) )
) )
) )
@@ -2288,9 +2060,6 @@ class TestFeishuChannel:
assert channel._reply_card.await_count == 1 assert channel._reply_card.await_count == 1
channel._update_card.assert_awaited_once_with("om-running-card", "Hello") channel._update_card.assert_awaited_once_with("om-running-card", "Hello")
assert "om-source-msg" not in channel._running_card_tasks assert "om-source-msg" not in channel._running_card_tasks
assert store.get_thread_id("feishu", "chat-1", topic_id="om-source-msg") == "thread-1"
assert store.get_thread_id("feishu", "chat-1", topic_id="om-running-card") == "thread-1"
assert store.get_thread_id("feishu", "chat-1", topic_id="om-root-msg") == "thread-1"
_run(go()) _run(go())
-93
View File
@@ -326,99 +326,6 @@ class TestAsyncCheckpointer:
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db") mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
mock_saver.setup.assert_awaited_once() mock_saver.setup.assert_awaited_once()
@pytest.mark.anyio
async def test_postgres_uses_connection_pool(self):
"""Async postgres checkpointer should use AsyncConnectionPool, not a single connection."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
mock_config = MagicMock()
mock_config.checkpointer = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
mock_saver = AsyncMock()
mock_saver_cls = MagicMock(return_value=mock_saver)
mock_pool_instance = AsyncMock()
mock_pool_instance.__aenter__.return_value = mock_pool_instance
mock_pool_instance.__aexit__.return_value = False
mock_pool_cls = MagicMock(return_value=mock_pool_instance)
mock_pool_cls.check_connection = AsyncMock()
mock_dict_row = MagicMock()
mock_pg_module = MagicMock()
mock_pg_module.AsyncPostgresSaver = mock_saver_cls
mock_psycopg_rows = MagicMock()
mock_psycopg_rows.dict_row = mock_dict_row
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}),
patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}),
patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}),
):
# AsyncConnectionPool() is a callable that returns mock_pool_instance
# We need the constructor to be an async context manager
async with make_checkpointer() as saver:
assert saver is mock_saver
# Verify the pool was constructed with check Connection
mock_pool_cls.assert_called_once()
call_kwargs = mock_pool_cls.call_args
assert call_kwargs[0][0] == "postgresql://localhost/db"
assert call_kwargs[1]["check"] is mock_pool_cls.check_connection
# Verify saver was constructed with the pool (not via from_conn_string)
mock_saver_cls.assert_called_once_with(conn=mock_pool_instance)
mock_saver.setup.assert_awaited_once()
@pytest.mark.anyio
async def test_database_postgres_uses_connection_pool(self):
"""Unified database postgres path should use AsyncConnectionPool with keepalive."""
from deerflow.config.database_config import DatabaseConfig
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
db_config = DatabaseConfig(backend="postgres", postgres_url="postgresql://localhost/db")
mock_config = MagicMock()
mock_config.checkpointer = None
mock_config.database = db_config
mock_saver = AsyncMock()
mock_saver_cls = MagicMock(return_value=mock_saver)
mock_pool_instance = AsyncMock()
mock_pool_instance.__aenter__.return_value = mock_pool_instance
mock_pool_instance.__aexit__.return_value = False
mock_pool_cls = MagicMock(return_value=mock_pool_instance)
mock_pool_cls.check_connection = AsyncMock()
mock_dict_row = MagicMock()
mock_pg_module = MagicMock()
mock_pg_module.AsyncPostgresSaver = mock_saver_cls
mock_psycopg_rows = MagicMock()
mock_psycopg_rows.dict_row = mock_dict_row
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}),
patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}),
patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}),
):
async with make_checkpointer() as saver:
assert saver is mock_saver
mock_pool_cls.assert_called_once()
call_kwargs = mock_pool_cls.call_args
assert call_kwargs[0][0] == "postgresql://localhost/db"
assert call_kwargs[1]["check"] is mock_pool_cls.check_connection
mock_saver_cls.assert_called_once_with(conn=mock_pool_instance)
mock_saver.setup.assert_awaited_once()
@pytest.mark.anyio @pytest.mark.anyio
async def test_database_sqlite_creates_parent_dir_via_to_thread(self): async def test_database_sqlite_creates_parent_dir_via_to_thread(self):
"""Unified database SQLite setup should also move path IO off the event loop.""" """Unified database SQLite setup should also move path IO off the event loop."""
-4
View File
@@ -1472,7 +1472,6 @@ class TestUploads:
assert result["success"] is True assert result["success"] is True
assert len(result["files"]) == 1 assert len(result["files"]) == 1
assert result["files"][0]["filename"] == "test.txt" assert result["files"][0]["filename"] == "test.txt"
assert result["files"][0]["size"] == len("hello")
assert "artifact_url" in result["files"][0] assert "artifact_url" in result["files"][0]
assert "message" in result assert "message" in result
assert (uploads_dir / "test.txt").exists() assert (uploads_dir / "test.txt").exists()
@@ -1552,8 +1551,6 @@ class TestUploads:
assert len(result["files"]) == 2 assert len(result["files"]) == 2
names = {f["filename"] for f in result["files"]} names = {f["filename"] for f in result["files"]}
assert names == {"a.txt", "b.txt"} assert names == {"a.txt", "b.txt"}
sizes = {f["filename"]: f["size"] for f in result["files"]}
assert sizes == {"a.txt": 1, "b.txt": 2}
# Verify artifact_url is present # Verify artifact_url is present
for f in result["files"]: for f in result["files"]:
assert "artifact_url" in f assert "artifact_url" in f
@@ -2461,7 +2458,6 @@ class TestGatewayConformance:
parsed = UploadResponse(**result) parsed = UploadResponse(**result)
assert parsed.success is True assert parsed.success is True
assert len(parsed.files) == 1 assert len(parsed.files) == 1
assert parsed.files[0].size == len("hello")
def test_get_memory_config(self, client): def test_get_memory_config(self, client):
mem_cfg = MagicMock() mem_cfg = MagicMock()
-83
View File
@@ -1,83 +0,0 @@
import pytest
from langchain_core.tools import tool as as_tool
from deerflow.tools.builtins.tool_search import DeferredToolCatalog
@as_tool
def alpha_search(query: str) -> str:
"Search alpha records by query."
return query
@as_tool
def beta_translate(text: str) -> str:
"Translate beta text."
return text
@pytest.fixture
def catalog() -> DeferredToolCatalog:
return DeferredToolCatalog((alpha_search, beta_translate))
def test_names(catalog):
assert catalog.names == frozenset({"alpha_search", "beta_translate"})
def test_search_select(catalog):
got = catalog.search("select:alpha_search")
assert [t.name for t in got] == ["alpha_search"]
def test_search_plus_keyword(catalog):
got = catalog.search("+beta translate")
assert [t.name for t in got] == ["beta_translate"]
def test_search_regex_on_description(catalog):
got = catalog.search("translate")
assert "beta_translate" in [t.name for t in got]
def test_search_invalid_regex_falls_back_to_literal():
@as_tool
def calc(expr: str) -> str:
"Compute sum(a, b) style expressions."
return expr
cat = DeferredToolCatalog((calc, alpha_search))
# "sum(" is an invalid regex (unbalanced paren). search() must not raise; it
# falls back to a literal match, which finds calc's "sum(" in its description.
assert [t.name for t in cat.search("sum(")] == ["calc"]
# A literal with no match is deterministically empty (and still must not raise).
assert cat.search("zzz(") == []
def test_search_empty_query_returns_empty(catalog):
# An empty / whitespace-only query is meaningless; rather than let the empty
# regex match every tool, search() returns nothing so the model gets a clear
# "no match" signal and re-queries instead of acting on noise.
assert catalog.search("") == []
assert catalog.search(" ") == []
def test_search_bare_plus_returns_empty(catalog):
# A "+" prefix with no required token is malformed model input. It must
# return no matches, not raise IndexError on parts[0]. " + " strips to "+",
# so it routes here too and must be handled the same way.
assert catalog.search("+") == []
assert catalog.search(" + ") == []
assert catalog.search("+ ") == []
def test_hash_stable_across_instances():
c1 = DeferredToolCatalog((alpha_search, beta_translate))
c2 = DeferredToolCatalog((beta_translate, alpha_search))
assert c1.hash == c2.hash
def test_hash_changes_with_membership():
c1 = DeferredToolCatalog((alpha_search, beta_translate))
c2 = DeferredToolCatalog((alpha_search,))
assert c1.hash != c2.hash
@@ -1,87 +0,0 @@
"""Tests for DeferredToolFilterMiddleware (closure deferred-set + state promotion)."""
from langchain_core.tools import tool as as_tool
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
@as_tool
def mcp_a(x: str) -> str:
"a"
return x
@as_tool
def mcp_b(x: str) -> str:
"b"
return x
@as_tool
def active_c(x: str) -> str:
"c"
return x
class _Req:
def __init__(self, tools, state):
self.tools = tools
self.state = state
self.overridden = None
def override(self, tools):
self.overridden = tools
return self
def _mw():
return DeferredToolFilterMiddleware(frozenset({"mcp_a", "mcp_b"}), "h1")
def test_hides_all_deferred_when_no_promotion():
req = _Req([mcp_a, mcp_b, active_c], {})
out = _mw()._filter_tools(req)
assert [t.name for t in out.overridden] == ["active_c"]
def test_promoted_under_matching_hash_passes_through():
req = _Req([mcp_a, mcp_b, active_c], {"promoted": {"catalog_hash": "h1", "names": ["mcp_a"]}})
out = _mw()._filter_tools(req)
assert {t.name for t in out.overridden} == {"mcp_a", "active_c"}
def test_promotion_ignored_when_hash_mismatch():
req = _Req([mcp_a, mcp_b, active_c], {"promoted": {"catalog_hash": "STALE", "names": ["mcp_a"]}})
out = _mw()._filter_tools(req)
assert [t.name for t in out.overridden] == ["active_c"]
def test_no_deferred_names_is_noop():
req = _Req([active_c], {})
out = DeferredToolFilterMiddleware(frozenset(), "h1")._filter_tools(req)
assert out.overridden is None # returned unchanged
def test_blocked_message_for_unpromoted_deferred_call():
class _TCReq:
tool_call = {"name": "mcp_a", "id": "tc1"}
state = {}
msg = _mw()._blocked_tool_message(_TCReq())
assert msg is not None and msg.status == "error" and "tool_search" in msg.content
def test_no_block_for_promoted_call():
class _TCReq:
tool_call = {"name": "mcp_a", "id": "tc1"}
state = {"promoted": {"catalog_hash": "h1", "names": ["mcp_a"]}}
assert _mw()._blocked_tool_message(_TCReq()) is None
def test_no_block_for_non_deferred_call():
class _TCReq:
tool_call = {"name": "active_c", "id": "tc1"}
state = {}
assert _mw()._blocked_tool_message(_TCReq()) is None
@@ -1,73 +0,0 @@
"""End-to-end: tool_search promotes a deferred tool into the next model turn.
Locks the full loop through a real ``create_agent`` graph:
turn 1 -> deferred MCP tools hidden from bind_tools; model calls tool_search
ToolNode-> tool_search returns Command(update={"promoted": {...}}) -> state
turn 2 -> middleware reads state["promoted"] (hash-scoped) -> the searched
tool's schema is now bound; un-searched deferred tools stay hidden
This is the behavior #3272's redesign depends on (no ContextVar): promotion
flows through graph state, so it works regardless of build/execute context.
"""
import asyncio
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool as as_tool
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
@as_tool
def active_tool(x: str) -> str:
"An always-active tool."
return x
@as_tool
def mcp_calc(expression: str) -> str:
"Evaluate arithmetic."
return expression
@as_tool
def mcp_other(x: str) -> str:
"Another deferred MCP tool."
return x
def test_tool_search_promotes_into_next_turn():
bound: list[list[str]] = []
class RecordingModel(GenericFakeChatModel):
def bind_tools(self, tools, **kwargs):
bound.append([getattr(t, "name", None) for t in tools])
return self
setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_calc), tag_mcp_tool(mcp_other)], enabled=True)
turn1 = AIMessage(content="", tool_calls=[{"name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "c1", "type": "tool_call"}])
turn2 = AIMessage(content="done")
model = RecordingModel(messages=iter([turn1, turn2]))
graph = create_agent(
model=model,
tools=[active_tool, mcp_calc, mcp_other, setup.tool_search_tool],
middleware=[DeferredToolFilterMiddleware(setup.deferred_names, setup.catalog_hash)],
state_schema=ThreadState,
)
result = asyncio.run(graph.ainvoke({"messages": [HumanMessage(content="use the deferred calculator")]}))
assert len(bound) >= 2, f"expected >=2 model binds, got {bound}"
# Turn 1: both deferred MCP tools hidden.
assert "mcp_calc" not in bound[0] and "mcp_other" not in bound[0]
# Turn 2: the searched tool is promoted (visible); the un-searched one stays hidden.
assert "mcp_calc" in bound[1]
assert "mcp_other" not in bound[1]
# Promotion recorded in graph state, scoped by catalog hash.
assert result["promoted"] == {"catalog_hash": setup.catalog_hash, "names": ["mcp_calc"]}
-62
View File
@@ -1,62 +0,0 @@
from langchain_core.tools import tool as as_tool
from langgraph.types import Command
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, build_deferred_tool_setup, build_tool_search_tool
from deerflow.tools.mcp_metadata import is_mcp_tool, tag_mcp_tool
@as_tool
def mcp_calc(expression: str) -> str:
"Evaluate arithmetic."
return expression
@as_tool
def local_echo(text: str) -> str:
"Echo text."
return text
def test_is_mcp_tool_reads_metadata():
assert is_mcp_tool(tag_mcp_tool(mcp_calc)) is True
assert is_mcp_tool(local_echo) is False
def test_setup_disabled_returns_empty():
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=False)
assert setup.tool_search_tool is None
assert setup.deferred_names == frozenset()
assert setup.catalog_hash is None
def test_setup_no_mcp_returns_empty():
setup = build_deferred_tool_setup([local_echo], enabled=True)
assert setup.tool_search_tool is None
assert setup.deferred_names == frozenset()
def test_setup_builds_from_mcp_survivors():
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=True)
assert setup.deferred_names == frozenset({"mcp_calc"})
assert setup.tool_search_tool is not None
assert setup.tool_search_tool.name == "tool_search"
assert setup.catalog_hash
def test_tool_search_returns_command_with_hash_scoped_promotion():
catalog = DeferredToolCatalog((mcp_calc,))
ts = build_tool_search_tool(catalog)
out = ts.invoke({"type": "tool_call", "name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "tc1"})
assert isinstance(out, Command)
promoted = out.update["promoted"]
assert promoted == {"catalog_hash": catalog.hash, "names": ["mcp_calc"]}
msg = out.update["messages"][0]
assert msg.tool_call_id == "tc1" and msg.name == "tool_search"
assert "mcp_calc" in msg.content
def test_tool_search_no_match_empty_names():
catalog = DeferredToolCatalog((mcp_calc,))
ts = build_tool_search_tool(catalog)
out = ts.invoke({"type": "tool_call", "name": "tool_search", "args": {"query": "select:nonexistent"}, "id": "tc2"})
assert out.update["promoted"]["names"] == []
@@ -1,179 +0,0 @@
"""Regressions for the deferred-tool redesign (#3272).
- Cross-context: building the graph in one async context and running it in a
sibling context (that did NOT inherit the build context) must still hide
deferred tools. The old ContextVar implementation failed this; the closure +
graph-state implementation must pass.
- Policy leak (Finding 1): a tool removed by policy must not be searchable.
- Fail-closed (Finding 2): a wiring regression must raise, not silently leak.
- #2884 isolation: a second (subagent-style) setup build must not affect the
lead agent's middleware/promotion.
"""
import asyncio
from pathlib import Path
import pytest
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool as as_tool
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.tools.builtins.tool_search import DeferredToolSetup, build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
@as_tool
def active_tool(x: str) -> str:
"active"
return x
@as_tool
def mcp_secret(x: str) -> str:
"deferred mcp tool — must be hidden from bind_tools until promoted"
return x
_BOUND: list[list[str]] = []
class _RecordingModel(GenericFakeChatModel):
def bind_tools(self, tools, **kwargs):
_BOUND.append([getattr(t, "name", None) for t in tools])
return self
def _build_graph():
filtered = [active_tool, tag_mcp_tool(mcp_secret)]
setup = build_deferred_tool_setup(filtered, enabled=True)
final = [*filtered, setup.tool_search_tool]
model = _RecordingModel(messages=iter([AIMessage(content="done")] * 4))
return create_agent(
model=model,
tools=final,
middleware=[DeferredToolFilterMiddleware(setup.deferred_names, setup.catalog_hash)],
system_prompt="t",
)
async def _abuild():
return _build_graph()
def test_deferred_hidden_when_built_and_run_in_different_contexts():
"""Build in one task, run in a sibling task that did not inherit it."""
_BOUND.clear()
async def main():
graph = await asyncio.create_task(_abuild())
async def run():
await graph.ainvoke({"messages": [HumanMessage(content="hi")]})
await asyncio.create_task(run())
asyncio.run(main())
assert _BOUND, "model was never bound"
assert not any("mcp_secret" in names for names in _BOUND), f"deferred MCP tool leaked into bind_tools: {_BOUND}"
def test_policy_excluded_mcp_tool_not_in_catalog():
"""Finding 1: a tool removed by policy is not searchable/exposed."""
filtered_after_policy = [active_tool] # mcp_secret denied by skill allowed-tools
setup = build_deferred_tool_setup(filtered_after_policy, enabled=True)
assert setup.deferred_names == frozenset()
assert setup.tool_search_tool is None
def test_fail_closed_when_mcp_survives_without_setup(monkeypatch):
"""Finding 2: simulate a wiring regression and assert it fails loudly.
``_assemble_deferred`` lazy-imports ``build_deferred_tool_setup`` from the
source module, so patch it there (not on the agent module).
"""
from deerflow.agents.lead_agent import agent as agentmod
monkeypatch.setattr(
"deerflow.tools.builtins.tool_search.build_deferred_tool_setup",
lambda tools, *, enabled: DeferredToolSetup(None, frozenset(), None),
)
with pytest.raises(RuntimeError, match="fail-closed"):
agentmod._assemble_deferred([tag_mcp_tool(mcp_secret)], enabled=True)
def test_subagent_reentry_does_not_touch_lead_state():
"""#2884: building a second (subagent) setup must not affect the lead's
middleware. With no shared registry/ContextVar, the lead middleware depends
only on its own deferred_names + the passed state."""
lead_setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_secret)], enabled=True)
mw = DeferredToolFilterMiddleware(lead_setup.deferred_names, lead_setup.catalog_hash)
# Simulate a subagent build re-entering tool assembly with its own setup.
_ = build_deferred_tool_setup([tag_mcp_tool(mcp_secret)], enabled=True)
class _Req:
def __init__(self):
self.tools = [active_tool, mcp_secret]
self.state = {"promoted": {"catalog_hash": lead_setup.catalog_hash, "names": ["mcp_secret"]}}
def override(self, tools):
self.tools = tools
return self
out = mw._filter_tools(_Req())
assert {t.name for t in out.tools} == {"active_tool", "mcp_secret"} # promotion intact
def _make_skill(allowed_tools):
"""Skill carrying an explicit allowed-tools allowlist (None = legacy allow-all)."""
return Skill(
name="s",
description="d",
license="MIT",
skill_dir=Path("/tmp/s"),
skill_file=Path("/tmp/s/SKILL.md"),
relative_path=Path("s"),
category="public",
allowed_tools=allowed_tools,
enabled=True,
)
def test_policy_denied_mcp_yields_no_tool_search_end_to_end():
"""An allowlist that denies the MCP tool gates it end-to-end: after the real
policy filter no MCP tool survives, so ``_assemble_deferred`` adds no
tool_search (and does not fail-closed, because no MCP tool leaked through)."""
from deerflow.agents.lead_agent import agent as agentmod
filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(["active_tool"])])
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
assert [t.name for t in final_tools] == ["active_tool"]
assert "tool_search" not in {t.name for t in final_tools}
assert setup.deferred_names == frozenset()
def test_tool_search_appended_after_policy_but_never_exposes_denied_tool():
"""Intentional behavior change vs. upstream (Copilot review on PR #3342).
``tool_search`` is appended AFTER skill-allowlist filtering, so an allowlist
can no longer deny ``tool_search`` by name. This is safe by construction: the
tool only appears when allowed MCP tools survive the filter, and its catalog
is derived from the already policy-filtered list so it can never expose a
tool the allowlist denied. Locks that contract so the ordering cannot regress.
"""
from deerflow.agents.lead_agent import agent as agentmod
allowed = ["active_tool", "mcp_secret"] # permits the MCP tool, does NOT list tool_search
filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(allowed)])
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
names = {t.name for t in final_tools}
assert "tool_search" in names # appended despite not being in the allowlist
assert setup.deferred_names == frozenset({"mcp_secret"})
assert set(setup.deferred_names) <= set(allowed) # catalog never exceeds the allowlist
@@ -82,6 +82,15 @@ def fake_translator(text: str, target_lang: str) -> str:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_registry_between_tests():
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
@@ -136,7 +145,6 @@ async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch:
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.tools import get_available_tools from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator]) _patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
@@ -150,17 +158,18 @@ async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch:
Use this whenever the user asks you to delegate work pass a short Use this whenever the user asks you to delegate work pass a short
description as ``prompt``. description as ``prompt``.
""" """
# ``task_tool`` does this internally. With the closure + graph-state # ``task_tool`` does this internally. Whether the registry-reset that
# design there is no shared registry/ContextVar, so a re-entrant # used to happen here actually leaks back to the parent task depends
# ``get_available_tools`` call here cannot affect the lead agent's # on asyncio's implicit context-copying semantics (gather creates
# deferred middleware or its promotion state. # child tasks with copied contexts, so reset_deferred_registry is
# task-local) — but the fix in this PR is what GUARANTEES the
# promotion sticks regardless of which integration path triggers a
# re-entrant ``get_available_tools`` call.
get_available_tools(subagent_enabled=False) get_available_tools(subagent_enabled=False)
_calls.append(f"fake_subagent_trigger:{prompt}") _calls.append(f"fake_subagent_trigger:{prompt}")
return "subagent completed" return "subagent completed"
raw_tools = get_available_tools() + [fake_subagent_trigger] tools = get_available_tools() + [fake_subagent_trigger]
setup = build_deferred_tool_setup(raw_tools, enabled=True)
tools = [*raw_tools, setup.tool_search_tool] if setup.tool_search_tool else raw_tools
model = ChatOpenAI( model = ChatOpenAI(
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"), model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
@@ -186,7 +195,7 @@ async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch:
graph = create_agent( graph = create_agent(
model=model, model=model,
tools=tools, tools=tools,
middleware=[DeferredToolFilterMiddleware(setup.deferred_names, setup.catalog_hash)], middleware=[DeferredToolFilterMiddleware()],
system_prompt=system_prompt, system_prompt=system_prompt,
) )
@@ -0,0 +1,390 @@
"""Reproduce + regression-guard issue #2884.
Hypothesis from the issue:
``tools.tools.get_available_tools`` unconditionally calls
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
every time it is invoked. If anything calls ``get_available_tools`` again
during the same async context (after the agent has promoted tools via
``tool_search``), the promotion is wiped and the next model call hides the
tool's schema again.
These tests pin two things:
A. **At the unit boundary** verify the failure mode directly. Promote a
tool in the registry, then call ``get_available_tools`` again and observe
that the ContextVar registry is reset and the promotion is lost.
B. **At the graph-execution boundary** drive a real ``create_agent`` graph
with the real ``DeferredToolFilterMiddleware`` through two model turns.
The first turn calls ``tool_search`` which promotes a tool. The second
turn must see that tool's schema in ``request.tools``. If
``get_available_tools`` were to run again between the two turns and reset
the registry, the second turn's filter would strip the tool.
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
unmodified; mock only the LLM and the MCP tool source. Patch
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
``get_available_tools`` resolves via lazy import) to return our fixture
tools so we don't need a real MCP server.
"""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
# ---------------------------------------------------------------------------
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
# ---------------------------------------------------------------------------
@as_tool
def fake_mcp_search(query: str) -> str:
"""Pretend to search a knowledge base for the given query."""
return f"results for {query}"
@as_tool
def fake_mcp_fetch(url: str) -> str:
"""Pretend to fetch a page at the given URL."""
return f"content of {url}"
@pytest.fixture(autouse=True)
def _supply_env(monkeypatch: pytest.MonkeyPatch):
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
@pytest.fixture(autouse=True)
def _reset_deferred_registry_between_tests():
"""Each test must start with a clean ContextVar.
The registry lives in a module-level ContextVar with no per-task isolation
in a synchronous test runner, so one test's promotion can leak into the
next and silently break filter assertions.
"""
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
"""Make get_available_tools believe an MCP server is registered.
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
that both ``AppConfig.from_file`` (which calls
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
see a valid instance. Then point the MCP tool cache at our fixture tools.
"""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Force config.tool_search.enabled=True without touching the yaml.
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
which permanently mutates module-level singletons (``_memory_config``,
``_title_config``, ) to match the developer's ``config.yaml`` — even
after pytest restores our patch. That leaks across tests later in the
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
require ``_memory_config.enabled = True``, which is the dataclass default
but FALSE in the actual yaml).
Build a minimal mock AppConfig instead and never call the real loader.
"""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Section A — direct unit-level reproduction
# ---------------------------------------------------------------------------
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
Step 1: call get_available_tools() registers MCP tools as deferred.
Step 2: simulate the agent calling tool_search by promoting one tool.
Step 3: call get_available_tools() again (the same code path
``task_tool`` exercises mid-run).
Assertion: after step 3, the promoted tool is STILL promoted (not
re-deferred). On ``main`` before the fix, step 3's
``reset_deferred_registry()`` wiped the promotion and re-registered
every MCP tool as deferred this assertion fired with
``REGRESSION (#2884)``.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# Step 1: first call — both MCP tools start deferred
get_available_tools()
reg1 = get_deferred_registry()
assert reg1 is not None
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
# Step 2: simulate tool_search promoting one of them
reg1.promote({"fake_mcp_search"})
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
# Step 3: second call — registry must NOT silently undo the promotion
get_available_tools()
reg2 = get_deferred_registry()
assert reg2 is not None
deferred_after = {e.name for e in reg2.entries}
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
# ---------------------------------------------------------------------------
# Section B — graph-execution reproduction
# ---------------------------------------------------------------------------
class _ToolSearchPromotingModel(FakeToolCallingModel):
"""Two-turn model that:
Turn 1 emit a tool_call for ``tool_search`` (the real one)
Turn 2 emit a tool_call for ``fake_mcp_search`` (the promoted tool)
Records the tools it received on each turn so the test can inspect what
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
"""
bound_tools_per_turn: list[list[str]] = []
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
# Record the tool names the model would see in this turn
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
self.bound_tools_per_turn.append(names)
return self
def _build_promoting_model() -> _ToolSearchPromotingModel:
return _ToolSearchPromotingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
}
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
"""End-to-end: drive a real create_agent graph through two turns.
Without the fix, the second-turn bind_tools call should NOT contain
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
registry and strips it). With the fix, the model sees the schema and can
invoke it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
tools = get_available_tools()
# Sanity: the assembled tool list includes the deferred tools (they're in
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
# they reach the model)
tool_names = {getattr(t, "name", "") for t in tools}
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
model = _build_promoting_model()
model.bound_tools_per_turn = [] # reset class-level recorder
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
turn1 = set(model.bound_tools_per_turn[0])
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
# This is the load-bearing assertion for issue #2884.
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
turn2 = set(model.bound_tools_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
# ---------------------------------------------------------------------------
# Section C — the actual issue #2884 trigger: a re-entrant
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
# wipe the parent's promotion.
# ---------------------------------------------------------------------------
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
(the same pattern that happens when ``task_tool`` builds a subagent's
toolset mid-run) must not wipe the parent agent's tool_search promotions.
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
``get_available_tools`` again exactly what ``task_tool`` does when it
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
promoted tool. Without the fix, the re-entry wipes the registry and
the filter re-hides it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# The trigger tool simulates what task_tool does internally: rebuild the
# toolset by calling get_available_tools while the registry is live.
@as_tool
def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
get_available_tools(subagent_enabled=False)
return f"spawned subagent for: {prompt}"
tools = get_available_tools() + [fake_subagent_trigger]
bound_per_turn: list[list[str]] = []
class _Model(FakeToolCallingModel):
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
return self
model = _Model(
responses=[
# Turn 1: do both in one batch — promote AND trigger the
# subagent-style rebuild. LangGraph executes them in order in the
# same agent step.
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
},
{
"name": "fake_subagent_trigger",
"args": {"prompt": "go"},
"id": "call_trigger_1",
"type": "tool_call",
},
],
),
# Turn 2: try to invoke the promoted tool. The model gets this
# turn only if turn 1's bind_tools recorded what the filter sent.
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-subagent-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1 sanity: deferred tool not visible yet
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
# re-entrant get_available_tools call that happened in turn 1's tool batch.
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
turn2 = set(bound_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
+1 -245
View File
@@ -1,38 +1,12 @@
import asyncio import asyncio
import json import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.feishu import FeishuChannel from app.channels.feishu import FeishuChannel
from app.channels.message_bus import ( from app.channels.message_bus import InboundMessage, MessageBus
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
InboundMessage,
MessageBus,
OutboundMessage,
)
from app.channels.store import ChannelStore
def _pending(
topic_id: str,
*,
thread_id: str | None = None,
source_message_id: str | None = None,
card_message_id: str | None = None,
created_at: float = 9999999999,
) -> dict:
return {
"thread_id": thread_id or f"deer-thread-{topic_id}",
"topic_id": topic_id,
"source_message_id": source_message_id or topic_id,
"card_message_id": card_message_id or f"card-{topic_id}",
"created_at": created_at,
}
def _run(coro): def _run(coro):
@@ -164,224 +138,6 @@ def test_feishu_on_message_extracts_image_and_file_keys():
assert "[file]" in mock_make_inbound.call_args[1]["text"] assert "[file]" in mock_make_inbound.call_args[1]["text"]
def test_feishu_on_message_reuses_stored_parent_topic_for_card_replies():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
store.set_thread_id(
"feishu",
"chat_1",
"deer-thread-1",
topic_id="om_clarification_card",
user_id="user_1",
)
channel = FeishuChannel(
bus,
{"app_id": "test", "app_secret": "test", "channel_store": store},
)
event = MagicMock()
event.event.message.chat_id = "chat_1"
event.event.message.message_id = "msg_reply"
event.event.message.root_id = "om_unknown_root"
event.event.message.parent_id = "om_clarification_card"
event.event.message.thread_id = None
event.event.sender.sender_id.open_id = "user_1"
event.event.message.content = json.dumps({"text": "prod"})
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(event)
inbound = mock_make_inbound.return_value
assert inbound.topic_id == "om_clarification_card"
assert mock_make_inbound.call_args.kwargs["metadata"]["topic_id"] == "om_clarification_card"
def _make_text_event(
text: str,
*,
chat_id: str = "chat_1",
message_id: str = "msg_1",
user_id: str = "user_1",
root_id: str | None = None,
parent_id: str | None = None,
thread_id: str | None = None,
):
event = MagicMock()
event.event.message.chat_id = chat_id
event.event.message.message_id = message_id
event.event.message.root_id = root_id
event.event.message.parent_id = parent_id
event.event.message.thread_id = thread_id
event.event.sender.sender_id.open_id = user_id
event.event.message.content = json.dumps({"text": text})
return event
def test_feishu_plain_reply_consumes_pending_clarification_topic():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
store.set_thread_id("feishu", "chat_1", "deer-thread-1", topic_id="om_original", user_id="user_1")
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test", "channel_store": store})
channel._pending_clarifications[channel._pending_key("chat_1", "user_1")] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card")]
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(_make_text_event("2", message_id="msg_plain_2"))
inbound = mock_make_inbound.return_value
metadata = mock_make_inbound.call_args.kwargs["metadata"]
assert inbound.topic_id == "om_original"
assert metadata["topic_id"] == "om_original"
assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is True
assert channel._pending_key("chat_1", "user_1") not in channel._pending_clarifications
def test_feishu_pending_clarification_is_consumed_once():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._pending_clarifications[channel._pending_key("chat_1", "user_1")] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card")]
with pytest.MonkeyPatch.context() as m:
created = []
def fake_make_inbound(**kwargs):
inbound = InboundMessage(channel_name="feishu", **kwargs)
created.append(inbound)
return inbound
mock_make_inbound = MagicMock(side_effect=fake_make_inbound)
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(_make_text_event("2", message_id="msg_first"))
channel._on_message(_make_text_event("next", message_id="msg_second"))
first_inbound = created[0]
second_inbound = created[1]
first_metadata = mock_make_inbound.call_args_list[0].kwargs["metadata"]
second_metadata = mock_make_inbound.call_args_list[1].kwargs["metadata"]
assert first_inbound.topic_id == "om_original"
assert second_inbound.topic_id == "msg_second"
assert first_metadata["topic_id"] == "om_original"
assert first_metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is True
assert second_metadata["topic_id"] == "msg_second"
assert second_metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False
def test_feishu_expired_pending_clarification_is_ignored(monkeypatch):
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
monkeypatch.setattr("app.channels.feishu.time.time", lambda: 10_000.0)
channel._pending_clarifications[channel._pending_key("chat_1", "user_1")] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card", created_at=0.0)]
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(_make_text_event("2", message_id="msg_plain_2"))
metadata = mock_make_inbound.call_args.kwargs["metadata"]
assert metadata["topic_id"] == "msg_plain_2"
assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False
assert channel._pending_key("chat_1", "user_1") not in channel._pending_clarifications
def test_feishu_command_does_not_consume_pending_clarification():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
key = channel._pending_key("chat_1", "user_1")
channel._pending_clarifications[key] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card")]
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(_make_text_event("/status", message_id="msg_command"))
metadata = mock_make_inbound.call_args.kwargs["metadata"]
assert mock_make_inbound.call_args.kwargs["msg_type"].value == "command"
assert metadata["topic_id"] == "msg_command"
assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False
assert key in channel._pending_clarifications
def test_feishu_remembers_pending_clarification_only_after_final_card_success():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
outbound = OutboundMessage(
channel_name="feishu",
chat_id="chat_1",
thread_id="deer-thread-1",
text="clarify?",
thread_ts="om_original",
metadata={
PENDING_CLARIFICATION_METADATA_KEY: True,
"user_id": "user_1",
"topic_id": "om_original",
"message_id": "om_original",
},
)
channel._remember_pending_clarification(outbound, None)
assert channel._pending_clarifications == {}
channel._remember_pending_clarification(outbound, "om_card")
pending = channel._pending_clarifications[channel._pending_key("chat_1", "user_1")][0]
assert pending["topic_id"] == "om_original"
assert pending["thread_id"] == "deer-thread-1"
assert pending["card_message_id"] == "om_card"
def test_feishu_multiple_pending_clarifications_are_consumed_in_order():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
key = channel._pending_key("chat_1", "user_1")
channel._pending_clarifications[key] = [
_pending("om_first", thread_id="deer-thread-1"),
_pending("om_second", thread_id="deer-thread-2"),
]
with pytest.MonkeyPatch.context() as m:
created = []
def fake_make_inbound(**kwargs):
inbound = InboundMessage(channel_name="feishu", **kwargs)
created.append(inbound)
return inbound
m.setattr(channel, "_make_inbound", MagicMock(side_effect=fake_make_inbound))
channel._on_message(_make_text_event("first answer", message_id="msg_first"))
channel._on_message(_make_text_event("second answer", message_id="msg_second"))
assert [msg.topic_id for msg in created] == ["om_first", "om_second"]
assert key not in channel._pending_clarifications
def test_feishu_explicit_reply_prefers_stored_mapping_over_pending():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
store.set_thread_id("feishu", "chat_1", "deer-thread-card", topic_id="om_card", user_id="user_1")
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test", "channel_store": store})
key = channel._pending_key("chat_1", "user_1")
channel._pending_clarifications[key] = [_pending("om_pending", thread_id="deer-thread-pending")]
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(
_make_text_event(
"answer",
message_id="msg_reply",
root_id="om_unknown",
parent_id="om_card",
)
)
metadata = mock_make_inbound.call_args.kwargs["metadata"]
assert metadata["topic_id"] == "om_card"
assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False
assert key in channel._pending_clarifications
@pytest.mark.parametrize("command", sorted(KNOWN_CHANNEL_COMMANDS)) @pytest.mark.parametrize("command", sorted(KNOWN_CHANNEL_COMMANDS))
def test_feishu_recognizes_all_known_slash_commands(command): def test_feishu_recognizes_all_known_slash_commands(command):
"""Every entry in KNOWN_CHANNEL_COMMANDS must be classified as a command.""" """Every entry in KNOWN_CHANNEL_COMMANDS must be classified as a command."""
@@ -1,353 +0,0 @@
"""Regression tests for graceful run-task drain on Gateway shutdown.
Guards bytedance/deer-flow issue #3373:
psycopg_pool.PoolClosed: the pool 'pool-1' is already closed
Root cause: chat runs are fire-and-forget background ``asyncio`` tasks
(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned
by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the
checkpointer's postgres pool while those tasks were still mid-graph. langgraph's
``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its
``finally: await checkpointer.aput(...)`` against the already-closed pool.
Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run,
and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the
checkpointer so the final checkpoint write lands while the pool is still open.
The drain must stay bounded (a stuck run must not hang the worker, the
precondition for the signal-reentrancy deadlock guarded by
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``).
"""
from __future__ import annotations
import asyncio
import operator
from contextlib import asynccontextmanager, suppress
from types import SimpleNamespace
from typing import Annotated, TypedDict
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime import RunManager, RunStatus
# Module-level so langgraph's get_type_hints (which resolves annotations against
# module globals under `from __future__ import annotations`) can see Annotated.
class _CountState(TypedDict):
count: Annotated[int, operator.add]
class _CloseableSaver(InMemorySaver):
"""InMemorySaver that fails writes once closed, like a closed pool."""
def __init__(self) -> None:
super().__init__()
self._closed = False
self.writes_after_close: list[str] = []
def close(self) -> None:
self._closed = True
async def aput(self, *args, **kwargs):
if self._closed:
self.writes_after_close.append("aput")
raise RuntimeError("checkpointer is closed")
return await super().aput(*args, **kwargs)
async def aput_writes(self, *args, **kwargs):
if self._closed:
self.writes_after_close.append("aput_writes")
raise RuntimeError("checkpointer is closed")
return await super().aput_writes(*args, **kwargs)
@pytest.mark.asyncio
async def test_shutdown_cancels_and_awaits_inflight_run():
"""shutdown() cancels the in-flight task, waits for it, marks it interrupted."""
rm = RunManager()
record = await rm.create("t-drain")
await rm.set_status(record.run_id, RunStatus.running)
started = asyncio.Event()
cancelled = asyncio.Event()
async def worker() -> None:
try:
started.set()
await asyncio.Event().wait()
except asyncio.CancelledError:
cancelled.set()
raise
record.task = asyncio.create_task(worker())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
await rm.shutdown(timeout=5.0)
assert record.task.done()
assert cancelled.is_set()
assert record.status == RunStatus.interrupted
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_is_bounded_when_run_ignores_cancellation():
"""A run that swallows cancellation must not make shutdown() hang."""
rm = RunManager()
record = await rm.create("t-stubborn")
await rm.set_status(record.run_id, RunStatus.running)
started = asyncio.Event()
stop = asyncio.Event()
async def stubborn() -> None:
started.set()
while not stop.is_set():
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
if stop.is_set():
raise
# else: swallow — simulates a run stuck in slow cleanup
record.task = asyncio.create_task(stubborn())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
loop = asyncio.get_running_loop()
t0 = loop.time()
await rm.shutdown(timeout=0.3)
elapsed = loop.time() - t0
assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded"
finally:
# cleanup the deliberately-stubborn task
stop.set()
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_is_noop_without_inflight_runs():
"""shutdown() on an idle manager completes cleanly and is idempotent."""
rm = RunManager()
await rm.shutdown(timeout=1.0)
# already-finished runs must not be re-cancelled or error out
record = await rm.create("t-done")
await rm.set_status(record.run_id, RunStatus.success)
await rm.shutdown(timeout=1.0)
@pytest.mark.asyncio
async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch):
"""The wiring order lock for #3373: drain in-flight runs, THEN close the pool.
Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so
only the bootstrap/teardown ordering runs. The checkpointer probe records when
its context manager exits (pool close); a ``RunManager.shutdown`` spy records
when the drain happens. The drain MUST come first.
"""
from fastapi import FastAPI
from app.gateway.deps import langgraph_runtime
events: list[str] = []
@asynccontextmanager
async def probe_checkpointer(_config):
try:
yield object()
finally:
events.append("checkpointer_closed")
@asynccontextmanager
async def fake_stream_bridge(_config):
yield object()
@asynccontextmanager
async def fake_store(_config):
yield object()
async def fake_init_engine(_db):
return None
async def fake_close_engine():
return None
async def spy_shutdown(self, *, timeout): # noqa: ANN001
events.append("runs_drained")
monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer)
monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge)
monkeypatch.setattr("deerflow.runtime.make_store", fake_store)
monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine)
monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine)
monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None)
monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object())
monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object())
monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False)
app = FastAPI()
startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None)
async with langgraph_runtime(app, startup_config):
pass
assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown"
assert "checkpointer_closed" in events
assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}"
@pytest.mark.asyncio
async def test_drain_flushes_real_graph_checkpoint_before_close():
"""End-to-end #3373 guard with a REAL langgraph graph + checkpointer.
A real run is driven through ``graph.astream`` in a background task, then
``RunManager.shutdown()`` drains it. The checkpointer raises once closed
(mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the
drain as the gateway's AsyncExitStack does. The drain must let langgraph
flush its final checkpoint while the checkpointer is still open, so no write
lands against a closed checkpointer.
Unlike the unit/spy tests above, this exercises the real langgraph
checkpoint-put machinery, so a future langgraph change that cancels (rather
than awaits) its checkpoint-put task on executor exit would fail this test
instead of silently regressing #3373.
"""
from langgraph.graph import END, START, StateGraph
async def slow(_state: _CountState) -> dict:
await asyncio.sleep(0.1)
return {"count": 1}
saver = _CloseableSaver()
builder = StateGraph(_CountState)
for name in ("a", "b", "c"):
builder.add_node(name, slow)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
builder.add_edge("c", END)
graph = builder.compile(checkpointer=saver)
rm = RunManager()
record = await rm.create("t-e2e")
await rm.set_status(record.run_id, RunStatus.running)
thread_cfg = {"configurable": {"thread_id": "t-e2e"}}
started = asyncio.Event()
async def run() -> None:
started.set()
async for _ in graph.astream({"count": 0}, config=thread_cfg):
pass
record.task = asyncio.create_task(run())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
# Deterministically wait until the run is genuinely in-flight — poll for
# the first persisted checkpoint instead of a fixed sleep (avoids CI
# flakiness on slow runners / under event-loop contention).
async def _await_first_checkpoint() -> None:
while (await saver.aget_tuple(thread_cfg)) is None:
await asyncio.sleep(0.01)
await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0)
# The fix: drain while the checkpointer is still open ...
await rm.shutdown(timeout=5.0)
# ... and only then close it (mirrors langgraph_runtime's ExitStack).
saver.close()
assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}"
# The final checkpoint landed before close.
snapshot = await saver.aget_tuple(thread_cfg)
assert snapshot is not None
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_preserves_status_of_run_completed_during_drain():
"""A run that finishes (e.g. success) during the drain window must keep its
real terminal status shutdown must not blanket-overwrite it to
``interrupted`` in memory or in the store (Copilot review on PR #3381)."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
rm = RunManager(store=store)
record = await rm.create("t-complete")
await rm.set_status(record.run_id, RunStatus.running)
async def worker() -> None:
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
# The run had effectively finished; swallow the cancellation and
# record success, like a run that completed in the same tick the
# shutdown cancelled it.
pass
await rm.set_status(record.run_id, RunStatus.success)
record.task = asyncio.create_task(worker())
try:
await asyncio.sleep(0) # let the task reach its await point
await rm.shutdown(timeout=5.0)
assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}"
persisted = await store.get(record.run_id)
assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}"
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_surfaces_failed_interrupted_persist(caplog):
"""A failed interrupted-status persist during the drain must be surfaced (with
the run_id), not silently swallowed by the gather (maintainer review on
PR #3381)."""
import logging
from deerflow.runtime.runs.store.memory import MemoryRunStore
class _FailingStore(MemoryRunStore):
async def update_status(self, *args, **kwargs):
raise RuntimeError("store unavailable")
rm = RunManager(store=_FailingStore())
record = await rm.create("t-failpersist")
record.status = RunStatus.running # set in memory; the failing store is exercised by the drain
started = asyncio.Event()
async def worker() -> None:
started.set()
await asyncio.Event().wait() # blocks until cancelled by the drain
record.task = asyncio.create_task(worker())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"):
await rm.shutdown(timeout=5.0)
assert "Could not persist interrupted status for run" in caplog.text, caplog.text
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@@ -32,7 +32,6 @@ class _FakeRunManager:
self.store = store self.store = store
self.reconcile_calls: list[dict] = [] self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = [] self.list_by_thread_calls: list[dict] = []
self.shutdown_calls: int = 0
_FakeRunManager.instances.append(self) _FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None): async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
@@ -43,11 +42,6 @@ class _FakeRunManager:
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit}) self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit]) return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
async def shutdown(self, *, timeout: float = 5.0) -> None:
# No in-flight tasks in these startup-recovery tests; langgraph_runtime
# drains the manager on teardown, so the double must accept the call.
self.shutdown_calls += 1
class _FakeThreadStore: class _FakeThreadStore:
def __init__(self) -> None: def __init__(self) -> None:

Some files were not shown because too many files have changed in this diff Show More