mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-14 03:15:58 +00:00
Compare commits
96 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d61718c80 | |||
| 474c89bac2 | |||
| f43aa78107 | |||
| 47e9570d86 | |||
| 1783da42f4 | |||
| d23eac227f | |||
| 554017a89f | |||
| 6e839342a7 | |||
| 8955b3222a | |||
| c91dacc8e2 | |||
| cad6e89a19 | |||
| 094296440f | |||
| 839fa99237 | |||
| 3475f7cdad | |||
| 83bc2fb1ae | |||
| a17d2ff8f8 | |||
| 420a886e1d | |||
| 579e416459 | |||
| c002596ab4 | |||
| a838546a2b | |||
| bbce6c0ac0 | |||
| 0d3bfe0a76 | |||
| 503eeac788 | |||
| aa015462a7 | |||
| b8f5ed360f | |||
| 76136d22b4 | |||
| dc2ababf00 | |||
| 330a2ff8c5 | |||
| 0367fe6c7a | |||
| c733d3c917 | |||
| b6fbf0d105 | |||
| f401e7baa6 | |||
| 919d8bc279 | |||
| 2d5f0787de | |||
| 5819bd8a59 | |||
| b3c2cc42cf | |||
| 167ef4512f | |||
| ba9cc5e972 | |||
| 05ae4467ae | |||
| 2b795265e7 | |||
| a57d05fe0a | |||
| ae9e8bc0bf | |||
| 16391e35ab | |||
| 18bbb82f07 | |||
| b62c5a7b5b | |||
| 5b81588b87 | |||
| 63ce88f874 | |||
| 37337b77f9 | |||
| 8db16bb3d8 | |||
| 93e3281cbf | |||
| 0fb18e368c | |||
| 90e23bfd09 | |||
| f92a26d56f | |||
| 3b6dd0a4e3 | |||
| 3c2b60aaae | |||
| 67ad6e232f | |||
| cd5bedaa74 | |||
| 1651d1f1f5 | |||
| 799bef6d9d | |||
| 3b105d1e5f | |||
| 88759015e4 | |||
| 64d923b0fd | |||
| 519200728a | |||
| 40a371b88c | |||
| f725a963d5 | |||
| 3b4c9ff733 | |||
| 10c1d9f417 | |||
| 7679f21edf | |||
| 8d2e55a05f | |||
| d8b728f7cb | |||
| befe334f10 | |||
| d133b1119a | |||
| 88e36d9686 | |||
| 268fdd6968 | |||
| 9a5de8d6a5 | |||
| 1aac408dd0 | |||
| dd8f9bf5f0 | |||
| 2bbc7879fa | |||
| 28b1da2172 | |||
| 3fddc24c5f | |||
| 0d0968a364 | |||
| 89ae74d4f4 | |||
| 9a53f9dfbb | |||
| 8fca56cf43 | |||
| 0ffa995fe9 | |||
| f97b0c0f74 | |||
| aca7acc105 | |||
| 3ae82dc663 | |||
| 5dc2d6cbf5 | |||
| d9f4724950 | |||
| 74e3e80cf6 | |||
| 019bd16a06 | |||
| 031d6fbcbe | |||
| d6a604d5a1 | |||
| 46ddc346ad | |||
| 79cc227917 |
@@ -0,0 +1,141 @@
|
|||||||
|
---
|
||||||
|
name: blocking-io-guard
|
||||||
|
description: Ensure async-path backend code that could block the asyncio event loop is protected by a teeth-verified runtime anchor in tests/blocking_io/. Use when changing backend Python under app/, packages/harness/deerflow/, or scripts/, when running a blocking-IO triage round over the whole repo, or when a reviewer/CI asks for blocking-IO coverage. Runs a deterministic scan (changed-lines or full-repo), routes each candidate, drafts/extends an anchor, and proves it fails when the blocking IO regresses.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Blocking-IO Guard Skill
|
||||||
|
|
||||||
|
Help a contributor ship backend async changes together with the runtime anchor
|
||||||
|
that lets DeerFlow's blocking-IO CI gate actually see the new code. The dynamic
|
||||||
|
detector only catches blocking IO on paths a test executes — this skill closes
|
||||||
|
that gap, either for your own diff or for a repo-wide triage round.
|
||||||
|
|
||||||
|
Read `references/good-anchor-rules.md` before writing any anchor.
|
||||||
|
Only read `references/sop-skeleton.md` when generalizing this SOP to another
|
||||||
|
detector domain — it is not needed to execute the steps below.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
- Your change touches Python under `backend/app/`,
|
||||||
|
`backend/packages/harness/deerflow/`, or `backend/scripts/` and may run on
|
||||||
|
the async event loop (Mode A). If unsure, run Step 0 — it answers
|
||||||
|
deterministically.
|
||||||
|
- You are doing a maintenance triage round over the existing codebase
|
||||||
|
(Mode B).
|
||||||
|
|
||||||
|
## SOP (router)
|
||||||
|
|
||||||
|
### Step 0 — Scope (deterministic)
|
||||||
|
|
||||||
|
**Mode A — your own diff** (default, pre-PR). From repo root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --project backend python scripts/scan_changed_blocking_io.py --base origin/main
|
||||||
|
```
|
||||||
|
|
||||||
|
Lists blocking-IO candidates your change introduces: findings on lines the
|
||||||
|
diff added, **plus** findings that are new versus the merge base — the latter
|
||||||
|
catches a new async caller exposing an old sync helper whose blocking line is
|
||||||
|
not in the diff. The diff is `<base>...HEAD`, so **commit your work first** —
|
||||||
|
uncommitted lines are not selected.
|
||||||
|
|
||||||
|
If the list is empty, this change introduces no blocking-IO surface *that the
|
||||||
|
static detector can see in the changed files*. One residual blind spot
|
||||||
|
remains: reachability is same-file only, so a new async caller of a sync
|
||||||
|
helper **defined in another file** is invisible to both selections. If your
|
||||||
|
diff adds an async call into a helper that lives elsewhere, check that helper
|
||||||
|
manually (codegraph or `git grep`) before stopping.
|
||||||
|
|
||||||
|
**Mode B — full-repo triage round.** From repo root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make detect-blocking-io
|
||||||
|
```
|
||||||
|
|
||||||
|
Prints a summary and writes the complete structured finding list to
|
||||||
|
`.deer-flow/blocking-io-findings.json`. Work HIGH priority first; do not start
|
||||||
|
MEDIUM until every HIGH is dispositioned (fixed, guarded, or recorded
|
||||||
|
NO-ACTION).
|
||||||
|
|
||||||
|
**Batching policy (PR sizing).** One **fix unit** per PR while any HIGH
|
||||||
|
remains: a fix unit is one root cause — usually a single HIGH, but two HIGHs
|
||||||
|
resolved by the same one-place fix belong together. Once no HIGH remains,
|
||||||
|
MEDIUM/LOW may be batched (about five per round, grouped by module or by
|
||||||
|
disposition) so each PR stays reviewable. A new Blockbuster rule is never
|
||||||
|
batched with anything — it always ships alone (see Step 5).
|
||||||
|
|
||||||
|
Both modes emit the same JSON shape per finding: `priority`, `location`
|
||||||
|
(path/line/function), `blocking_call` (category/operation/symbol),
|
||||||
|
`event_loop_exposure`, `reason`, `code`. Priority is a deterministic review
|
||||||
|
ordering, not proof of a bug — Step 1 makes the actual call.
|
||||||
|
|
||||||
|
### Step 1 — Judge each candidate (router)
|
||||||
|
|
||||||
|
Read the code around each candidate and route it:
|
||||||
|
|
||||||
|
- **Already offloaded** (`asyncio.to_thread`, `run_in_executor`, async client) →
|
||||||
|
**GUARD**: add/extend an anchor that locks the offload so a future edit cannot
|
||||||
|
move it back onto the loop.
|
||||||
|
- **On the loop, not offloaded** → **FIX+ANCHOR**: offload the production code
|
||||||
|
(your fix), then add an anchor that guards it.
|
||||||
|
- **Not actually exposed / acceptable** (rare: scanner false positive,
|
||||||
|
startup-only code) → **NO-ACTION**: record one line of why.
|
||||||
|
- **Cross-file caveat**: the scanner's async reachability is same-file only
|
||||||
|
(`ASYNC_REACHABLE_SAME_FILE`). If the candidate is a *sync helper*, check for
|
||||||
|
async callers in other files (codegraph or `git grep`) before deciding
|
||||||
|
NO-ACTION.
|
||||||
|
|
||||||
|
### Step 2 — Apply the fix, then re-scan (FIX+ANCHOR only)
|
||||||
|
|
||||||
|
Offload the blocking call in production code, then re-run the Step 0 scan and
|
||||||
|
confirm the candidate no longer appears. If the offloaded call sits in a
|
||||||
|
`finally` / cleanup path, keep it best-effort and bounded (swallow-and-log,
|
||||||
|
`asyncio.wait_for`) so a failing or hung cleanup cannot mask the primary
|
||||||
|
exception. Match by the stable key
|
||||||
|
**(path, function, symbol)** — line numbers shift after edits, so never
|
||||||
|
compare by line.
|
||||||
|
|
||||||
|
- The finding must disappear. If it still shows, the fix did not remove the
|
||||||
|
blocking pattern (e.g. the call is still a direct call, not offloaded) —
|
||||||
|
go back before touching any test.
|
||||||
|
- GUARD / NO-ACTION routes skip this step: a residual finding there is
|
||||||
|
*expected* (the raw call still exists inside a sync helper with the offload
|
||||||
|
at the caller, or the exposure was judged acceptable).
|
||||||
|
|
||||||
|
This is pattern-level feedback in seconds; it complements but never replaces
|
||||||
|
Step 5 — only the runtime gate proves the event loop is actually protected.
|
||||||
|
|
||||||
|
### Step 3 — Check existing anchors
|
||||||
|
|
||||||
|
Look in `backend/tests/blocking_io/` for a test that drives the production async
|
||||||
|
entry point reaching this candidate's branch.
|
||||||
|
|
||||||
|
- Covers this branch already → go to Step 5 (re-verify teeth).
|
||||||
|
- Covers the entry point but not this branch (e.g. happy path covered,
|
||||||
|
cleanup/404/409 not) → **extend** that anchor.
|
||||||
|
- None → create one from `templates/anchor.template.py`.
|
||||||
|
|
||||||
|
### Step 4 — Generate / extend the anchor
|
||||||
|
|
||||||
|
Follow `references/good-anchor-rules.md`. Drive the *specific* branch (e.g. force
|
||||||
|
the create failure that hits the cleanup `shutil.rmtree`). Never bypass the
|
||||||
|
blocking surface with a test-only `asyncio.to_thread` wrapper.
|
||||||
|
|
||||||
|
### Step 5 — Verify teeth (mandatory; also the anchor-vs-rule discriminator)
|
||||||
|
|
||||||
|
1. Reintroduce the block (GUARD: temporarily revert the offload; FIX+ANCHOR: run
|
||||||
|
against the pre-fix code).
|
||||||
|
2. Run `cd backend && make test-blocking-io` (or target the one test). It **must
|
||||||
|
go RED**.
|
||||||
|
3. Restore the fix. It **must go GREEN**.
|
||||||
|
|
||||||
|
A real block that stays GREEN means Blockbuster has no rule for that
|
||||||
|
primitive — that is the **RULE** route; see `references/good-anchor-rules.md`
|
||||||
|
for the admission criteria before adding one.
|
||||||
|
|
||||||
|
### Step 6 — Deliver
|
||||||
|
|
||||||
|
Commit the anchor(s) with your change; `make test-blocking-io` green. In the PR,
|
||||||
|
note: candidates found, each disposition, the re-scan result (Step 2), and
|
||||||
|
the teeth evidence (red→green). Include the reason for any NO-ACTION. A new
|
||||||
|
Blockbuster rule, if any, goes in its own commit with the evidence from Step 5.
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
# Good anchor rules + teeth (blocking-IO fill)
|
||||||
|
|
||||||
|
Distilled from `backend/docs/BLOCKING_IO_DETECTION.md`. An anchor lives in
|
||||||
|
`backend/tests/blocking_io/`; the suite's conftest runs each test under the
|
||||||
|
strict Blockbuster gate scoped to `app.*` / `deerflow.*`.
|
||||||
|
|
||||||
|
The examples in this file and in `templates/` are all filesystem-flavored.
|
||||||
|
They demonstrate how to *write* the test, not what the SOP covers: the same
|
||||||
|
rules apply to every category the detector reports (FILE_IO, HTTP,
|
||||||
|
SUBPROCESS, SLEEP), and the acceptance criterion is always the teeth check
|
||||||
|
below — never similarity to an example.
|
||||||
|
|
||||||
|
## A good anchor
|
||||||
|
|
||||||
|
- Calls the **real production async entry point** — not a low-level helper,
|
||||||
|
unless that helper *is* the entry point production executes.
|
||||||
|
- Does **not** bypass the blocking surface with a test-only
|
||||||
|
`asyncio.to_thread` / `run_in_executor` wrapper.
|
||||||
|
- Uses **real local filesystem** inputs when the bug shape is filesystem IO.
|
||||||
|
- Mocks **only** the external dependency boundary (network service, third-party
|
||||||
|
saver), never the offload being guarded.
|
||||||
|
- Drives the **specific branch** you are protecting (error / cleanup / 404 /
|
||||||
|
409), not just the happy path.
|
||||||
|
|
||||||
|
## Teeth (the acceptance test)
|
||||||
|
|
||||||
|
An anchor only counts if the gate actually fires when the code blocks:
|
||||||
|
|
||||||
|
1. Reintroduce the block (revert the offload, or run pre-fix code).
|
||||||
|
2. `cd backend && make test-blocking-io` → the anchor **must fail** (RED).
|
||||||
|
3. Restore the fix → the anchor **must pass** (GREEN).
|
||||||
|
|
||||||
|
A green-on-happy-path anchor with no proven red is fake coverage. Don't ship it.
|
||||||
|
|
||||||
|
## The RULE route (rare; strict admission criteria)
|
||||||
|
|
||||||
|
Blockbuster's built-in rules cover the common blocking primitives well. The
|
||||||
|
two deliberate openings in this SOP are:
|
||||||
|
|
||||||
|
1. **Coverage opening** (the normal case): the rules already see the
|
||||||
|
primitive — you only need an anchor so runtime detection executes the real
|
||||||
|
business path and CI prevents regression.
|
||||||
|
2. **Rule opening** (rare): you reintroduced a *real* block and the gate
|
||||||
|
stayed GREEN — Blockbuster has no rule for that primitive.
|
||||||
|
|
||||||
|
A project rule lives in `_PROJECT_BLOCKING_RULES` inside
|
||||||
|
`backend/tests/support/detectors/blocking_io_runtime.py` and changes detection
|
||||||
|
for the **entire** blocking-IO suite — global blast radius. Admission criteria
|
||||||
|
for adding one:
|
||||||
|
|
||||||
|
- You have the **fails-to-fail anchor** as evidence: a good anchor (per the
|
||||||
|
rules above) that drives a genuinely blocking path and stays green. No
|
||||||
|
evidence, no rule.
|
||||||
|
- The primitive is a real blocking call (verified against its implementation
|
||||||
|
or docs), not a false positive of the static detector.
|
||||||
|
- The rule ships in its **own commit**, naming the primitive, the anchor that
|
||||||
|
exposed the gap, and the suite-wide impact. Run the full
|
||||||
|
`make test-blocking-io` suite after adding it — a new rule can turn other
|
||||||
|
previously-green tests red, and each such red is either a real latent bug
|
||||||
|
(fix it) or rule overreach (narrow the rule).
|
||||||
|
- If you are not in a position to own that blast radius (e.g. external
|
||||||
|
contributor), escalate to a maintainer with the evidence instead.
|
||||||
|
|
||||||
|
**Never add a runtime rule just because a path is untested** — that case needs
|
||||||
|
an anchor, not a rule.
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
# SOP skeleton (generic shape — extraction seam)
|
||||||
|
|
||||||
|
This is the domain-agnostic shape the blocking-IO skill instantiates. It exists
|
||||||
|
so a second detector/gate domain can reuse the flow without copying it. Do not
|
||||||
|
add machinery for that until a second domain actually appears (YAGNI).
|
||||||
|
|
||||||
|
A domain provides:
|
||||||
|
- a **static detector** that can scan a diff (or the whole tree) and emit
|
||||||
|
located candidates,
|
||||||
|
- a **CI gate** that fails when the bad pattern executes,
|
||||||
|
- a **test location** for guard tests,
|
||||||
|
- **good-test rules** for that gate,
|
||||||
|
- a **teeth definition** (how to make the gate fire on purpose).
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. **Scope (deterministic):** intersect the diff's added lines with the
|
||||||
|
detector's findings → candidates this change introduced/touched. (Or, in
|
||||||
|
triage mode, take the full finding list ordered by priority.)
|
||||||
|
2. **Judge (router):** per candidate — guard existing fix / fix + guard /
|
||||||
|
no-action / rule (the gate cannot see the primitive).
|
||||||
|
3. **Fix + re-scope (fixes only):** apply the fix, re-run the detector; the
|
||||||
|
fixed candidate must vanish from the findings (match by a stable key, not
|
||||||
|
line numbers). Pattern-level feedback in seconds — complements, never
|
||||||
|
replaces, step 5.
|
||||||
|
4. **Generate:** draft or extend a guard test per the good-test rules, driving
|
||||||
|
the specific branch.
|
||||||
|
5. **Verify teeth:** make the bad pattern happen → gate must fail; restore →
|
||||||
|
gate must pass. A pattern that stays green while genuinely bad is the
|
||||||
|
"rule" signal, not a coverage success.
|
||||||
|
6. **Deliver:** commit the verified guard test; any gate-rule change ships in
|
||||||
|
its own commit with the fails-to-fail evidence attached.
|
||||||
|
|
||||||
|
To add a domain: supply a new fill doc (like `good-anchor-rules.md`) + detector,
|
||||||
|
and promote this file into a parent skill the instances point at.
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
"""Template: a tests/blocking_io/ runtime anchor.
|
||||||
|
|
||||||
|
Copy into backend/tests/blocking_io/test_<area>.py and adapt. The suite's
|
||||||
|
conftest already wraps every test here in the strict Blockbuster gate, so you do
|
||||||
|
NOT import or activate the detector — just drive the real async entry point.
|
||||||
|
|
||||||
|
Teeth check before you commit (see references/good-anchor-rules.md):
|
||||||
|
1. reintroduce the block -> `cd backend && make test-blocking-io` must FAIL
|
||||||
|
2. restore the fix -> it must PASS
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# from app.<module> import <real_async_entry_point>
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_<entry_point>_offloads_blocking_io_on_<branch>(tmp_path: Path) -> None:
|
||||||
|
# Arrange: real inputs at the boundary the code blocks on (FS -> tmp_path;
|
||||||
|
# HTTP/subprocess -> stub the external service). Mock ONLY the external
|
||||||
|
# boundary, never the offload under test.
|
||||||
|
|
||||||
|
# Act + Assert: call the REAL production async entry point and drive the
|
||||||
|
# specific branch you are guarding (e.g. force a failure to hit the cleanup
|
||||||
|
# path). If the entry point performs blocking IO on the loop, the gate fails.
|
||||||
|
# await <real_async_entry_point>(...)
|
||||||
|
raise NotImplementedError("Replace with the real async entry point call.")
|
||||||
@@ -0,0 +1,237 @@
|
|||||||
|
---
|
||||||
|
name: deerflow-maintainer-orchestrator
|
||||||
|
description: "Use when a DeerFlow maintainer needs comment-only GitHub issue or PR handling: resolve issue/PR scopes with gh, analyze issues, post or draft issue comments, perform PR review comments, give fix strategy, risk classification, and validation guidance. Intended for maintainers and trusted local agents, not general contributors."
|
||||||
|
---
|
||||||
|
|
||||||
|
# DeerFlow Maintainer Orchestrator
|
||||||
|
|
||||||
|
## Core Rule
|
||||||
|
|
||||||
|
This is a comment-plane skill: resolve GitHub scope, inspect evidence, and prepare or post DeerFlow issue comments and PR review comments. Keep the work comment-scoped; do not turn it into coding, branch management, release work, artifact closure, or other maintainer operations.
|
||||||
|
|
||||||
|
When the maintainer asks to process, handle, comment on, or review a bounded set of issues or PRs, proceed without asking follow-up questions. Treat that request as authorization for one public issue comment per selected non-skipped issue and one PR review comment per selected PR with high-confidence findings. If a PR has no high-confidence findings, do not post a public comment; report that result to the maintainer only. If the maintainer explicitly asks for analysis only, return comment-ready drafts without posting.
|
||||||
|
|
||||||
|
The maintainer's normal interaction should be: provide scope; receive posted comment URLs, PR review URLs, clean results, skipped items, failures, or drafts. Do not offload technical analysis to the maintainer. Make the best evidence-backed recommendation in the comment itself: describe the risk, impact, likely fix, and validation path. Ask the reporter or PR author for missing evidence only when the artifact lacks enough data to diagnose.
|
||||||
|
|
||||||
|
Output only the maintainer run result or comment draft. Do not announce the skill name, mode, or that no code was edited unless the user asks for process details.
|
||||||
|
|
||||||
|
Match the dominant language of the issue or PR unless the maintainer asks for another language. Chinese issue or PR text gets Chinese output; English issue or PR text gets English output. For mixed artifacts, use the body language, not logs or code.
|
||||||
|
|
||||||
|
## Artifact Resolution
|
||||||
|
|
||||||
|
Use GitHub tooling to resolve artifact type and scope. Do not ask the maintainer to clarify when `gh` or GitHub API can determine the answer.
|
||||||
|
|
||||||
|
1. Default repository is `bytedance/deer-flow` unless a URL or explicit repo says otherwise.
|
||||||
|
2. For URLs, route `/issues/<number>` to Issue Flow and `/pull/<number>` to PR Review Flow.
|
||||||
|
3. For typed numbers, use the typed command:
|
||||||
|
- Issue: `gh issue view <number> --repo <repo> --json number,title,url,state,body,labels,author,comments`
|
||||||
|
- PR: `gh pr view <number> --repo <repo> --json number,title,url,state,body,author,files,comments,reviews,statusCheckRollup,baseRefName,headRefName`
|
||||||
|
4. Normalize multiple explicit references such as `#123`, `# 123`, and bare `123` into a number list, preserving order and de-duplicating exact repeats.
|
||||||
|
5. For untyped numbers, try `gh pr view <number> --repo <repo> --json number,url` first. If it fails, use `gh issue view <number> --repo <repo> --json number,url`. Do not ask which type it is.
|
||||||
|
6. For issue batches, use `gh issue list`, not the mixed GitHub issues endpoint. For PR batches, use `gh pr list`.
|
||||||
|
7. Respect maintainer-provided count or time window. There is no hard five-item cap. If the scope is broad and underspecified, choose a practical recent slice, state the slice used, prioritize newest and highest-risk items, and report any unprocessed remainder.
|
||||||
|
8. For "recent/latest" wording without a count, use a small default recent slice. For "recent hours" wording without a number, use six hours. Do not ask.
|
||||||
|
9. Use `gh api` when `gh issue/pr view/list` lacks required fields such as timeline events, review threads, or precise search filters.
|
||||||
|
10. Use GitHub search only as a fallback for natural-language filters that cannot be represented by view/list/API calls. Do not use web search for artifact routing unless GitHub tooling is unavailable.
|
||||||
|
11. If no artifact type, number, URL, count, time window, or searchable GitHub scope can be resolved, stop with a compact "scope unresolved" report. Do not ask a follow-up question.
|
||||||
|
|
||||||
|
Use concise repo-local references such as `#123` and `PR #123` in maintainer reports and comments. Include full GitHub URLs only for posted comment/review links returned by GitHub or when the maintainer supplied an explicit URL.
|
||||||
|
|
||||||
|
## Issue Flow
|
||||||
|
|
||||||
|
Use Issue Flow for GitHub issues, bug reports, feature requests, support questions, and issue batches.
|
||||||
|
|
||||||
|
Start every issue with a cheap duplicate-opinion precheck:
|
||||||
|
|
||||||
|
1. Fetch issue metadata, labels, author, body, and existing comments.
|
||||||
|
2. If labels, title, or body mark the issue as RFC (`rfc`, `[RFC]`, `RFC:`, or `Request for Comments`), classify it as `rfc-no-comment`, skip deep analysis, and do not post anything public unless the maintainer explicitly overrides the RFC skip for that item.
|
||||||
|
3. If an existing maintainer or trusted-agent issue comment already gives a materially equivalent diagnosis, modification suggestion, information request, or blocking decision, skip deep analysis and do not post anything public for that issue.
|
||||||
|
4. Treat ordinary reporter replies, thanks, unrelated discussion, or incomplete guesses as non-blocking.
|
||||||
|
5. Report skipped issues to the maintainer only as compact identifiers plus the skipped reason or existing comment URL when available.
|
||||||
|
|
||||||
|
For non-skipped issues:
|
||||||
|
|
||||||
|
1. Read enough context to avoid guessing: issue body, comments, screenshots, logs, reproduction details, linked artifacts, and relevant DeerFlow code/docs.
|
||||||
|
2. Classify the surface:
|
||||||
|
- Frontend UI
|
||||||
|
- Backend API
|
||||||
|
- Agents / LangGraph
|
||||||
|
- Sandbox
|
||||||
|
- Skills
|
||||||
|
- MCP
|
||||||
|
- Dependencies
|
||||||
|
- Default behavior
|
||||||
|
- Docs / tests / CI only
|
||||||
|
3. Classify actionability:
|
||||||
|
- `ready-to-fix`: bounded, evidence sufficient, validation path clear.
|
||||||
|
- `needs-more-evidence`: repro, logs, environment, screenshots, exact expected behavior, or failing case missing.
|
||||||
|
- `defer-or-close`: duplicate, stale, unsupported, unactionable, or out of scope.
|
||||||
|
- `rfc-no-comment`: RFC issue; skip public comments by default.
|
||||||
|
4. Produce a public-safe comment from the analysis, not the analysis labels:
|
||||||
|
- Start with one natural opener that connects to the issue context. Prefer `Thanks @author.` for reporter-authored issues when it reads naturally; omit the mention for bots, maintainer-authored tracking issues, or cases where it would add noise.
|
||||||
|
- The opener must say something specific about the next step or boundary, not a generic assessment. Do not use generic phrases such as "This is actionable", "I would treat this as", "ready to fix", or surface/actionability/risk labels.
|
||||||
|
- Use the smallest stable template that fits:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Thanks @author. <one specific sentence that frames the fix, investigation, or missing evidence.>
|
||||||
|
|
||||||
|
Recommended solution:
|
||||||
|
- ...
|
||||||
|
|
||||||
|
Validation:
|
||||||
|
- ...
|
||||||
|
```
|
||||||
|
|
||||||
|
- Add `Evidence:` only when citing concrete code, logs, reproduction details, or other proof helps the author act.
|
||||||
|
- Add `Risk:` only when architecture, security, public API, default behavior, or compatibility impact must be called out explicitly; make the risk specific.
|
||||||
|
- Add `Missing info:` only when the issue cannot be diagnosed without more evidence; ask for the smallest useful data.
|
||||||
|
- Put relevant files/components inside `Evidence:` or `Recommended solution:` bullets instead of separate metadata fields.
|
||||||
|
- Every posted issue comment should contain concrete modification guidance and validation guidance unless the only useful response is `Missing info:`.
|
||||||
|
5. Immediately before posting, refresh comments and skip if an equivalent maintainer or trusted-agent comment appeared during analysis.
|
||||||
|
6. Post one issue comment when posting is authorized; otherwise return the same text as `Reply draft`.
|
||||||
|
|
||||||
|
Do not expose private reasoning, credentials, internal-only context, or unsupported promises. Do not say a fix was made unless a separate coding workflow actually changed code.
|
||||||
|
|
||||||
|
## PR Review Flow
|
||||||
|
|
||||||
|
Use PR Review Flow for GitHub pull requests and PR batches.
|
||||||
|
|
||||||
|
Start every PR with a cheap duplicate-review precheck:
|
||||||
|
|
||||||
|
1. Fetch PR metadata, changed file list, checks summary, existing PR reviews, existing PR comments, and review threads when available.
|
||||||
|
2. If an existing maintainer or trusted-agent review already gives materially equivalent findings or a blocking decision, skip deep review and do not post anything public for that PR.
|
||||||
|
3. Treat author replies, thanks, unrelated discussion, or incomplete guesses as non-blocking.
|
||||||
|
4. Report skipped PRs to the maintainer only as compact identifiers plus the existing review/comment URL when available.
|
||||||
|
|
||||||
|
### Diff Base Rule
|
||||||
|
|
||||||
|
Before reviewing a local PR branch or local diff, fetch the base repository's target branch and compare against that fresh remote-tracking ref, not a possibly stale local `main`.
|
||||||
|
|
||||||
|
- For fork checkouts, prefer `upstream/<base-branch>` when `upstream` points to the base repository.
|
||||||
|
- For direct upstream checkouts, use the base remote's fetched branch, usually `origin/<base-branch>`.
|
||||||
|
- Prefer GitHub PR base metadata for the target branch. For non-PR local diffs, use the base repository default branch. If metadata is unavailable, default to `main` only after fetching the base remote.
|
||||||
|
- Refresh the comparison ref explicitly, for example `git fetch <base-remote> +refs/heads/<base-branch>:refs/remotes/<base-remote>/<base-branch>`, then inspect `BASE=$(git merge-base HEAD <base-remote>/<base-branch>)` and `git diff "$BASE"...HEAD`.
|
||||||
|
- If using `FETCH_HEAD` from a single-branch fetch instead, diff against that verified `FETCH_HEAD` immediately and do not later substitute a possibly stale remote-tracking ref.
|
||||||
|
- For uncommitted local changes, review committed branch changes against the fresh base first, then include working-tree changes separately.
|
||||||
|
- If the base remote or base branch cannot be established, use the GitHub PR files/diff as the source of truth. If neither local nor GitHub diff can be read, return a compact failure report and do not post a review.
|
||||||
|
|
||||||
|
Before posting a PR review comment:
|
||||||
|
|
||||||
|
1. Review only the current diff against the fresh base and changed files. Do not comment on unrelated pre-existing code unless the diff makes it newly risky.
|
||||||
|
2. Do not report low-confidence guesses. If evidence is insufficient, omit the finding.
|
||||||
|
3. Prioritize correctness, safety, maintainability, production risk, compatibility, and missing critical tests over style.
|
||||||
|
4. Report concrete architecture, security, public API, default-behavior, and compatibility problems as findings when the diff causes or exposes them.
|
||||||
|
5. Check changed behavior, edge cases, error paths, state mutation, transactions, locks, cache invalidation, cleanup, security boundaries, missing tests, performance/reliability, and API compatibility.
|
||||||
|
6. Immediately before posting, refresh reviews/comments and skip if an equivalent maintainer or trusted-agent review appeared during analysis.
|
||||||
|
7. If there are high-confidence findings, post a PR review comment using the PR language. If there are no high-confidence findings, do not post a public PR review/comment; report `No high-confidence review findings.` to the maintainer in the run result.
|
||||||
|
|
||||||
|
For public PR reviews with findings, start with one short opener that fits the review context and matches the finding count. Use singular wording only for exactly one finding, for example `Thanks @author. I found one issue that should be addressed before this is ready.` Use plural wording for multiple findings, for example `Thanks @author. I found a few issues that should be addressed before this is ready.` Omit the mention for bots or when it adds noise.
|
||||||
|
|
||||||
|
For each finding, use:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[P0/P1/P2] Title
|
||||||
|
|
||||||
|
- Location: file and line/range
|
||||||
|
- Problem: what can go wrong
|
||||||
|
- Evidence: why the diff causes it
|
||||||
|
- Suggested fix: concrete minimal fix
|
||||||
|
- Test: what test should cover it
|
||||||
|
```
|
||||||
|
|
||||||
|
Severity guide:
|
||||||
|
|
||||||
|
- `P0`: causes outage, data loss, security breach, or build failure.
|
||||||
|
- `P1`: likely production bug, serious regression, broken compatibility, or high-risk security/architecture issue.
|
||||||
|
- `P2`: correctness, maintainability, or test concern with lower risk.
|
||||||
|
|
||||||
|
Do not produce compliments, summaries, or general advice. For sensitive security issues, describe impact and remediation without exploit instructions.
|
||||||
|
|
||||||
|
## No-Question Policy
|
||||||
|
|
||||||
|
Do not ask the maintainer routine clarification questions. The skill should save maintainer time by turning scope into comments through a fixed workflow.
|
||||||
|
|
||||||
|
Stop without asking only when:
|
||||||
|
|
||||||
|
- no issue/PR scope can be resolved through URLs, numbers, `gh` view/list, `gh api`, or GitHub search fallback;
|
||||||
|
- GitHub authentication, repository access, or comment posting fails;
|
||||||
|
- the requested action is outside comment-only scope;
|
||||||
|
- posting would require private credentials, private security details, or non-public context.
|
||||||
|
|
||||||
|
In these cases, return a compact failure report with the attempted command path and the smallest next action. Do not phrase it as a question unless the maintainer explicitly asked to be prompted.
|
||||||
|
|
||||||
|
## DeerFlow Review Heuristics
|
||||||
|
|
||||||
|
Treat these as high-signal areas for issue comments and PR findings:
|
||||||
|
|
||||||
|
- `backend/packages/harness/deerflow/` must not import `app.*`.
|
||||||
|
- App may depend on harness; harness must stay publishable and app-agnostic.
|
||||||
|
- Frontend thread/message behavior and Gateway/LangGraph-compatible SSE are contract surfaces.
|
||||||
|
- Sandbox permissions, bash/file-write tools, skill installation, and remote execution are security-sensitive.
|
||||||
|
- Default model/provider behavior, config migration, persistence schema, public API/SSE, and LangGraph thread/run lifecycle are compatibility-sensitive.
|
||||||
|
- Runtime docs should track user-facing or developer-facing behavior changes.
|
||||||
|
- Security-sensitive comments should provide proof and remediation, not vague assertions.
|
||||||
|
|
||||||
|
## Validation Guidance
|
||||||
|
|
||||||
|
Recommend the checks matching the touched surface:
|
||||||
|
|
||||||
|
| Surface | Suggested validation |
|
||||||
|
| --- | --- |
|
||||||
|
| Backend API / harness / agents / MCP / skills runtime | `cd backend && make lint && make test` |
|
||||||
|
| Blocking IO or async file/network work | `cd backend && make test-blocking-io` or a focused blocking-IO regression |
|
||||||
|
| Harness/app boundary | `cd backend && uv run pytest tests/test_harness_boundary.py` |
|
||||||
|
| Frontend UI/core | `cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test` |
|
||||||
|
| Front/back thread or SSE contract | backend replay golden and full-stack replay render where feasible |
|
||||||
|
| Frontend user workflow | Playwright E2E or browser proof with screenshot/DOM assertion |
|
||||||
|
| Docker/sandbox/provisioner | focused backend tests plus Docker/provisioner smoke when feasible |
|
||||||
|
| Docs-only | targeted markdown review |
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
For Issue Flow:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Run result:
|
||||||
|
Posted:
|
||||||
|
Skipped:
|
||||||
|
Failed:
|
||||||
|
Per issue:
|
||||||
|
Issue:
|
||||||
|
Surface:
|
||||||
|
Actionability:
|
||||||
|
Risk:
|
||||||
|
Comment:
|
||||||
|
Validation:
|
||||||
|
Comment status:
|
||||||
|
```
|
||||||
|
|
||||||
|
For PR Review Flow:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Run result:
|
||||||
|
Reviewed:
|
||||||
|
Skipped:
|
||||||
|
Clean:
|
||||||
|
Failed:
|
||||||
|
Per PR:
|
||||||
|
PR:
|
||||||
|
Public review:
|
||||||
|
Findings:
|
||||||
|
Review status:
|
||||||
|
```
|
||||||
|
|
||||||
|
For analysis-only requests, replace `Posted`/`Reviewed` with `Drafted` and include the comment/review text without posting.
|
||||||
|
|
||||||
|
For batches, prefer a compact maintainer-facing table after the headline counts:
|
||||||
|
|
||||||
|
```text
|
||||||
|
| Artifact | Status | Public action | Notes |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| #123 | posted | comment URL | short reason |
|
||||||
|
| PR #456 | reviewed | review URL | P1: finding title |
|
||||||
|
| PR #789 | clean | none | No high-confidence review findings. |
|
||||||
|
| #321 | skipped | none | existing maintainer comment |
|
||||||
|
```
|
||||||
|
|
||||||
|
Omit empty categories, no-op fields, routine command output, and raw logs. Report meaningful changes, evidence, and options.
|
||||||
@@ -21,6 +21,7 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||||
# NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai
|
# NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai
|
||||||
# MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io
|
# MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io
|
||||||
|
# STEPFUN_API_KEY=your-stepfun-api-key # OpenAI-compatible, see https://platform.stepfun.com
|
||||||
# VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible
|
# VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible
|
||||||
# FEISHU_APP_ID=your-feishu-app-id
|
# FEISHU_APP_ID=your-feishu-app-id
|
||||||
# FEISHU_APP_SECRET=your-feishu-app-secret
|
# FEISHU_APP_SECRET=your-feishu-app-secret
|
||||||
@@ -65,3 +66,18 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
# alias, or behind a different port). docker-compose already sets these.
|
# alias, or behind a different port). docker-compose already sets these.
|
||||||
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
||||||
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
||||||
|
|
||||||
|
# ── Claude Code / Codex CLI subscription as a model provider (optional) ───────
|
||||||
|
# If you configure a ClaudeChatModel / Codex model provider (or an ACP agent)
|
||||||
|
# that reuses your CLI subscription login, prefer passing a token via env over
|
||||||
|
# bind-mounting your whole ~/.claude / ~/.codex into the container. The Gateway
|
||||||
|
# credential loader reads these first, so no directory mount is needed.
|
||||||
|
# CLAUDE_CODE_CREDENTIALS_PATH points at a single .credentials.json (Claude)
|
||||||
|
# rather than the whole dir. docker-compose.cli-auth.yaml is the opt-in
|
||||||
|
# directory-mount fallback for adapters that need the full CLI config.
|
||||||
|
# ACP adapters often take their own env API key (e.g. ANTHROPIC_API_KEY) and
|
||||||
|
# need no mount at all — check the adapter's docs. See SECURITY.md.
|
||||||
|
# CLAUDE_CODE_OAUTH_TOKEN=your-claude-code-oauth-token
|
||||||
|
# ANTHROPIC_AUTH_TOKEN=your-anthropic-auth-token
|
||||||
|
# CLAUDE_CODE_CREDENTIALS_PATH=/path/to/.claude/.credentials.json
|
||||||
|
# CODEX_AUTH_PATH=/path/to/codex/auth.json
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
name: 🐛 Bug report
|
||||||
|
description: Report something that isn't working so maintainers can reproduce and fix it.
|
||||||
|
title: "[bug] "
|
||||||
|
labels: ["bug"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Thanks for taking the time to file a bug. A clear, reproducible report is the
|
||||||
|
single biggest factor in how fast it gets fixed.
|
||||||
|
|
||||||
|
Please fill in every required field — especially **reproduction steps** and **logs**.
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: preflight
|
||||||
|
attributes:
|
||||||
|
label: Before you start
|
||||||
|
options:
|
||||||
|
- label: I searched [existing issues](https://github.com/bytedance/deer-flow/issues?q=is%3Aissue) and this is not a duplicate.
|
||||||
|
required: true
|
||||||
|
- label: I can reproduce this on the latest `main`.
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: summary
|
||||||
|
attributes:
|
||||||
|
label: Problem summary
|
||||||
|
description: One sentence describing the bug.
|
||||||
|
placeholder: e.g. make dev fails to start the gateway service
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: areas
|
||||||
|
attributes:
|
||||||
|
label: Affected area(s)
|
||||||
|
description: Which part of DeerFlow does this touch? Select all that apply.
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Frontend (UI / Next.js)
|
||||||
|
- Backend API (gateway / endpoints / SSE)
|
||||||
|
- Agents / LangGraph (graph, prompts, langgraph.json)
|
||||||
|
- Sandbox / Docker
|
||||||
|
- Skills
|
||||||
|
- MCP
|
||||||
|
- Config / setup (make, config.yaml, env)
|
||||||
|
- Docs
|
||||||
|
- Not sure
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: actual
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: The actual behavior. Include the key error lines verbatim.
|
||||||
|
placeholder: When I do X, I expected Y but I got Z.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: expected
|
||||||
|
attributes:
|
||||||
|
label: Expected behavior
|
||||||
|
placeholder: What did you expect to happen instead?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: Steps to reproduce
|
||||||
|
description: Exact commands and sequence. Minimal steps that reliably reproduce the problem.
|
||||||
|
placeholder: |
|
||||||
|
1. make check
|
||||||
|
2. make install
|
||||||
|
3. make dev
|
||||||
|
4. ...
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: logs
|
||||||
|
attributes:
|
||||||
|
label: Relevant logs
|
||||||
|
description: Paste key lines from logs (for example `logs/gateway.log`, `logs/frontend.log`). Redact secrets.
|
||||||
|
render: shell
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: run_mode
|
||||||
|
attributes:
|
||||||
|
label: How are you running DeerFlow?
|
||||||
|
options:
|
||||||
|
- Local (make dev)
|
||||||
|
- Docker (make docker-start)
|
||||||
|
- CI
|
||||||
|
- Other
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating system
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Other
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: platform_details
|
||||||
|
attributes:
|
||||||
|
label: Platform details
|
||||||
|
description: Architecture and shell, if relevant.
|
||||||
|
placeholder: e.g. arm64, zsh
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: python_version
|
||||||
|
attributes:
|
||||||
|
label: Python version
|
||||||
|
placeholder: e.g. Python 3.12.9
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: node_version
|
||||||
|
attributes:
|
||||||
|
label: Node.js version
|
||||||
|
placeholder: e.g. v22.11.0
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: pnpm_version
|
||||||
|
attributes:
|
||||||
|
label: pnpm version
|
||||||
|
placeholder: e.g. 10.26.2
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: uv_version
|
||||||
|
attributes:
|
||||||
|
label: uv version
|
||||||
|
placeholder: e.g. 0.7.20
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: git_info
|
||||||
|
attributes:
|
||||||
|
label: Git state
|
||||||
|
description: Output of `git branch --show-current` and the latest commit SHA.
|
||||||
|
placeholder: |
|
||||||
|
branch: feature/my-branch
|
||||||
|
commit: abcdef1
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: additional
|
||||||
|
attributes:
|
||||||
|
label: Additional context
|
||||||
|
description: Screenshots, related issues, config snippets (redacted), or anything else that helps triage.
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: 💬 Questions & usage help
|
||||||
|
url: https://github.com/bytedance/deer-flow/discussions/categories/q-a
|
||||||
|
about: "How do I use X? Why does Y behave like that? Ask in Discussions — it gets answered faster and stays searchable."
|
||||||
|
- name: 💡 Ideas & proposals
|
||||||
|
url: https://github.com/bytedance/deer-flow/discussions/categories/ideas
|
||||||
|
about: Have a half-formed idea? Float it in Discussions before opening a formal feature request.
|
||||||
|
- name: 🔒 Report a security vulnerability
|
||||||
|
url: https://github.com/bytedance/deer-flow/security/policy
|
||||||
|
about: Do not open a public issue for security problems. Follow the security policy instead.
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
name: 💡 Feature request
|
||||||
|
description: Propose a new capability or an improvement to an existing one.
|
||||||
|
title: "[feat] "
|
||||||
|
labels: ["enhancement"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Thanks for the suggestion. For non-trivial features, please open a
|
||||||
|
[Discussion](https://github.com/bytedance/deer-flow/discussions/categories/ideas)
|
||||||
|
first to align on scope before writing code.
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: preflight
|
||||||
|
attributes:
|
||||||
|
label: Before you start
|
||||||
|
options:
|
||||||
|
- label: I searched [existing issues](https://github.com/bytedance/deer-flow/issues?q=is%3Aissue) and this is not a duplicate.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: Problem / motivation
|
||||||
|
description: What problem does this solve? What is painful today, or what does it unblock?
|
||||||
|
placeholder: "I'm always frustrated when ..."
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: Describe the change from a user's / caller's perspective.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: areas
|
||||||
|
attributes:
|
||||||
|
label: Affected area(s)
|
||||||
|
description: Which part of DeerFlow would this touch? Select all that apply.
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Frontend (UI / Next.js)
|
||||||
|
- Backend API (gateway / endpoints / SSE)
|
||||||
|
- Agents / LangGraph (graph, prompts, langgraph.json)
|
||||||
|
- Sandbox / Docker
|
||||||
|
- Skills
|
||||||
|
- MCP
|
||||||
|
- Config / setup
|
||||||
|
- Docs
|
||||||
|
- Not sure
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: alternatives
|
||||||
|
attributes:
|
||||||
|
label: Alternatives considered
|
||||||
|
description: Other approaches you weighed and why you discarded them.
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: additional
|
||||||
|
attributes:
|
||||||
|
label: Additional context
|
||||||
|
description: Mockups, links, related issues, or anything else that helps.
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
name: Runtime Information
|
|
||||||
description: Report runtime/environment details to help reproduce an issue.
|
|
||||||
title: "[runtime] "
|
|
||||||
labels:
|
|
||||||
- needs-triage
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: |
|
|
||||||
Thanks for sharing runtime details.
|
|
||||||
Complete this form so maintainers can quickly reproduce and diagnose the problem.
|
|
||||||
|
|
||||||
- type: input
|
|
||||||
id: summary
|
|
||||||
attributes:
|
|
||||||
label: Problem summary
|
|
||||||
description: Short summary of the issue.
|
|
||||||
placeholder: e.g. make dev fails to start gateway service
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: expected
|
|
||||||
attributes:
|
|
||||||
label: Expected behavior
|
|
||||||
placeholder: What did you expect to happen?
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: actual
|
|
||||||
attributes:
|
|
||||||
label: Actual behavior
|
|
||||||
placeholder: What happened instead? Include key error lines.
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: dropdown
|
|
||||||
id: os
|
|
||||||
attributes:
|
|
||||||
label: Operating system
|
|
||||||
options:
|
|
||||||
- macOS
|
|
||||||
- Linux
|
|
||||||
- Windows
|
|
||||||
- Other
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: input
|
|
||||||
id: platform_details
|
|
||||||
attributes:
|
|
||||||
label: Platform details
|
|
||||||
description: Add architecture and shell if relevant.
|
|
||||||
placeholder: e.g. arm64, zsh
|
|
||||||
|
|
||||||
- type: input
|
|
||||||
id: python_version
|
|
||||||
attributes:
|
|
||||||
label: Python version
|
|
||||||
placeholder: e.g. Python 3.12.9
|
|
||||||
|
|
||||||
- type: input
|
|
||||||
id: node_version
|
|
||||||
attributes:
|
|
||||||
label: Node.js version
|
|
||||||
placeholder: e.g. v23.11.0
|
|
||||||
|
|
||||||
- type: input
|
|
||||||
id: pnpm_version
|
|
||||||
attributes:
|
|
||||||
label: pnpm version
|
|
||||||
placeholder: e.g. 10.26.2
|
|
||||||
|
|
||||||
- type: input
|
|
||||||
id: uv_version
|
|
||||||
attributes:
|
|
||||||
label: uv version
|
|
||||||
placeholder: e.g. 0.7.20
|
|
||||||
|
|
||||||
- type: dropdown
|
|
||||||
id: run_mode
|
|
||||||
attributes:
|
|
||||||
label: How are you running DeerFlow?
|
|
||||||
options:
|
|
||||||
- Local (make dev)
|
|
||||||
- Docker (make docker-dev)
|
|
||||||
- CI
|
|
||||||
- Other
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: reproduce
|
|
||||||
attributes:
|
|
||||||
label: Reproduction steps
|
|
||||||
description: Provide exact commands and sequence.
|
|
||||||
placeholder: |
|
|
||||||
1. make check
|
|
||||||
2. make install
|
|
||||||
3. make dev
|
|
||||||
4. ...
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: logs
|
|
||||||
attributes:
|
|
||||||
label: Relevant logs
|
|
||||||
description: Paste key lines from logs (for example logs/gateway.log, logs/frontend.log).
|
|
||||||
render: shell
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: git_info
|
|
||||||
attributes:
|
|
||||||
label: Git state
|
|
||||||
description: Share output of git branch and latest commit SHA.
|
|
||||||
placeholder: |
|
|
||||||
branch: feature/my-branch
|
|
||||||
commit: abcdef1
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: additional
|
|
||||||
attributes:
|
|
||||||
label: Additional context
|
|
||||||
description: Add anything else that might help triage.
|
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
# Declarative label source of truth for DeerFlow.
|
||||||
|
#
|
||||||
|
# This file is the single source of truth for repository labels used by the
|
||||||
|
# auto-labeling workflows (.github/workflows/pr-labeler.yml, pr-triage.yml,
|
||||||
|
# issue-triage.yml). Auto-labelers can only apply labels that already exist,
|
||||||
|
# so every label referenced by a workflow MUST be declared here.
|
||||||
|
#
|
||||||
|
# Apply with: uv run --with pyyaml python scripts/sync_labels.py [--repo OWNER/NAME]
|
||||||
|
# CI keeps it in sync via .github/workflows/label-sync.yml (runs on changes here).
|
||||||
|
#
|
||||||
|
# Sync is additive/update-only: it creates or updates the labels listed below
|
||||||
|
# and never deletes labels that are not listed.
|
||||||
|
#
|
||||||
|
# Color = 6-digit hex without the leading '#'.
|
||||||
|
|
||||||
|
labels:
|
||||||
|
# ── Type ─────────────────────────────────────────────────────────────────
|
||||||
|
# Mostly GitHub defaults; declared here so colors/descriptions stay stable
|
||||||
|
# and so issue templates can rely on them existing.
|
||||||
|
- name: bug
|
||||||
|
color: d73a4a
|
||||||
|
description: Something isn't working
|
||||||
|
- name: enhancement
|
||||||
|
color: a2eeef
|
||||||
|
description: New feature or request
|
||||||
|
- name: documentation
|
||||||
|
color: 0075ca
|
||||||
|
description: Improvements or additions to documentation
|
||||||
|
- name: question
|
||||||
|
color: d876e3
|
||||||
|
description: Further information is requested
|
||||||
|
|
||||||
|
# ── Area (auto, by changed paths — see .github/labeler.yml) ───────────────
|
||||||
|
# Mirrors the "Surface area" section of the pull request template.
|
||||||
|
- name: "area:frontend"
|
||||||
|
color: c5def5
|
||||||
|
description: Next.js frontend under frontend/
|
||||||
|
- name: "area:backend"
|
||||||
|
color: c5def5
|
||||||
|
description: Gateway / runtime / core backend under backend/
|
||||||
|
- name: "area:agents"
|
||||||
|
color: c5def5
|
||||||
|
description: Agents, subagents, graph wiring, prompts, langgraph.json
|
||||||
|
- name: "area:sandbox"
|
||||||
|
color: c5def5
|
||||||
|
description: Sandboxed execution and docker/
|
||||||
|
- name: "area:skills"
|
||||||
|
color: c5def5
|
||||||
|
description: Skills under skills/ or the skills harness
|
||||||
|
- name: "area:mcp"
|
||||||
|
color: c5def5
|
||||||
|
description: Model Context Protocol integration
|
||||||
|
- name: "area:ci"
|
||||||
|
color: c5def5
|
||||||
|
description: GitHub Actions, CI config, repo tooling
|
||||||
|
- name: "area:docs"
|
||||||
|
color: c5def5
|
||||||
|
description: Documentation and Markdown only
|
||||||
|
- name: "area:deps"
|
||||||
|
color: c5def5
|
||||||
|
description: Dependency manifests / lockfiles
|
||||||
|
|
||||||
|
# ── Size (auto, by additions + deletions — see pr-triage.yml) ─────────────
|
||||||
|
- name: "size/XS"
|
||||||
|
color: "009900"
|
||||||
|
description: PR changes < 20 lines
|
||||||
|
- name: "size/S"
|
||||||
|
color: 77bb00
|
||||||
|
description: PR changes 20-100 lines
|
||||||
|
- name: "size/M"
|
||||||
|
color: eebb00
|
||||||
|
description: PR changes 100-300 lines
|
||||||
|
- name: "size/L"
|
||||||
|
color: ee9900
|
||||||
|
description: PR changes 300-700 lines
|
||||||
|
- name: "size/XL"
|
||||||
|
color: ee5500
|
||||||
|
description: PR changes 700+ lines
|
||||||
|
|
||||||
|
# ── Risk (auto, by changed paths — see pr-triage.yml) ─────────────────────
|
||||||
|
- name: "risk:low"
|
||||||
|
color: 0e8a16
|
||||||
|
description: "Low risk: docs / i18n / assets only"
|
||||||
|
- name: "risk:medium"
|
||||||
|
color: fbca04
|
||||||
|
description: "Medium risk: regular code changes"
|
||||||
|
- name: "risk:high"
|
||||||
|
color: b60205
|
||||||
|
description: "High risk: backend API, agents, sandbox, auth, deps, CI"
|
||||||
|
|
||||||
|
# ── Priority (manual) ─────────────────────────────────────────────────────
|
||||||
|
- name: P0
|
||||||
|
color: b60205
|
||||||
|
description: Critical priority
|
||||||
|
- name: P1
|
||||||
|
color: d93f0b
|
||||||
|
description: Major priority
|
||||||
|
- name: P2
|
||||||
|
color: e99695
|
||||||
|
description: Normal priority
|
||||||
|
|
||||||
|
# ── Status (auto + manual) ────────────────────────────────────────────────
|
||||||
|
- name: needs-triage
|
||||||
|
color: fef2c0
|
||||||
|
description: Awaiting maintainer triage
|
||||||
|
- name: needs-validation
|
||||||
|
color: d4c5f9
|
||||||
|
description: Touches front/back contract surface; needs real-path validation
|
||||||
|
- name: skip-validation
|
||||||
|
color: cccccc
|
||||||
|
description: "Maintainer override: do not auto-add needs-validation on this PR"
|
||||||
|
- name: reviewing
|
||||||
|
color: 5319e7
|
||||||
|
description: A maintainer is reviewing this PR
|
||||||
|
|
||||||
|
# ── Contributor ───────────────────────────────────────────────────────────
|
||||||
|
- name: first-time-contributor
|
||||||
|
color: c2e0c6
|
||||||
|
description: First contribution to this repository — be welcoming
|
||||||
@@ -59,3 +59,17 @@ Fixes #
|
|||||||
Frontend: cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test
|
Frontend: cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test
|
||||||
Frontend E2E (if you touched frontend/): cd frontend && make test-e2e -->
|
Frontend E2E (if you touched frontend/): cd frontend && make test-e2e -->
|
||||||
|
|
||||||
|
|
||||||
|
## AI assistance
|
||||||
|
|
||||||
|
<!-- DeerFlow is an AI project — most PRs here use AI coding tools, and that's
|
||||||
|
welcome. Disclosing it just helps reviewers calibrate how closely to read the
|
||||||
|
diff. Please fill all three; don't delete the section. -->
|
||||||
|
|
||||||
|
**Tool(s) used:** <!-- e.g. Claude Code, Cursor, GitHub Copilot, Codex, Windsurf, or "none" -->
|
||||||
|
|
||||||
|
**How you used it:** <!-- e.g. "generated the module from a spec", "autocomplete only",
|
||||||
|
"AI wrote tests, I wrote the impl". A prompt or conversation link is great too. -->
|
||||||
|
|
||||||
|
- [ ] I've read and understand every line of this change and take responsibility for it — it's not unreviewed AI output.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
name: Label Sync
|
||||||
|
|
||||||
|
# Keeps repository labels in sync with the declarative source of truth
|
||||||
|
# (.github/labels.yml). Runs whenever that file changes on main, and can be
|
||||||
|
# triggered manually. Additive/update-only — never deletes labels.
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
paths:
|
||||||
|
- ".github/labels.yml"
|
||||||
|
- "scripts/sync_labels.py"
|
||||||
|
- ".github/workflows/label-sync.yml"
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: label-sync
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
sync:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Sync labels
|
||||||
|
run: uv run --with pyyaml python scripts/sync_labels.py
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GH_REPO: ${{ github.repository }}
|
||||||
@@ -10,7 +10,7 @@ permissions:
|
|||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint-backend:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
name: Replay E2E (front-back contract)
|
||||||
|
|
||||||
|
# Guards the front-back contract via record/replay (no API key in CI):
|
||||||
|
# Layer 1 — backend golden: replay a recorded trace through the real gateway,
|
||||||
|
# assert the SSE event sequence matches the committed golden.
|
||||||
|
# Layer 2 — full-stack render: real Next.js frontend + real gateway (replay
|
||||||
|
# model) + Chromium; assert the replayed turns render in the browser.
|
||||||
|
# Triggered by changes on EITHER side of the contract so a backend change can no
|
||||||
|
# longer pass without the frontend-facing checks running.
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["main"]
|
||||||
|
paths:
|
||||||
|
- "frontend/**"
|
||||||
|
- "backend/app/gateway/**"
|
||||||
|
- "backend/packages/harness/**"
|
||||||
|
- "backend/tests/fixtures/replay/**"
|
||||||
|
- "backend/tests/replay_provider.py"
|
||||||
|
- "backend/tests/_replay_fixture.py"
|
||||||
|
- "backend/tests/seed_runs_router.py"
|
||||||
|
- "backend/tests/test_replay_golden.py"
|
||||||
|
- "backend/scripts/run_replay_gateway.py"
|
||||||
|
- ".github/workflows/replay-e2e.yml"
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
paths:
|
||||||
|
- "frontend/**"
|
||||||
|
- "backend/app/gateway/**"
|
||||||
|
- "backend/packages/harness/**"
|
||||||
|
- "backend/tests/fixtures/replay/**"
|
||||||
|
- "backend/tests/replay_provider.py"
|
||||||
|
- "backend/tests/_replay_fixture.py"
|
||||||
|
- "backend/tests/seed_runs_router.py"
|
||||||
|
- "backend/tests/test_replay_golden.py"
|
||||||
|
- "backend/scripts/run_replay_gateway.py"
|
||||||
|
- ".github/workflows/replay-e2e.yml"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: replay-e2e-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
backend-replay-golden:
|
||||||
|
name: Layer 1 — backend golden (no API key)
|
||||||
|
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 15
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
- name: Install backend dependencies
|
||||||
|
working-directory: backend
|
||||||
|
run: uv sync --group dev
|
||||||
|
- name: Replay golden (backend SSE contract)
|
||||||
|
working-directory: backend
|
||||||
|
run: PYTHONPATH=. uv run pytest tests/test_replay_golden.py -v
|
||||||
|
|
||||||
|
fullstack-replay-render:
|
||||||
|
name: Layer 2 — full-stack render (no API key)
|
||||||
|
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 25
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
- name: Install backend dependencies (replay gateway)
|
||||||
|
working-directory: backend
|
||||||
|
run: uv sync --group dev
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22"
|
||||||
|
- name: Enable Corepack
|
||||||
|
run: corepack enable
|
||||||
|
- name: Use pinned pnpm version
|
||||||
|
run: corepack prepare pnpm@10.26.2 --activate
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
working-directory: frontend
|
||||||
|
run: pnpm install --frozen-lockfile
|
||||||
|
- name: Install Playwright Chromium
|
||||||
|
working-directory: frontend
|
||||||
|
run: npx playwright install chromium --with-deps
|
||||||
|
- name: Full-stack replay render (DOM assertions are the gate)
|
||||||
|
working-directory: frontend
|
||||||
|
run: pnpm exec playwright test -c playwright.real-backend.config.ts
|
||||||
|
- name: Upload report + render artifact
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
if: ${{ !cancelled() }}
|
||||||
|
with:
|
||||||
|
name: replay-render
|
||||||
|
path: |
|
||||||
|
frontend/playwright-report/
|
||||||
|
frontend/test-results/
|
||||||
|
retention-days: 7
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
name: Triage
|
||||||
|
|
||||||
|
# One workflow for all event-driven PR/issue labeling. Replaces the former
|
||||||
|
# pr-labeler / pr-triage / issue-triage workflows (and drops actions/labeler).
|
||||||
|
#
|
||||||
|
# Design notes:
|
||||||
|
# * All jobs are pure-metadata: they read changed-file lists / PR fields / the
|
||||||
|
# review payload via the API and write labels. PR code is NEVER checked out
|
||||||
|
# or executed, so pull_request_target is safe here.
|
||||||
|
# * Each job only reconciles labels in namespaces IT owns
|
||||||
|
# (area:* / size/* / risk:* / needs-validation). It never touches labels
|
||||||
|
# applied by maintainers or other tools (bug, priority, etc.). first-time-
|
||||||
|
# contributor and reviewing are add-only.
|
||||||
|
# * State is read LIVE (listFiles + listLabelsOnIssue) at run time, not from
|
||||||
|
# the (stale) event payload, so rapid synchronize events converge instead
|
||||||
|
# of thrashing.
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
pull_request_review:
|
||||||
|
types: [submitted]
|
||||||
|
issues:
|
||||||
|
types: [opened]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# ── PR: area / size / risk / needs-validation / first-time ─────────────────
|
||||||
|
pr-labels:
|
||||||
|
if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
concurrency:
|
||||||
|
group: triage-pr-${{ github.event.pull_request.number }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
steps:
|
||||||
|
- name: Apply PR labels from live state
|
||||||
|
uses: actions/github-script@v8
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const pr = context.payload.pull_request;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const num = pr.number;
|
||||||
|
|
||||||
|
// ---- live changed files ----
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner, repo, pull_number: num, per_page: 100,
|
||||||
|
});
|
||||||
|
const paths = files.map(f => f.filename);
|
||||||
|
const m = (re) => paths.some(p => re.test(p));
|
||||||
|
|
||||||
|
// ---- area: replaces .github/labeler.yml (path -> area) ----
|
||||||
|
const AREA_RULES = [
|
||||||
|
['area:frontend', [/^frontend\//]],
|
||||||
|
['area:backend', [/^backend\/app\//, /^backend\/packages\/harness\/deerflow\/(runtime|persistence|config|tools|guardrails|tracing|models|utils|uploads)\//]],
|
||||||
|
['area:agents', [/^backend\/packages\/harness\/deerflow\/(agents|subagents|reflection)\//, /(^|\/)langgraph\.json$/, /^backend\/.*\/prompts\//]],
|
||||||
|
['area:sandbox', [/^docker\//, /^backend\/packages\/harness\/deerflow\/sandbox\//, /(^|\/)Dockerfile$/]],
|
||||||
|
['area:skills', [/^skills\//, /^backend\/packages\/harness\/deerflow\/skills\//, /^frontend\/src\/core\/skills\//]],
|
||||||
|
['area:mcp', [/^backend\/packages\/harness\/deerflow\/mcp\//, /^frontend\/src\/core\/mcp\//]],
|
||||||
|
['area:ci', [/^\.github\//, /^scripts\//]],
|
||||||
|
['area:docs', [/^docs\//, /\.mdx?$/]],
|
||||||
|
['area:deps', [/(^|\/)(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/]],
|
||||||
|
];
|
||||||
|
const areaLabels = AREA_RULES
|
||||||
|
.filter(([, res]) => res.some(re => m(re)))
|
||||||
|
.map(([label]) => label);
|
||||||
|
|
||||||
|
// ---- size: additions+deletions, excluding lockfiles/snapshots ----
|
||||||
|
const EXCLUDE_SIZE = /(^|\/)(uv\.lock|pnpm-lock\.yaml|package-lock\.json)$|\.snap$/;
|
||||||
|
const churn = files
|
||||||
|
.filter(f => !EXCLUDE_SIZE.test(f.filename))
|
||||||
|
.reduce((s, f) => s + (f.additions || 0) + (f.deletions || 0), 0);
|
||||||
|
const sizeLabel =
|
||||||
|
churn < 20 ? 'size/XS' :
|
||||||
|
churn < 100 ? 'size/S' :
|
||||||
|
churn < 300 ? 'size/M' :
|
||||||
|
churn < 700 ? 'size/L' : 'size/XL';
|
||||||
|
|
||||||
|
// ---- risk ----
|
||||||
|
const docsOnly = paths.length > 0 && paths.every(p =>
|
||||||
|
/\.(md|mdx|txt)$/i.test(p) || p.startsWith('docs/') ||
|
||||||
|
/\.(png|jpe?g|gif|svg|webp|ico)$/i.test(p));
|
||||||
|
const highRisk =
|
||||||
|
m(/^backend\/app\/gateway\//) ||
|
||||||
|
m(/^backend\/packages\/harness\/deerflow\/(agents|subagents|sandbox)\//) ||
|
||||||
|
m(/(^|\/)langgraph\.json$/) ||
|
||||||
|
m(/(^|\/)(auth|authz|security)/i) ||
|
||||||
|
m(/(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/) ||
|
||||||
|
m(/^docker\//) ||
|
||||||
|
m(/^\.github\/workflows\//);
|
||||||
|
const riskLabel = docsOnly ? 'risk:low' : (highRisk ? 'risk:high' : 'risk:medium');
|
||||||
|
|
||||||
|
// ---- needs-validation: front/back contract surface ----
|
||||||
|
const contract =
|
||||||
|
m(/^backend\/app\/gateway\//) ||
|
||||||
|
m(/^backend\/packages\/harness\/deerflow\/(agents|subagents)\//) ||
|
||||||
|
m(/(^|\/)langgraph\.json$/) ||
|
||||||
|
m(/^frontend\/src\/core\/(api|threads|messages)\//);
|
||||||
|
|
||||||
|
// ---- live current labels (NOT the stale event payload) ----
|
||||||
|
const current = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
|
||||||
|
owner, repo, issue_number: num, per_page: 100,
|
||||||
|
})).map(l => l.name);
|
||||||
|
const hasSkip = current.includes('skip-validation');
|
||||||
|
|
||||||
|
// Reconcile ONLY namespaces we own; never touch others.
|
||||||
|
const owned = (n) =>
|
||||||
|
n.startsWith('area:') || n.startsWith('size/') ||
|
||||||
|
n.startsWith('risk:') || n === 'needs-validation';
|
||||||
|
const desired = new Set([...areaLabels, sizeLabel, riskLabel]);
|
||||||
|
if (contract && !hasSkip) desired.add('needs-validation');
|
||||||
|
|
||||||
|
const toRemove = current.filter(n => owned(n) && !desired.has(n));
|
||||||
|
const toAdd = [...desired].filter(n => !current.includes(n));
|
||||||
|
|
||||||
|
// first-time-contributor: add-only, on opened, real users only.
|
||||||
|
if (context.payload.action === 'opened' &&
|
||||||
|
pr.user.type === 'User' &&
|
||||||
|
['FIRST_TIME_CONTRIBUTOR', 'FIRST_TIMER'].includes(pr.author_association) &&
|
||||||
|
!current.includes('first-time-contributor')) {
|
||||||
|
toAdd.push('first-time-contributor');
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const name of toRemove) {
|
||||||
|
try {
|
||||||
|
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name });
|
||||||
|
} catch (e) {
|
||||||
|
if (e.status !== 404) throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (toAdd.length) {
|
||||||
|
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: toAdd });
|
||||||
|
}
|
||||||
|
core.info(`area=[${areaLabels.join(',')}] ${sizeLabel} ${riskLabel} churn=${churn} ` +
|
||||||
|
`validation=${desired.has('needs-validation')} ` +
|
||||||
|
`(+${toAdd.join(',') || '-'} / -${toRemove.join(',') || '-'})`);
|
||||||
|
|
||||||
|
# ── PR: reviewing label on a maintainer's human review ─────────────────────
|
||||||
|
reviewing:
|
||||||
|
if: github.event_name == 'pull_request_review'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
concurrency:
|
||||||
|
group: triage-review-${{ github.event.pull_request.number }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
steps:
|
||||||
|
- name: Add reviewing label for maintainer reviews
|
||||||
|
uses: actions/github-script@v8
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const num = context.payload.pull_request.number;
|
||||||
|
const review = context.payload.review;
|
||||||
|
const assoc = review.author_association; // payload field; no API call
|
||||||
|
const type = review.user && review.user.type;
|
||||||
|
|
||||||
|
// author_association is NONE for every automated reviewer
|
||||||
|
// (Copilot, CodeRabbit, Codex, Sourcery, ...), so this allowlist
|
||||||
|
// drops them all without a denylist — and never calls the
|
||||||
|
// collaborators API that 404s on "Copilot is not a user".
|
||||||
|
// user.type === 'User' guards the rare bot-added-as-collaborator case.
|
||||||
|
if (!['OWNER', 'MEMBER', 'COLLABORATOR'].includes(assoc) || type !== 'User') {
|
||||||
|
core.info(`reviewer ${review.user && review.user.login} assoc=${assoc} type=${type}; skipping.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const labels = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
|
||||||
|
owner, repo, issue_number: num, per_page: 100,
|
||||||
|
})).map(l => l.name);
|
||||||
|
if (labels.includes('reviewing')) {
|
||||||
|
core.info('Already labeled reviewing; skipping.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
await github.rest.issues.addLabels({
|
||||||
|
owner, repo, issue_number: num, labels: ['reviewing'],
|
||||||
|
});
|
||||||
|
core.info('Added "reviewing".');
|
||||||
|
} catch (e) {
|
||||||
|
if (e.status === 403) core.info('No permission to label (expected on some fork PRs).');
|
||||||
|
else throw e;
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Issue: needs-triage on every new issue ────────────────────────────────
|
||||||
|
issue-triage:
|
||||||
|
if: github.event_name == 'issues'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
concurrency:
|
||||||
|
group: triage-issue-${{ github.event.issue.number }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
steps:
|
||||||
|
- name: Add needs-triage label
|
||||||
|
uses: actions/github-script@v8
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const issue_number = context.payload.issue.number;
|
||||||
|
|
||||||
|
// Read live labels (not the event payload) so labels added at creation
|
||||||
|
// time via the API or by another automation are seen — consistent with
|
||||||
|
// the live-state reads in the PR jobs above.
|
||||||
|
const current = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
|
||||||
|
owner, repo, issue_number, per_page: 100,
|
||||||
|
})).map(l => l.name);
|
||||||
|
if (current.includes('needs-triage')) {
|
||||||
|
core.info('Issue already has needs-triage; nothing to do.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Self-heal: create the label if it does not exist yet.
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createLabel({
|
||||||
|
owner, repo, name: 'needs-triage', color: 'fef2c0',
|
||||||
|
description: 'Awaiting maintainer triage',
|
||||||
|
});
|
||||||
|
} catch (e) {
|
||||||
|
if (e.status !== 422) throw e; // 422 = already exists
|
||||||
|
}
|
||||||
|
await github.rest.issues.addLabels({
|
||||||
|
owner, repo, issue_number, labels: ['needs-triage'],
|
||||||
|
});
|
||||||
|
core.info(`Added needs-triage to #${issue_number}.`);
|
||||||
@@ -287,6 +287,21 @@ Nginx (port 2026) ← Unified entry point
|
|||||||
git push origin feature/your-feature-name
|
git push origin feature/your-feature-name
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## AI assistance disclosure
|
||||||
|
|
||||||
|
DeerFlow is an AI project and we welcome AI-assisted contributions. To help
|
||||||
|
reviewers calibrate how closely to read a change, **every pull request must
|
||||||
|
complete the "AI assistance" section of the
|
||||||
|
[PR template](.github/pull_request_template.md)**:
|
||||||
|
|
||||||
|
- which tool(s) you used (or `none`),
|
||||||
|
- how you used them, and
|
||||||
|
- a confirmation that a human has read, understands, and takes responsibility
|
||||||
|
for the change.
|
||||||
|
|
||||||
|
Please don't delete the section. PRs that ignore it may be asked to fill it in
|
||||||
|
before review.
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -89,36 +89,7 @@ install:
|
|||||||
|
|
||||||
# Pre-pull sandbox Docker image (optional but recommended)
|
# Pre-pull sandbox Docker image (optional but recommended)
|
||||||
setup-sandbox:
|
setup-sandbox:
|
||||||
@echo "=========================================="
|
@$(RUN_WITH_GIT_BASH) ./scripts/setup-sandbox.sh
|
||||||
@echo " Pre-pulling Sandbox Container Image"
|
|
||||||
@echo "=========================================="
|
|
||||||
@echo ""
|
|
||||||
@IMAGE=$$(grep -A 20 "# sandbox:" config.yaml 2>/dev/null | grep "image:" | awk '{print $$2}' | head -1); \
|
|
||||||
if [ -z "$$IMAGE" ]; then \
|
|
||||||
IMAGE="enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"; \
|
|
||||||
echo "Using default image: $$IMAGE"; \
|
|
||||||
else \
|
|
||||||
echo "Using configured image: $$IMAGE"; \
|
|
||||||
fi; \
|
|
||||||
echo ""; \
|
|
||||||
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
|
|
||||||
echo "Detected Apple Container on macOS, pulling image..."; \
|
|
||||||
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
|
|
||||||
fi; \
|
|
||||||
if command -v docker >/dev/null 2>&1; then \
|
|
||||||
echo "Pulling image using Docker..."; \
|
|
||||||
if docker pull "$$IMAGE"; then \
|
|
||||||
echo ""; \
|
|
||||||
echo "✓ Sandbox image pulled successfully"; \
|
|
||||||
else \
|
|
||||||
echo ""; \
|
|
||||||
echo "⚠ Failed to pull sandbox image (this is OK for local sandbox mode)"; \
|
|
||||||
fi; \
|
|
||||||
else \
|
|
||||||
echo "✗ Neither Docker nor Apple Container is available"; \
|
|
||||||
echo " Please install Docker: https://docs.docker.com/get-docker/"; \
|
|
||||||
exit 1; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Start all services in development mode (with hot-reloading)
|
# Start all services in development mode (with hot-reloading)
|
||||||
dev:
|
dev:
|
||||||
@@ -148,7 +119,6 @@ stop:
|
|||||||
clean: stop
|
clean: stop
|
||||||
@echo "Cleaning up..."
|
@echo "Cleaning up..."
|
||||||
@-rm -rf backend/.deer-flow 2>/dev/null || true
|
@-rm -rf backend/.deer-flow 2>/dev/null || true
|
||||||
@-rm -rf backend/.langgraph_api 2>/dev/null || true
|
|
||||||
@-rm -rf logs/*.log 2>/dev/null || true
|
@-rm -rf logs/*.log 2>/dev/null || true
|
||||||
@echo "✓ Cleanup complete"
|
@echo "✓ Cleanup complete"
|
||||||
|
|
||||||
|
|||||||
@@ -247,6 +247,9 @@ Access: http://localhost:2026
|
|||||||
|
|
||||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||||
|
|
||||||
#### Option 2: Local Development
|
#### Option 2: Local Development
|
||||||
@@ -340,6 +343,8 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
|
|||||||
|
|
||||||
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
||||||
|
|
||||||
|
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
|
||||||
|
|
||||||
| Channel | Transport | Difficulty |
|
| Channel | Transport | Difficulty |
|
||||||
|---------|-----------|------------|
|
|---------|-----------|------------|
|
||||||
| Telegram | Bot API (long-polling) | Easy |
|
| Telegram | Bot API (long-polling) | Easy |
|
||||||
@@ -585,6 +590,8 @@ A standard Agent Skill is a structured capability module — a Markdown file tha
|
|||||||
|
|
||||||
Skills are loaded progressively — only when the task needs them, not all at once. This keeps the context window lean and makes DeerFlow work well even with token-sensitive models.
|
Skills are loaded progressively — only when the task needs them, not all at once. This keeps the context window lean and makes DeerFlow work well even with token-sensitive models.
|
||||||
|
|
||||||
|
Users can explicitly activate an enabled skill for a single turn by starting the request with `/skill-name`, for example `/data-analysis analyze uploads/foo.csv`. DeerFlow loads that skill's `SKILL.md` as hidden current-turn context while leaving the base prompt limited to skill metadata. Slash activation respects disabled skills, custom-agent skill whitelists, and existing channel commands such as `/new` and `/help`.
|
||||||
|
|
||||||
When you install `.skill` archives through the Gateway, DeerFlow accepts standard optional frontmatter metadata such as `version`, `author`, and `compatibility` instead of rejecting otherwise valid external skills.
|
When you install `.skill` archives through the Gateway, DeerFlow accepts standard optional frontmatter metadata such as `version`, `author`, and `compatibility` instead of rejecting otherwise valid external skills.
|
||||||
|
|
||||||
Tools follow the same philosophy. DeerFlow comes with a core toolset — web search, web fetch, file operations, bash execution — and supports custom tools via MCP servers and Python functions. Swap anything. Add anything.
|
Tools follow the same philosophy. DeerFlow comes with a core toolset — web search, web fetch, file operations, bash execution — and supports custom tools via MCP servers and Python functions. Swap anything. Add anything.
|
||||||
|
|||||||
+68
@@ -10,3 +10,71 @@ Currently, we have two branches to maintain:
|
|||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
Please go to https://github.com/bytedance/deer-flow/security to report the vulnerability you find.
|
Please go to https://github.com/bytedance/deer-flow/security to report the vulnerability you find.
|
||||||
|
|
||||||
|
## Sandbox Isolation and the Docker Socket (DooD)
|
||||||
|
|
||||||
|
DeerFlow executes agent-generated shell/code through a configurable sandbox
|
||||||
|
(`sandbox.use` in `config.yaml`). The isolation guarantees differ by mode, and
|
||||||
|
one mode requires mounting the host Docker socket. Understand the trade-offs
|
||||||
|
before exposing an instance to untrusted input.
|
||||||
|
|
||||||
|
| Mode | `config.yaml` | Host Docker socket | Isolation |
|
||||||
|
|------|---------------|--------------------|-----------|
|
||||||
|
| `local` (default) | `deerflow.sandbox.local:LocalSandboxProvider` | Not mounted | Commands run **inside the gateway container** on its filesystem. Not a strong boundary — `allow_host_bash` is `false` by default and should stay off for untrusted workloads. |
|
||||||
|
| `aio` (pure DooD) | `deerflow.community.aio_sandbox:AioSandboxProvider` (no `provisioner_url`) | **Mounted** (opt-in overlay) | Sandbox containers are started via the host Docker daemon. |
|
||||||
|
| `provisioner` (Kubernetes) | `AioSandboxProvider` + `provisioner_url` | Not mounted | Sandbox pods are created through the provisioner's K8s API over HTTP. Strongest isolation. |
|
||||||
|
|
||||||
|
### The Docker socket is host root
|
||||||
|
|
||||||
|
Mounting `/var/run/docker.sock` into a container grants that container
|
||||||
|
**root-equivalent control of the host**: anything able to reach the socket can
|
||||||
|
start a new container that bind-mounts the host filesystem and escape. This
|
||||||
|
matters for DeerFlow because the gateway executes model-generated commands, so a
|
||||||
|
prompt injection or any in-container code-execution primitive could pivot to the
|
||||||
|
host through the socket.
|
||||||
|
|
||||||
|
To keep this off the default attack surface:
|
||||||
|
|
||||||
|
- The host Docker socket is **not** mounted by the default Compose stack. It is
|
||||||
|
added only for `aio` mode through the opt-in `docker/docker-compose.dood.yaml`
|
||||||
|
overlay, which `scripts/deploy.sh` and `scripts/docker.sh` append
|
||||||
|
automatically when `detect_sandbox_mode()` returns `aio`.
|
||||||
|
- Prefer **provisioner/Kubernetes mode** for multi-tenant or internet-exposed
|
||||||
|
deployments — it isolates sandboxes without handing the gateway the host
|
||||||
|
daemon.
|
||||||
|
- If you must use `aio`/DooD, treat the host as part of the gateway's trust
|
||||||
|
boundary: run it on a dedicated host, and consider a scoped Docker API proxy
|
||||||
|
instead of the raw socket.
|
||||||
|
|
||||||
|
> Note: the gateway bind-mounts `$HOME/.claude` and `$HOME/.codex` (read-only)
|
||||||
|
> for CLI auto-auth in **all** modes. These hold long-lived CLI credentials;
|
||||||
|
> scope or omit them when the gateway runs untrusted workloads.
|
||||||
|
|
||||||
|
## CLI Credential Mounts (Claude Code / Codex)
|
||||||
|
|
||||||
|
DeerFlow can reuse your Claude Code / Codex CLI subscription login as a model
|
||||||
|
provider (`ClaudeChatModel`, the Codex provider) or for ACP agents that run the
|
||||||
|
CLI in-container. The Compose stack used to bind-mount the **entire** `~/.claude`
|
||||||
|
and `~/.codex` directories (read-only) into the gateway container in **every**
|
||||||
|
configuration — exposing not just credentials but full conversation history,
|
||||||
|
per-project session data, and global CLI config. A gateway compromise (prompt
|
||||||
|
injection, tool/MCP misuse, RCE) would leak all of it.
|
||||||
|
|
||||||
|
These directories are **no longer mounted by default**. Supply CLI credentials
|
||||||
|
with the least exposure that fits your setup:
|
||||||
|
|
||||||
|
| Need | How | Exposure |
|
||||||
|
|------|-----|----------|
|
||||||
|
| Claude model provider | env `CLAUDE_CODE_OAUTH_TOKEN` / `ANTHROPIC_AUTH_TOKEN` (via `.env`), or `CLAUDE_CODE_CREDENTIALS_PATH` → a single mounted `.credentials.json` | none / one file |
|
||||||
|
| Codex model provider | env `CODEX_AUTH_PATH` pointing at a single mounted `auth.json` | one file |
|
||||||
|
| ACP agent | the adapter's own auth — many ACP adapters take an env API key (e.g. `ANTHROPIC_API_KEY` / `OPENAI_API_KEY`) and need no mount; use the opt-in `docker/docker-compose.cli-auth.yaml` overlay only if your adapter reads the full CLI config dir | none / full dir |
|
||||||
|
|
||||||
|
The Gateway credential loader checks environment variables **before** the
|
||||||
|
default credential files, so the env-token paths need no bind mount at all. ACP
|
||||||
|
adapters authenticate independently of DeerFlow via their own documented env —
|
||||||
|
for example the common `claude-code-acp` adapter starts as
|
||||||
|
`ANTHROPIC_API_KEY=… claude-code-acp` and honors `CLAUDE_CONFIG_DIR` to redirect
|
||||||
|
its config directory, so it needs no `~/.claude` mount at all. Prefer the
|
||||||
|
adapter's documented env auth, and reach for the
|
||||||
|
`docker-compose.cli-auth.yaml` overlay only as a fallback for an adapter that
|
||||||
|
genuinely reads the full CLI config directory.
|
||||||
|
|||||||
@@ -24,5 +24,10 @@ config.yaml
|
|||||||
# Langgraph
|
# Langgraph
|
||||||
.langgraph_api
|
.langgraph_api
|
||||||
|
|
||||||
|
# Sandbox runtime working dir — pre-created and excluded from uvicorn reload
|
||||||
|
# (scripts/serve.sh, docker/dev-entrypoint.sh). Anchored so it does not match
|
||||||
|
# the source package backend/packages/harness/deerflow/sandbox/.
|
||||||
|
/sandbox/
|
||||||
|
|
||||||
# Claude Code settings
|
# Claude Code settings
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
|
|||||||
+60
-35
@@ -112,6 +112,14 @@ calls are resolved by function name, so duplicate helper names in one file can
|
|||||||
conservatively over-report async reachability. It is intentionally
|
conservatively over-report async reachability. It is intentionally
|
||||||
informational and is not run from CI in this round.
|
informational and is not run from CI in this round.
|
||||||
|
|
||||||
|
For a diff-scoped view of the same findings, `scripts/scan_changed_blocking_io.py`
|
||||||
|
(repo root) reports findings on the added lines of `git diff <base>...HEAD`
|
||||||
|
plus findings new versus the merge base (so a new async caller exposing an
|
||||||
|
untouched sync helper in the same file is still reported) — used by the
|
||||||
|
`blocking-io-guard` skill (`.agent/skills/blocking-io-guard/`) as the
|
||||||
|
deterministic scope step before routing each candidate to a fix and/or a
|
||||||
|
`tests/blocking_io/` runtime anchor.
|
||||||
|
|
||||||
Regression tests related to Docker/provisioner behavior:
|
Regression tests related to Docker/provisioner behavior:
|
||||||
- `tests/test_docker_sandbox_mode_detection.py` (mode detection from `config.yaml`)
|
- `tests/test_docker_sandbox_mode_detection.py` (mode detection from `config.yaml`)
|
||||||
- `tests/test_provisioner_kubeconfig.py` (kubeconfig file/directory handling)
|
- `tests/test_provisioner_kubeconfig.py` (kubeconfig file/directory handling)
|
||||||
@@ -192,7 +200,7 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
### Middleware Chain
|
### Middleware Chain
|
||||||
|
|
||||||
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
|
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`build_middlewares`):
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
@@ -202,16 +210,17 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
|
|||||||
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
||||||
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
||||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
||||||
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
9. **SkillActivationMiddleware** - Detects strict `/skill-name task` syntax on the latest real user message, resolves only enabled and runtime-allowed skills, reads `SKILL.md` from trusted skill storage, injects the skill body as hidden current-turn model context, and records a `middleware:skill_activation` audit event with skill name, category, path, and content hash
|
||||||
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
10. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
11. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||||
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
12. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
||||||
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
13. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||||
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
14. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||||
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
|
15. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||||
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
16. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`)
|
||||||
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
17. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
||||||
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
18. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
||||||
|
19. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||||
|
|
||||||
### Configuration System
|
### Configuration System
|
||||||
|
|
||||||
@@ -223,17 +232,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
|||||||
|
|
||||||
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
||||||
|
|
||||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
|
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
|
||||||
|
|
||||||
| Field | Why a restart is required |
|
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`, `channel_connections`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
|
||||||
|---|---|
|
|
||||||
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
|
|
||||||
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
|
|
||||||
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
|
|
||||||
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
|
|
||||||
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
|
|
||||||
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
|
|
||||||
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
|
|
||||||
|
|
||||||
Configuration priority:
|
Configuration priority:
|
||||||
1. Explicit `config_path` argument
|
1. Explicit `config_path` argument
|
||||||
@@ -271,7 +272,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
|||||||
| **Uploads** (`/api/threads/{id}/uploads`) | `POST /` - upload files (auto-converts PDF/PPT/Excel/Word); `GET /list` - list; `DELETE /{filename}` - delete |
|
| **Uploads** (`/api/threads/{id}/uploads`) | `POST /` - upload files (auto-converts PDF/PPT/Excel/Word); `GET /list` - list; `DELETE /{filename}` - delete |
|
||||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized and inline reasoning (`<think>...</think>`, including unclosed/truncated blocks from reasoning models like MiniMax-M3) is stripped before JSON parsing |
|
||||||
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
@@ -291,7 +292,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
||||||
**Implementations**:
|
**Implementations**:
|
||||||
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
||||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation. Active-cache and warm-pool entries are checked with the backend during acquire/reuse; definitively dead containers are dropped from all in-process maps so the thread can discover or create a fresh sandbox instead of reusing a stale client. Backend health-check failures are treated as unknown, not dead; local discovery likewise treats an unverifiable container as not adoptable and falls through to create rather than failing acquire. `get()` remains an in-memory lookup for event-loop-safe tool paths.
|
||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
@@ -313,6 +314,8 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` enforced by `SubagentLimitMiddleware` (truncates excess tool calls in `after_model`), 15-minute timeout
|
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` enforced by `SubagentLimitMiddleware` (truncates excess tool calls in `after_model`), 15-minute timeout
|
||||||
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
|
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
|
||||||
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
|
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
|
||||||
|
**Deferred MCP tools** (if `tool_search.enabled`): `SubagentExecutor._build_initial_state` assembles deferral after policy filtering via the shared `assemble_deferred_tools` (fail-closed), appends the `tool_search` tool, injects the `<available-deferred-tools>` section into the subagent's `SystemMessage`, and threads the setup to `_create_agent`, which attaches `DeferredToolFilterMiddleware` through `build_subagent_runtime_middlewares(deferred_setup=...)`. Subagents thus withhold full MCP schemas until promotion, same as the lead agent; each task run gets a fresh `ThreadState` so promotion is isolated per run
|
||||||
|
**Checkpointer isolation**: Subagent graphs are compiled with `checkpointer=False` to avoid inheriting the parent run's checkpointer, since subagents are one-shot and never resume.
|
||||||
|
|
||||||
### Tool System (`packages/harness/deerflow/tools/`)
|
### Tool System (`packages/harness/deerflow/tools/`)
|
||||||
|
|
||||||
@@ -355,6 +358,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
- **Format**: Directory with `SKILL.md` (YAML frontmatter: name, description, license, allowed-tools)
|
- **Format**: Directory with `SKILL.md` (YAML frontmatter: name, description, license, allowed-tools)
|
||||||
- **Loading**: `load_skills()` recursively scans `skills/{public,custom}` for `SKILL.md`, parses metadata, and reads enabled state from extensions_config.json
|
- **Loading**: `load_skills()` recursively scans `skills/{public,custom}` for `SKILL.md`, parses metadata, and reads enabled state from extensions_config.json
|
||||||
- **Injection**: Enabled skills listed in agent system prompt with container paths
|
- **Injection**: Enabled skills listed in agent system prompt with container paths
|
||||||
|
- **Slash activation**: `/skill-name task` loads that enabled skill's `SKILL.md` for the current model call only. The resolver rejects leading whitespace, missing separators, reserved channel commands (`/new`, `/help`, `/bootstrap`, `/status`, `/models`, `/memory`), disabled skills, and skills outside a custom agent's whitelist.
|
||||||
- **Installation**: `POST /api/skills/install` extracts .skill ZIP archive to custom/ directory
|
- **Installation**: `POST /api/skills/install` extracts .skill ZIP archive to custom/ directory
|
||||||
|
|
||||||
### Model Factory (`packages/harness/deerflow/models/factory.py`)
|
### Model Factory (`packages/harness/deerflow/models/factory.py`)
|
||||||
@@ -374,29 +378,32 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
|
|
||||||
### IM Channels System (`app/channels/`)
|
### IM Channels System (`app/channels/`)
|
||||||
|
|
||||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||||
|
|
||||||
|
|
||||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
||||||
- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` → `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations)
|
- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` → `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations)
|
||||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Discord on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu/Telegram incremental outbound updates
|
||||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||||
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `telegram.py` registers the "Working on it..." placeholder as the stream target and edits it in place via `editMessageText`; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||||
|
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
|
||||||
|
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
|
||||||
|
|
||||||
**Message Flow**:
|
**Message Flow**:
|
||||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||||
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
|
||||||
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
4. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||||
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
5. Feishu/Telegram chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||||
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
6. Slack/Discord chat: `runs.wait()` → extract final response → publish outbound
|
||||||
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||||
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
8. Telegram streaming: the "Working on it..." placeholder message is registered as the stream target; non-final updates `editMessageText` it in place (channel-side throttle: 1s in private chats, 3s in groups due to Telegram's 20 msg/min group cap; 4096-char truncation; rate-limited updates dropped); the final update performs the last edit and splits >4096 texts into follow-up messages
|
||||||
9. Outbound → channel callbacks → platform reply
|
9. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||||
|
10. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||||
|
11. Outbound → channel callbacks → platform reply
|
||||||
|
|
||||||
**Configuration** (`config.yaml` -> `channels`):
|
**Configuration** (`config.yaml` -> `channels`):
|
||||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
||||||
@@ -404,6 +411,17 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
||||||
|
|
||||||
|
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
|
||||||
|
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
|
||||||
|
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
|
||||||
|
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect <code>` over their existing outbound channel workers.
|
||||||
|
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
|
||||||
|
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
|
||||||
|
- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user.
|
||||||
|
- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable.
|
||||||
|
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
|
||||||
|
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
|
||||||
|
|
||||||
|
|
||||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||||
|
|
||||||
@@ -434,6 +452,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
|
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
|
||||||
|
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
|
||||||
|
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
|
||||||
|
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
|
||||||
|
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
|
||||||
|
|
||||||
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
@@ -443,6 +467,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
||||||
|
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
|
||||||
|
|
||||||
### Reflection System (`packages/harness/deerflow/reflection/`)
|
### Reflection System (`packages/harness/deerflow/reflection/`)
|
||||||
|
|
||||||
@@ -500,7 +525,7 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
|
|||||||
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
|
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
|
||||||
- `"custom"` — forwarded from `StreamWriter`
|
- `"custom"` — forwarded from `StreamWriter`
|
||||||
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
|
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
|
||||||
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
|
- Agent created lazily via `create_agent()` + `build_middlewares()`, same as `make_lead_agent`
|
||||||
- Supports `checkpointer` parameter for state persistence across turns
|
- Supports `checkpointer` parameter for state persistence across turns
|
||||||
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
||||||
- See [docs/STREAMING.md](docs/STREAMING.md) for the full design: why Gateway and DeerFlowClient are parallel paths, LangGraph's `stream_mode` semantics, the per-id dedup invariants, and regression testing strategy
|
- See [docs/STREAMING.md](docs/STREAMING.md) for the full design: why Gateway and DeerFlowClient are parallel paths, LangGraph's `stream_mode` semantics, the per-id dedup invariants, and regression testing strategy
|
||||||
|
|||||||
+3
-3
@@ -64,7 +64,7 @@ FROM builder AS dev
|
|||||||
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
|
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
|
||||||
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
|
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
|
||||||
|
|
||||||
EXPOSE 8001 2024
|
EXPOSE 8001
|
||||||
|
|
||||||
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
||||||
|
|
||||||
@@ -94,8 +94,8 @@ WORKDIR /app
|
|||||||
# Copy backend with pre-built virtualenv from builder
|
# Copy backend with pre-built virtualenv from builder
|
||||||
COPY --from=builder /app/backend ./backend
|
COPY --from=builder /app/backend ./backend
|
||||||
|
|
||||||
# Expose ports (gateway: 8001, langgraph: 2024)
|
# Expose Gateway API port.
|
||||||
EXPOSE 8001 2024
|
EXPOSE 8001
|
||||||
|
|
||||||
# Default command (can be overridden in docker-compose)
|
# Default command (can be overridden in docker-compose)
|
||||||
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run --no-sync uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run --no-sync uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
||||||
|
|||||||
+1
-1
@@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern:
|
|||||||
Per-thread isolated execution with virtual path translation:
|
Per-thread isolated execution with virtual path translation:
|
||||||
|
|
||||||
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
|
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. `AioSandboxProvider` validates active-cache and warm-pool containers during acquire/reuse, dropping definitively dead entries so a thread can provision a fresh sandbox after an unexpected container exit while keeping `get()` as an in-memory lookup. Backend health-check failures are treated as unknown, not dead, and a container that cannot be verified during discovery is simply not adopted (acquire falls through to create instead of failing).
|
||||||
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
||||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||||
|
|||||||
@@ -18,3 +18,21 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
|||||||
"/help",
|
"/help",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_connect_code(text: str) -> str | None:
|
||||||
|
"""Extract the one-time channel binding code from a connect command."""
|
||||||
|
parts = text.strip().split()
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
command = parts[0].lower()
|
||||||
|
if command in {"/connect", "connect"}:
|
||||||
|
return parts[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_known_channel_command(text: str) -> bool:
|
||||||
|
"""Return whether text starts with a registered channel control command."""
|
||||||
|
if not text.startswith("/"):
|
||||||
|
return False
|
||||||
|
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.channels.message_bus import InboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
async def attach_connection_identity(
|
||||||
|
inbound: InboundMessage,
|
||||||
|
*,
|
||||||
|
repo: Any,
|
||||||
|
provider: str,
|
||||||
|
workspace_id: str | None,
|
||||||
|
fallback_without_workspace: bool = False,
|
||||||
|
) -> InboundMessage:
|
||||||
|
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
||||||
|
if repo is None:
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
workspace_candidates: list[str | None] = []
|
||||||
|
if workspace_id:
|
||||||
|
workspace_candidates.append(workspace_id)
|
||||||
|
if fallback_without_workspace:
|
||||||
|
workspace_candidates.append(None)
|
||||||
|
if not workspace_candidates:
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
for candidate in workspace_candidates:
|
||||||
|
connection = await repo.find_connection_by_external_identity(
|
||||||
|
provider=provider,
|
||||||
|
external_account_id=inbound.user_id,
|
||||||
|
workspace_id=candidate,
|
||||||
|
)
|
||||||
|
if connection is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inbound.connection_id = connection["id"]
|
||||||
|
inbound.owner_user_id = connection["owner_user_id"]
|
||||||
|
inbound.workspace_id = connection.get("workspace_id")
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
return inbound
|
||||||
@@ -14,7 +14,8 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -59,9 +60,7 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _is_dingtalk_command(text: str) -> bool:
|
def _is_dingtalk_command(text: str) -> bool:
|
||||||
if not text.startswith("/"):
|
return is_known_channel_command(text)
|
||||||
return False
|
|
||||||
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_rich_text(rich_text_list: list) -> str:
|
def _extract_text_from_rich_text(rich_text_list: list) -> str:
|
||||||
@@ -138,6 +137,7 @@ class DingTalkChannel(Channel):
|
|||||||
self._incoming_messages: dict[str, Any] = {}
|
self._incoming_messages: dict[str, Any] = {}
|
||||||
self._incoming_messages_lock = threading.Lock()
|
self._incoming_messages_lock = threading.Lock()
|
||||||
self._card_repliers: dict[str, Any] = {}
|
self._card_repliers: dict[str, Any] = {}
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
@@ -397,6 +397,24 @@ class DingTalkChannel(Channel):
|
|||||||
text[:100],
|
text[:100],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
if self._main_loop and self._main_loop.is_running():
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._bind_connection_from_connect_code(
|
||||||
|
conversation_type=conversation_type,
|
||||||
|
sender_staff_id=sender_staff_id,
|
||||||
|
sender_nick=sender_nick,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
code=connect_code,
|
||||||
|
),
|
||||||
|
self._main_loop,
|
||||||
|
)
|
||||||
|
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||||
|
else:
|
||||||
|
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
||||||
|
return
|
||||||
|
|
||||||
if _is_dingtalk_command(text):
|
if _is_dingtalk_command(text):
|
||||||
msg_type = InboundMessageType.COMMAND
|
msg_type = InboundMessageType.COMMAND
|
||||||
else:
|
else:
|
||||||
@@ -452,11 +470,95 @@ class DingTalkChannel(Channel):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
# Running reply must finish before publish_inbound so AI card tracks are
|
# Running reply must finish before publish_inbound so AI card tracks are
|
||||||
# registered before the manager emits streaming outbounds.
|
# registered before the manager emits streaming outbounds.
|
||||||
await self._send_running_reply(chat_id, inbound)
|
await self._send_running_reply(chat_id, inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
|
||||||
|
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
|
||||||
|
return conversation_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
|
||||||
|
conversation_id = str(inbound.metadata.get("conversation_id") or "")
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="dingtalk",
|
||||||
|
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||||
|
fallback_without_workspace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
conversation_type: str,
|
||||||
|
sender_staff_id: str,
|
||||||
|
sender_nick: str,
|
||||||
|
conversation_id: str,
|
||||||
|
code: str,
|
||||||
|
) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(
|
||||||
|
conversation_type,
|
||||||
|
sender_staff_id,
|
||||||
|
conversation_id,
|
||||||
|
"DingTalk connection code is invalid or expired.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not sender_staff_id:
|
||||||
|
await self._send_connection_reply(
|
||||||
|
conversation_type,
|
||||||
|
sender_staff_id,
|
||||||
|
conversation_id,
|
||||||
|
"DingTalk connection could not be completed from this message.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="dingtalk",
|
||||||
|
external_account_id=sender_staff_id,
|
||||||
|
external_account_name=sender_nick or None,
|
||||||
|
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||||
|
metadata={
|
||||||
|
"conversation_type": conversation_type,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(
|
||||||
|
conversation_type,
|
||||||
|
sender_staff_id,
|
||||||
|
conversation_id,
|
||||||
|
"DingTalk connected to DeerFlow.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_connection_reply(
|
||||||
|
self,
|
||||||
|
conversation_type: str,
|
||||||
|
sender_staff_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
text: str,
|
||||||
|
) -> None:
|
||||||
|
robot_code = self._client_id
|
||||||
|
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
||||||
|
if conversation_id:
|
||||||
|
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
||||||
|
return
|
||||||
|
if sender_staff_id:
|
||||||
|
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
||||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -69,6 +71,7 @@ class DiscordChannel(Channel):
|
|||||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._discord_module = None
|
self._discord_module = None
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -202,10 +205,14 @@ class DiscordChannel(Channel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
|
# Keep the file handle open only for the duration of the upload: discord.py
|
||||||
file = self._discord_module.File(fp, filename=attachment.filename)
|
# reads ``fp`` while ``target.send`` runs on ``_discord_loop``; once that
|
||||||
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
# future resolves the bytes are consumed, so closing here is safe and avoids
|
||||||
await asyncio.wrap_future(send_future)
|
# leaking the handle on both the success and failure paths.
|
||||||
|
with open(str(attachment.actual_path), "rb") as fp:
|
||||||
|
file = self._discord_module.File(fp, filename=attachment.filename)
|
||||||
|
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
||||||
|
await asyncio.wrap_future(send_future)
|
||||||
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -286,6 +293,10 @@ class DiscordChannel(Channel):
|
|||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||||
|
return
|
||||||
|
|
||||||
# --- Determine thread/channel routing and typing target ---
|
# --- Determine thread/channel routing and typing target ---
|
||||||
thread_id = None
|
thread_id = None
|
||||||
chat_id = None
|
chat_id = None
|
||||||
@@ -300,7 +311,7 @@ class DiscordChannel(Channel):
|
|||||||
|
|
||||||
# If this is a known active thread, process normally
|
# If this is a known active thread, process normally
|
||||||
if thread_id in self._active_thread_ids:
|
if thread_id in self._active_thread_ids:
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=str(message.author.id),
|
user_id=str(message.author.id),
|
||||||
@@ -314,6 +325,7 @@ class DiscordChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||||
self._publish(inbound)
|
self._publish(inbound)
|
||||||
# Start typing indicator in the thread
|
# Start typing indicator in the thread
|
||||||
if typing_target:
|
if typing_target:
|
||||||
@@ -407,7 +419,7 @@ class DiscordChannel(Channel):
|
|||||||
chat_id = channel_id
|
chat_id = channel_id
|
||||||
typing_target = message.channel # Type into the channel
|
typing_target = message.channel # Type into the channel
|
||||||
|
|
||||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=str(message.author.id),
|
user_id=str(message.author.id),
|
||||||
@@ -421,6 +433,7 @@ class DiscordChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||||
|
|
||||||
# Start typing indicator in the correct target (thread or channel)
|
# Start typing indicator in the correct target (thread or channel)
|
||||||
if typing_target:
|
if typing_target:
|
||||||
@@ -435,6 +448,60 @@ class DiscordChannel(Channel):
|
|||||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="discord",
|
||||||
|
workspace_id=guild_id,
|
||||||
|
fallback_without_workspace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
guild = getattr(message, "guild", None)
|
||||||
|
channel = getattr(message, "channel", None)
|
||||||
|
author = getattr(message, "author", None)
|
||||||
|
user_id = str(getattr(author, "id", "") or "")
|
||||||
|
if not user_id:
|
||||||
|
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
guild_id = str(getattr(guild, "id", "") or "") or None
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="discord",
|
||||||
|
external_account_id=user_id,
|
||||||
|
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
|
||||||
|
workspace_id=guild_id,
|
||||||
|
workspace_name=getattr(guild, "name", None) if guild is not None else None,
|
||||||
|
metadata={
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"channel_id": str(getattr(channel, "id", "") or ""),
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_connection_reply(message, text: str) -> None:
|
||||||
|
channel = getattr(message, "channel", None)
|
||||||
|
send = getattr(channel, "send", None)
|
||||||
|
if send is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await send(text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Discord] failed to send connection reply")
|
||||||
|
|
||||||
def _run_client(self) -> None:
|
def _run_client(self) -> None:
|
||||||
self._discord_loop = asyncio.new_event_loop()
|
self._discord_loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self._discord_loop)
|
asyncio.set_event_loop(self._discord_loop)
|
||||||
|
|||||||
+267
-14
@@ -7,22 +7,31 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
|
from app.channels.message_bus import (
|
||||||
|
PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
|
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
|
InboundMessage,
|
||||||
|
InboundMessageType,
|
||||||
|
MessageBus,
|
||||||
|
OutboundMessage,
|
||||||
|
ResolvedAttachment,
|
||||||
|
)
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60
|
||||||
|
|
||||||
|
|
||||||
def _is_feishu_command(text: str) -> bool:
|
def _is_feishu_command(text: str) -> bool:
|
||||||
if not text.startswith("/"):
|
return is_known_channel_command(text)
|
||||||
return False
|
|
||||||
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
|
||||||
|
|
||||||
|
|
||||||
class FeishuChannel(Channel):
|
class FeishuChannel(Channel):
|
||||||
@@ -56,17 +65,46 @@ class FeishuChannel(Channel):
|
|||||||
self._background_tasks: set[asyncio.Task] = set()
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
self._running_card_ids: dict[str, str] = {}
|
self._running_card_ids: dict[str, str] = {}
|
||||||
self._running_card_tasks: dict[str, asyncio.Task] = {}
|
self._running_card_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._pending_clarifications: dict[tuple[str, str], list[dict[str, Any]]] = {}
|
||||||
self._CreateFileRequest = None
|
self._CreateFileRequest = None
|
||||||
self._CreateFileRequestBody = None
|
self._CreateFileRequestBody = None
|
||||||
self._CreateImageRequest = None
|
self._CreateImageRequest = None
|
||||||
self._CreateImageRequestBody = None
|
self._CreateImageRequestBody = None
|
||||||
self._GetMessageResourceRequest = None
|
self._GetMessageResourceRequest = None
|
||||||
self._thread_lock = threading.Lock()
|
self._thread_lock = threading.Lock()
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _non_empty_str(value: Any) -> str | None:
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
return value.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pending_key(chat_id: str, user_id: str) -> tuple[str, str]:
|
||||||
|
return (chat_id, user_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
if not self._running:
|
||||||
|
return False
|
||||||
|
return self._thread is not None and self._thread.is_alive()
|
||||||
|
|
||||||
|
def _build_event_handler(self, lark):
|
||||||
|
return (
|
||||||
|
lark.EventDispatcherHandler.builder("", "")
|
||||||
|
.register_p2_im_message_receive_v1(self._on_message)
|
||||||
|
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
|
||||||
|
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
|
||||||
|
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
|
||||||
|
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
@@ -160,7 +198,7 @@ class FeishuChannel(Channel):
|
|||||||
# thread's uvloop.
|
# thread's uvloop.
|
||||||
_ws_client_mod.loop = loop
|
_ws_client_mod.loop = loop
|
||||||
|
|
||||||
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
|
event_handler = self._build_event_handler(lark)
|
||||||
ws_client = lark.ws.Client(
|
ws_client = lark.ws.Client(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
app_secret=app_secret,
|
app_secret=app_secret,
|
||||||
@@ -172,6 +210,10 @@ class FeishuChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.exception("Feishu WebSocket error")
|
logger.exception("Feishu WebSocket error")
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def _on_ignored_message_event(self, event) -> None:
|
||||||
|
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -531,18 +573,25 @@ class FeishuChannel(Channel):
|
|||||||
"[Feishu] failed to patch running card %s, falling back to final reply",
|
"[Feishu] failed to patch running card %s, falling back to final reply",
|
||||||
running_card_id,
|
running_card_id,
|
||||||
)
|
)
|
||||||
await self._reply_card(source_message_id, msg.text)
|
fallback_card_id = await self._reply_card(source_message_id, msg.text)
|
||||||
|
self._remember_thread_mapping(msg, source_message_id, fallback_card_id)
|
||||||
|
self._remember_pending_clarification(msg, fallback_card_id)
|
||||||
else:
|
else:
|
||||||
|
self._remember_thread_mapping(msg, source_message_id, running_card_id)
|
||||||
|
self._remember_pending_clarification(msg, running_card_id)
|
||||||
logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id)
|
logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id)
|
||||||
elif msg.is_final:
|
elif msg.is_final:
|
||||||
await self._reply_card(source_message_id, msg.text)
|
final_card_id = await self._reply_card(source_message_id, msg.text)
|
||||||
|
self._remember_thread_mapping(msg, source_message_id, final_card_id)
|
||||||
|
self._remember_pending_clarification(msg, final_card_id)
|
||||||
elif awaited_running_card_task:
|
elif awaited_running_card_task:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation",
|
"[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation",
|
||||||
source_message_id,
|
source_message_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._ensure_running_card(source_message_id, msg.text)
|
created_card_id = await self._ensure_running_card(source_message_id, msg.text)
|
||||||
|
self._remember_thread_mapping(msg, source_message_id, created_card_id)
|
||||||
|
|
||||||
if msg.is_final:
|
if msg.is_final:
|
||||||
self._running_card_ids.pop(source_message_id, None)
|
self._running_card_ids.pop(source_message_id, None)
|
||||||
@@ -553,6 +602,129 @@ class FeishuChannel(Channel):
|
|||||||
|
|
||||||
# -- internal ----------------------------------------------------------
|
# -- internal ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _remember_thread_mapping(self, msg: OutboundMessage, *topic_ids: str | None) -> None:
|
||||||
|
store = self.config.get("channel_store")
|
||||||
|
if store is None or not msg.thread_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
metadata_topic_ids = [
|
||||||
|
msg.metadata.get("message_id"),
|
||||||
|
msg.metadata.get("root_id"),
|
||||||
|
msg.metadata.get("parent_id"),
|
||||||
|
msg.metadata.get("thread_id"),
|
||||||
|
msg.metadata.get("topic_id"),
|
||||||
|
]
|
||||||
|
user_id = ""
|
||||||
|
raw_user_id = msg.metadata.get("user_id")
|
||||||
|
if isinstance(raw_user_id, str):
|
||||||
|
user_id = raw_user_id
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
for topic_id in [*topic_ids, *metadata_topic_ids]:
|
||||||
|
topic_id = self._non_empty_str(topic_id)
|
||||||
|
if not topic_id or topic_id in seen:
|
||||||
|
continue
|
||||||
|
seen.add(topic_id)
|
||||||
|
try:
|
||||||
|
store.set_thread_id(
|
||||||
|
self.name,
|
||||||
|
msg.chat_id,
|
||||||
|
msg.thread_id,
|
||||||
|
topic_id=topic_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Feishu] failed to remember thread mapping for topic_id=%s", topic_id)
|
||||||
|
|
||||||
|
def _remember_pending_clarification(self, msg: OutboundMessage, card_message_id: str | None) -> None:
|
||||||
|
if not msg.is_final or msg.metadata.get(PENDING_CLARIFICATION_METADATA_KEY) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_id = self._non_empty_str(msg.metadata.get("user_id"))
|
||||||
|
topic_id = self._non_empty_str(msg.metadata.get("topic_id"))
|
||||||
|
source_message_id = self._non_empty_str(msg.thread_ts) or self._non_empty_str(msg.metadata.get("message_id"))
|
||||||
|
if not (user_id and topic_id and msg.thread_id and source_message_id and card_message_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
key = self._pending_key(msg.chat_id, user_id)
|
||||||
|
pending = {
|
||||||
|
"thread_id": msg.thread_id,
|
||||||
|
"topic_id": topic_id,
|
||||||
|
"source_message_id": source_message_id,
|
||||||
|
"card_message_id": card_message_id,
|
||||||
|
"created_at": time.time(),
|
||||||
|
}
|
||||||
|
with self._thread_lock:
|
||||||
|
# Plain-message clarification continuity is a short-lived in-memory
|
||||||
|
# hint; explicit Feishu replies are still covered by persisted
|
||||||
|
# message-id mappings.
|
||||||
|
self._pending_clarifications.setdefault(key, []).append(pending)
|
||||||
|
logger.info(
|
||||||
|
"[Feishu] pending clarification remembered: chat_id=%s user_id=%s topic_id=%s thread_id=%s",
|
||||||
|
msg.chat_id,
|
||||||
|
user_id,
|
||||||
|
topic_id,
|
||||||
|
msg.thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _consume_pending_clarification(self, chat_id: str, user_id: str) -> dict[str, Any] | None:
|
||||||
|
key = self._pending_key(chat_id, user_id)
|
||||||
|
with self._thread_lock:
|
||||||
|
pending_items = self._pending_clarifications.get(key)
|
||||||
|
if not pending_items:
|
||||||
|
return None
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
while pending_items:
|
||||||
|
pending = pending_items.pop(0)
|
||||||
|
created_at = pending.get("created_at")
|
||||||
|
if isinstance(created_at, (int, float)) and now - created_at <= PENDING_CLARIFICATION_TTL_SECONDS:
|
||||||
|
if pending_items:
|
||||||
|
self._pending_clarifications[key] = pending_items
|
||||||
|
else:
|
||||||
|
self._pending_clarifications.pop(key, None)
|
||||||
|
return pending
|
||||||
|
logger.info("[Feishu] pending clarification expired: chat_id=%s user_id=%s", chat_id, user_id)
|
||||||
|
|
||||||
|
self._pending_clarifications.pop(key, None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _ensure_pending_thread_mapping(self, chat_id: str, user_id: str, pending: dict[str, Any]) -> None:
|
||||||
|
store = self.config.get("channel_store")
|
||||||
|
topic_id = self._non_empty_str(pending.get("topic_id"))
|
||||||
|
thread_id = self._non_empty_str(pending.get("thread_id"))
|
||||||
|
if store is None or not topic_id or not thread_id:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
store.set_thread_id(self.name, chat_id, thread_id, topic_id=topic_id, user_id=user_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Feishu] failed to restore pending clarification mapping for topic_id=%s", topic_id)
|
||||||
|
|
||||||
|
def _resolve_topic_id(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
msg_id: str,
|
||||||
|
*,
|
||||||
|
root_id: str | None,
|
||||||
|
parent_id: str | None,
|
||||||
|
thread_id: str | None,
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
store = self.config.get("channel_store")
|
||||||
|
candidates = [root_id, parent_id, thread_id]
|
||||||
|
|
||||||
|
if store is not None:
|
||||||
|
for candidate in candidates:
|
||||||
|
candidate = self._non_empty_str(candidate)
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if store.get_thread_id(self.name, chat_id, topic_id=candidate):
|
||||||
|
return candidate, True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Feishu] failed to resolve stored topic mapping for topic_id=%s", candidate)
|
||||||
|
|
||||||
|
return root_id or msg_id, False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _log_future_error(fut, name: str, msg_id: str) -> None:
|
def _log_future_error(fut, name: str, msg_id: str) -> None:
|
||||||
"""Callback for run_coroutine_threadsafe futures to surface errors."""
|
"""Callback for run_coroutine_threadsafe futures to surface errors."""
|
||||||
@@ -577,11 +749,47 @@ class FeishuChannel(Channel):
|
|||||||
|
|
||||||
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
||||||
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
||||||
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
||||||
self._ensure_running_card_started(msg_id)
|
self._ensure_running_card_started(msg_id)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="feishu",
|
||||||
|
workspace_id=inbound.chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not user_id or not chat_id:
|
||||||
|
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="feishu",
|
||||||
|
external_account_id=user_id,
|
||||||
|
workspace_id=chat_id,
|
||||||
|
metadata={
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
def _on_message(self, event) -> None:
|
def _on_message(self, event) -> None:
|
||||||
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
||||||
try:
|
try:
|
||||||
@@ -593,7 +801,9 @@ class FeishuChannel(Channel):
|
|||||||
|
|
||||||
# root_id is set when the message is a reply within a Feishu thread.
|
# root_id is set when the message is a reply within a Feishu thread.
|
||||||
# Use it as topic_id so all replies share the same DeerFlow thread.
|
# Use it as topic_id so all replies share the same DeerFlow thread.
|
||||||
root_id = getattr(message, "root_id", None) or None
|
root_id = self._non_empty_str(getattr(message, "root_id", None))
|
||||||
|
parent_id = self._non_empty_str(getattr(message, "parent_id", None))
|
||||||
|
feishu_thread_id = self._non_empty_str(getattr(message, "thread_id", None))
|
||||||
|
|
||||||
# Parse message content
|
# Parse message content
|
||||||
content = json.loads(message.content)
|
content = json.loads(message.content)
|
||||||
@@ -654,10 +864,12 @@ class FeishuChannel(Channel):
|
|||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, sender=%s, text=%r",
|
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r",
|
||||||
chat_id,
|
chat_id,
|
||||||
msg_id,
|
msg_id,
|
||||||
root_id,
|
root_id,
|
||||||
|
parent_id,
|
||||||
|
feishu_thread_id,
|
||||||
sender_id,
|
sender_id,
|
||||||
text[:100] if text else "",
|
text[:100] if text else "",
|
||||||
)
|
)
|
||||||
@@ -666,6 +878,23 @@ class FeishuChannel(Channel):
|
|||||||
logger.info("[Feishu] empty text, ignoring message")
|
logger.info("[Feishu] empty text, ignoring message")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
if self._main_loop and self._main_loop.is_running():
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._bind_connection_from_connect_code(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_id=sender_id,
|
||||||
|
code=connect_code,
|
||||||
|
),
|
||||||
|
self._main_loop,
|
||||||
|
)
|
||||||
|
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||||
|
else:
|
||||||
|
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
|
||||||
|
return
|
||||||
|
|
||||||
# Only treat known slash commands as commands; absolute paths and
|
# Only treat known slash commands as commands; absolute paths and
|
||||||
# other slash-prefixed text should be handled as normal chat.
|
# other slash-prefixed text should be handled as normal chat.
|
||||||
if _is_feishu_command(text):
|
if _is_feishu_command(text):
|
||||||
@@ -673,8 +902,24 @@ class FeishuChannel(Channel):
|
|||||||
else:
|
else:
|
||||||
msg_type = InboundMessageType.CHAT
|
msg_type = InboundMessageType.CHAT
|
||||||
|
|
||||||
# topic_id: use root_id for replies (same topic), msg_id for new messages (new topic)
|
# Prefer any platform message id that already maps to a DeerFlow
|
||||||
topic_id = root_id or msg_id
|
# thread. This keeps replies to bot clarification cards in the
|
||||||
|
# original conversation even when Feishu reports the card as root.
|
||||||
|
topic_id, resolved_from_stored_mapping = self._resolve_topic_id(
|
||||||
|
chat_id,
|
||||||
|
msg_id,
|
||||||
|
root_id=root_id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
thread_id=feishu_thread_id,
|
||||||
|
)
|
||||||
|
resolved_from_pending = False
|
||||||
|
if msg_type == InboundMessageType.CHAT and not resolved_from_stored_mapping:
|
||||||
|
pending = self._consume_pending_clarification(chat_id, sender_id)
|
||||||
|
pending_topic_id = self._non_empty_str(pending.get("topic_id")) if pending else None
|
||||||
|
if pending_topic_id:
|
||||||
|
topic_id = pending_topic_id
|
||||||
|
self._ensure_pending_thread_mapping(chat_id, sender_id, pending)
|
||||||
|
resolved_from_pending = True
|
||||||
|
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
@@ -683,7 +928,15 @@ class FeishuChannel(Channel):
|
|||||||
msg_type=msg_type,
|
msg_type=msg_type,
|
||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
files=files_list,
|
files=files_list,
|
||||||
metadata={"message_id": msg_id, "root_id": root_id},
|
metadata={
|
||||||
|
"message_id": msg_id,
|
||||||
|
"root_id": root_id,
|
||||||
|
"parent_id": parent_id,
|
||||||
|
"thread_id": feishu_thread_id,
|
||||||
|
"topic_id": topic_id,
|
||||||
|
"user_id": sender_id,
|
||||||
|
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY: resolved_from_pending,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
|
|
||||||
|
|||||||
+360
-47
@@ -8,6 +8,7 @@ import mimetypes
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable, Mapping
|
from collections.abc import Awaitable, Callable, Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -15,11 +16,24 @@ import httpx
|
|||||||
from langgraph_sdk.errors import ConflictError
|
from langgraph_sdk.errors import ConflictError
|
||||||
|
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import (
|
||||||
|
PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
|
InboundMessage,
|
||||||
|
InboundMessageType,
|
||||||
|
MessageBus,
|
||||||
|
OutboundMessage,
|
||||||
|
ResolvedAttachment,
|
||||||
|
)
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
||||||
from app.gateway.internal_auth import create_internal_auth_headers
|
from app.gateway.internal_auth import create_internal_auth_headers
|
||||||
|
from deerflow.config.agents_config import load_agent_config
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
from deerflow.skills.slash import parse_slash_skill_reference
|
||||||
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
|
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||||
|
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,6 +49,11 @@ DEFAULT_RUN_CONTEXT: dict[str, Any] = {
|
|||||||
"subagent_enabled": False,
|
"subagent_enabled": False,
|
||||||
}
|
}
|
||||||
STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
||||||
|
# Stream modes requested from the runtime, and the SSE event names under which
|
||||||
|
# the message-tuple stream may arrive: the embedded runtime (and LangGraph
|
||||||
|
# Platform) deliver the requested "messages-tuple" mode as event "messages".
|
||||||
|
STREAM_MODES = ["messages-tuple", "values"]
|
||||||
|
MESSAGE_STREAM_EVENTS = ("messages-tuple", "messages")
|
||||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||||
|
|
||||||
CHANNEL_CAPABILITIES = {
|
CHANNEL_CAPABILITIES = {
|
||||||
@@ -42,7 +61,7 @@ CHANNEL_CAPABILITIES = {
|
|||||||
"discord": {"supports_streaming": False},
|
"discord": {"supports_streaming": False},
|
||||||
"feishu": {"supports_streaming": True},
|
"feishu": {"supports_streaming": True},
|
||||||
"slack": {"supports_streaming": False},
|
"slack": {"supports_streaming": False},
|
||||||
"telegram": {"supports_streaming": False},
|
"telegram": {"supports_streaming": True},
|
||||||
"wechat": {"supports_streaming": False},
|
"wechat": {"supports_streaming": False},
|
||||||
"wecom": {"supports_streaming": True},
|
"wecom": {"supports_streaming": True},
|
||||||
}
|
}
|
||||||
@@ -116,6 +135,16 @@ class InvalidChannelSessionConfigError(ValueError):
|
|||||||
"""Raised when IM channel session overrides contain invalid agent config."""
|
"""Raised when IM channel session overrides contain invalid agent config."""
|
||||||
|
|
||||||
|
|
||||||
|
class SlashSkillCommandResolutionError(RuntimeError):
|
||||||
|
"""Raised when IM slash-skill command resolution cannot complete safely."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _SlashSkillCommandResolution:
|
||||||
|
route_to_chat: bool = False
|
||||||
|
failure_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def _is_thread_busy_error(exc: BaseException | None) -> bool:
|
def _is_thread_busy_error(exc: BaseException | None) -> bool:
|
||||||
if exc is None:
|
if exc is None:
|
||||||
return False
|
return False
|
||||||
@@ -202,6 +231,70 @@ def _extract_response_text(result: dict | list) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_from_result(result: dict | list) -> list[Any]:
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
if isinstance(result, dict):
|
||||||
|
messages = result.get("messages", [])
|
||||||
|
if isinstance(messages, list):
|
||||||
|
return messages
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _current_turn_messages(result: dict | list) -> list[dict[str, Any]]:
|
||||||
|
messages = _messages_from_result(result)
|
||||||
|
current_turn: list[dict[str, Any]] = []
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
continue
|
||||||
|
if msg.get("type") == "human":
|
||||||
|
break
|
||||||
|
current_turn.append(msg)
|
||||||
|
current_turn.reverse()
|
||||||
|
return current_turn
|
||||||
|
|
||||||
|
|
||||||
|
def _has_current_turn_clarification(result: dict | list) -> bool:
|
||||||
|
"""Return True only when the current turn's final result is clarification."""
|
||||||
|
for msg in reversed(_current_turn_messages(result)):
|
||||||
|
msg_type = msg.get("type")
|
||||||
|
if msg_type == "tool":
|
||||||
|
return msg.get("name") == "ask_clarification"
|
||||||
|
if msg_type == "ai":
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
if content:
|
||||||
|
return False
|
||||||
|
elif content:
|
||||||
|
return False
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification: bool = False) -> dict[str, Any]:
|
||||||
|
metadata = _slim_metadata(base_metadata)
|
||||||
|
if pending_clarification:
|
||||||
|
metadata[PENDING_CLARIFICATION_METADATA_KEY] = True
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
|
||||||
|
channel_source: dict[str, Any] = {
|
||||||
|
"type": "im_channel",
|
||||||
|
"provider": msg.channel_name,
|
||||||
|
"chat_id": msg.chat_id,
|
||||||
|
}
|
||||||
|
if msg.topic_id:
|
||||||
|
channel_source["topic_id"] = msg.topic_id
|
||||||
|
if msg.thread_ts:
|
||||||
|
channel_source["thread_ts"] = msg.thread_ts
|
||||||
|
if msg.connection_id:
|
||||||
|
channel_source["connection_id"] = msg.connection_id
|
||||||
|
|
||||||
|
return {"channel_source": channel_source}
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_content(content: Any) -> str:
|
def _extract_text_content(content: Any) -> str:
|
||||||
"""Extract text from a streaming payload content field."""
|
"""Extract text from a streaming payload content field."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -354,6 +447,83 @@ def _format_artifact_text(artifacts: list[str]) -> str:
|
|||||||
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
||||||
|
|
||||||
|
|
||||||
|
def _unknown_command_reply(command: str | None = None) -> str:
|
||||||
|
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
|
||||||
|
if command:
|
||||||
|
return f"Unknown command: /{command}. Available commands: {available}"
|
||||||
|
return f"Unknown command. Available commands: {available}"
|
||||||
|
|
||||||
|
|
||||||
|
def _human_input_message(content: str, *, original_content: str | None = None) -> dict[str, Any]:
|
||||||
|
message: dict[str, Any] = {"role": "human", "content": content}
|
||||||
|
if original_content is not None and original_content != content:
|
||||||
|
message["additional_kwargs"] = {ORIGINAL_USER_CONTENT_KEY: original_content}
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_disabled_owner_user_id() -> str | None:
|
||||||
|
try:
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
||||||
|
return None
|
||||||
|
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
||||||
|
|
||||||
|
|
||||||
|
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
||||||
|
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if owner_user_id:
|
||||||
|
msg.owner_user_id = owner_user_id
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if not owner_user_id:
|
||||||
|
return None
|
||||||
|
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_user_id_for_run(raw_user_id: str) -> str:
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
|
||||||
|
try:
|
||||||
|
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to prepare channel run user directory")
|
||||||
|
return make_safe_user_id(raw_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_slash_skill_command(
|
||||||
|
text: str,
|
||||||
|
available_skills: set[str] | None = None,
|
||||||
|
storage: SkillStorage | Callable[[], SkillStorage] | None = None,
|
||||||
|
) -> _SlashSkillCommandResolution | None:
|
||||||
|
reference = parse_slash_skill_reference(text)
|
||||||
|
if reference is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
resolved_storage = storage() if callable(storage) else storage or get_or_new_skill_storage()
|
||||||
|
skills = resolved_storage.load_skills(enabled_only=False)
|
||||||
|
|
||||||
|
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
|
||||||
|
if skill is None:
|
||||||
|
return None
|
||||||
|
if not skill.enabled:
|
||||||
|
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
|
||||||
|
if available_skills is not None and reference.name not in available_skills:
|
||||||
|
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
|
||||||
|
|
||||||
|
return _SlashSkillCommandResolution(route_to_chat=True)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("[Manager] failed to resolve slash skill command")
|
||||||
|
raise SlashSkillCommandResolutionError("Failed to resolve slash skill command. Please check the skill configuration.") from exc
|
||||||
|
|
||||||
|
|
||||||
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
||||||
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
||||||
|
|
||||||
@@ -443,8 +613,14 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
write_upload_file_no_symlink,
|
write_upload_file_no_symlink,
|
||||||
)
|
)
|
||||||
|
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
def _prepare_uploads_dir() -> tuple[Path, set[str]]:
|
||||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
# Worker thread: ensure_uploads_dir's mkdir and the iterdir enumeration are
|
||||||
|
# blocking filesystem IO that must stay off the event loop.
|
||||||
|
target = ensure_uploads_dir(thread_id)
|
||||||
|
existing = {entry.name for entry in target.iterdir() if entry.is_file()}
|
||||||
|
return target, existing
|
||||||
|
|
||||||
|
uploads_dir, seen_names = await asyncio.to_thread(_prepare_uploads_dir)
|
||||||
|
|
||||||
created: list[dict[str, Any]] = []
|
created: list[dict[str, Any]] = []
|
||||||
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
||||||
@@ -492,7 +668,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
|
|
||||||
dest = uploads_dir / safe_name
|
dest = uploads_dir / safe_name
|
||||||
try:
|
try:
|
||||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
dest = await asyncio.to_thread(write_upload_file_no_symlink, uploads_dir, safe_name, data)
|
||||||
except UnsafeUploadPathError:
|
except UnsafeUploadPathError:
|
||||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
||||||
continue
|
continue
|
||||||
@@ -558,6 +734,7 @@ class ChannelManager:
|
|||||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||||
default_session: dict[str, Any] | None = None,
|
default_session: dict[str, Any] | None = None,
|
||||||
channel_sessions: dict[str, Any] | None = None,
|
channel_sessions: dict[str, Any] | None = None,
|
||||||
|
connection_repo: Any | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.store = store
|
self.store = store
|
||||||
@@ -567,7 +744,10 @@ class ChannelManager:
|
|||||||
self._assistant_id = assistant_id
|
self._assistant_id = assistant_id
|
||||||
self._default_session = _as_dict(default_session)
|
self._default_session = _as_dict(default_session)
|
||||||
self._channel_sessions = dict(channel_sessions or {})
|
self._channel_sessions = dict(channel_sessions or {})
|
||||||
|
self._connection_repo = connection_repo
|
||||||
self._client = None # lazy init — langgraph_sdk async client
|
self._client = None # lazy init — langgraph_sdk async client
|
||||||
|
self._channel_metadata_synced: set[str] = set()
|
||||||
|
self._skill_storage: SkillStorage | None = None
|
||||||
self._csrf_token = generate_csrf_token()
|
self._csrf_token = generate_csrf_token()
|
||||||
self._semaphore: asyncio.Semaphore | None = None
|
self._semaphore: asyncio.Semaphore | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -615,12 +795,25 @@ class ChannelManager:
|
|||||||
configurable["checkpoint_ns"] = ""
|
configurable["checkpoint_ns"] = ""
|
||||||
configurable["thread_id"] = thread_id
|
configurable["thread_id"] = thread_id
|
||||||
|
|
||||||
|
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
||||||
|
# For browser-connected IM channels, prefer the DeerFlow account that
|
||||||
|
# owns the connection. Preserve the raw platform user under
|
||||||
|
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||||
|
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if owner_user_id:
|
||||||
|
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
|
||||||
|
elif msg.user_id:
|
||||||
|
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
|
||||||
|
if msg.user_id:
|
||||||
|
run_context_identity["channel_user_id"] = msg.user_id
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
DEFAULT_RUN_CONTEXT,
|
DEFAULT_RUN_CONTEXT,
|
||||||
self._default_session.get("context"),
|
self._default_session.get("context"),
|
||||||
channel_layer.get("context"),
|
channel_layer.get("context"),
|
||||||
user_layer.get("context"),
|
user_layer.get("context"),
|
||||||
{"thread_id": thread_id},
|
run_context_identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Custom agents are implemented as lead_agent + agent_name context.
|
# Custom agents are implemented as lead_agent + agent_name context.
|
||||||
@@ -632,6 +825,21 @@ class ChannelManager:
|
|||||||
|
|
||||||
return assistant_id, run_config, run_context
|
return assistant_id, run_config, run_context
|
||||||
|
|
||||||
|
def _resolve_available_skill_names(self, msg: InboundMessage) -> set[str] | None:
|
||||||
|
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or ""
|
||||||
|
_, _, run_context = self._resolve_run_params(msg, thread_id)
|
||||||
|
if run_context.get("is_bootstrap"):
|
||||||
|
return {"bootstrap"}
|
||||||
|
|
||||||
|
agent_name = run_context.get("agent_name")
|
||||||
|
if not isinstance(agent_name, str) or not agent_name.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_config = load_agent_config(_normalize_custom_agent_name(agent_name))
|
||||||
|
if agent_config and agent_config.skills is not None:
|
||||||
|
return set(agent_config.skills)
|
||||||
|
return None
|
||||||
|
|
||||||
# -- LangGraph SDK client (lazy) ----------------------------------------
|
# -- LangGraph SDK client (lazy) ----------------------------------------
|
||||||
|
|
||||||
def _get_client(self):
|
def _get_client(self):
|
||||||
@@ -649,6 +857,11 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def _get_skill_storage(self) -> SkillStorage:
|
||||||
|
if self._skill_storage is None:
|
||||||
|
self._skill_storage = get_or_new_skill_storage()
|
||||||
|
return self._skill_storage
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
# -- lifecycle ---------------------------------------------------------
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -704,6 +917,7 @@ class ChannelManager:
|
|||||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||||
|
|
||||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||||
|
msg = _apply_effective_owner(msg)
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
if msg.msg_type == InboundMessageType.COMMAND:
|
if msg.msg_type == InboundMessageType.COMMAND:
|
||||||
@@ -718,6 +932,14 @@ class ChannelManager:
|
|||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
await self._send_error(msg, str(exc))
|
await self._send_error(msg, str(exc))
|
||||||
|
except SlashSkillCommandResolutionError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Slash skill command resolution failed for %s (chat=%s): %s",
|
||||||
|
msg.channel_name,
|
||||||
|
msg.chat_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await self._send_error(msg, str(exc))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Error handling message from %s (chat=%s)",
|
"Error handling message from %s (chat=%s)",
|
||||||
@@ -728,10 +950,27 @@ class ChannelManager:
|
|||||||
|
|
||||||
# -- chat handling -----------------------------------------------------
|
# -- chat handling -----------------------------------------------------
|
||||||
|
|
||||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
||||||
"""Create a new thread through Gateway and store the mapping."""
|
if msg.connection_id and self._connection_repo is not None:
|
||||||
thread = await client.threads.create()
|
return await self._connection_repo.get_thread_id(
|
||||||
thread_id = thread["thread_id"]
|
msg.connection_id,
|
||||||
|
msg.chat_id,
|
||||||
|
msg.topic_id,
|
||||||
|
)
|
||||||
|
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||||
|
|
||||||
|
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
||||||
|
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
||||||
|
await self._connection_repo.set_thread_id(
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
|
provider=msg.channel_name,
|
||||||
|
external_conversation_id=msg.chat_id,
|
||||||
|
external_topic_id=msg.topic_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
self.store.set_thread_id(
|
self.store.set_thread_id(
|
||||||
msg.channel_name,
|
msg.channel_name,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
@@ -739,18 +978,49 @@ class ChannelManager:
|
|||||||
topic_id=msg.topic_id,
|
topic_id=msg.topic_id,
|
||||||
user_id=msg.user_id,
|
user_id=msg.user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||||
|
"""Create a new thread through Gateway and store the mapping."""
|
||||||
|
metadata = _thread_channel_metadata(msg)
|
||||||
|
owner_headers = _owner_headers(msg)
|
||||||
|
if owner_headers:
|
||||||
|
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
|
||||||
|
else:
|
||||||
|
thread = await client.threads.create(metadata=metadata)
|
||||||
|
thread_id = thread["thread_id"]
|
||||||
|
await self._store_thread_id(msg, thread_id)
|
||||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
|
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
|
||||||
|
"""Best-effort source metadata backfill for existing IM-created threads."""
|
||||||
|
# The metadata (provider/chat/topic) is constant for a thread, so one
|
||||||
|
# successful backfill per manager lifetime is enough — skip the
|
||||||
|
# redundant PATCH on every subsequent inbound message.
|
||||||
|
if thread_id in self._channel_metadata_synced:
|
||||||
|
return
|
||||||
|
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
update_kwargs["headers"] = owner_headers
|
||||||
|
try:
|
||||||
|
await client.threads.update(thread_id, **update_kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
|
||||||
|
return
|
||||||
|
if len(self._channel_metadata_synced) > 4096:
|
||||||
|
self._channel_metadata_synced.clear()
|
||||||
|
self._channel_metadata_synced.add(thread_id)
|
||||||
|
|
||||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
||||||
# Look up existing DeerFlow thread.
|
# Look up existing DeerFlow thread.
|
||||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
thread_id = await self._lookup_thread_id(msg)
|
||||||
if thread_id:
|
if thread_id:
|
||||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||||
|
await self._update_thread_channel_metadata(client, msg, thread_id)
|
||||||
|
|
||||||
# No existing thread found — create a new one
|
# No existing thread found — create a new one
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
@@ -772,9 +1042,11 @@ class ChannelManager:
|
|||||||
if extra_context:
|
if extra_context:
|
||||||
run_context.update(extra_context)
|
run_context.update(extra_context)
|
||||||
|
|
||||||
|
original_text = msg.text
|
||||||
uploaded = await _ingest_inbound_files(thread_id, msg)
|
uploaded = await _ingest_inbound_files(thread_id, msg)
|
||||||
if uploaded:
|
if uploaded:
|
||||||
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
|
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
|
||||||
|
human_message = _human_input_message(msg.text, original_content=original_text)
|
||||||
|
|
||||||
if self._channel_supports_streaming(msg.channel_name):
|
if self._channel_supports_streaming(msg.channel_name):
|
||||||
await self._handle_streaming_chat(
|
await self._handle_streaming_chat(
|
||||||
@@ -784,18 +1056,24 @@ class ChannelManager:
|
|||||||
assistant_id,
|
assistant_id,
|
||||||
run_config,
|
run_config,
|
||||||
run_context,
|
run_context,
|
||||||
|
human_message,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
|
run_kwargs: dict[str, Any] = {
|
||||||
|
"input": {"messages": [human_message]},
|
||||||
|
"config": run_config,
|
||||||
|
"context": run_context,
|
||||||
|
"multitask_strategy": "reject",
|
||||||
|
}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
run_kwargs["headers"] = owner_headers
|
||||||
try:
|
try:
|
||||||
result = await client.runs.wait(
|
result = await client.runs.wait(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
**run_kwargs,
|
||||||
config=run_config,
|
|
||||||
context=run_context,
|
|
||||||
multitask_strategy="reject",
|
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if _is_thread_busy_error(exc):
|
if _is_thread_busy_error(exc):
|
||||||
@@ -806,6 +1084,7 @@ class ChannelManager:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
|
pending_clarification = _has_current_turn_clarification(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -831,7 +1110,9 @@ class ChannelManager:
|
|||||||
artifacts=artifacts,
|
artifacts=artifacts,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
|
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||||
)
|
)
|
||||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
@@ -844,6 +1125,7 @@ class ChannelManager:
|
|||||||
assistant_id: str,
|
assistant_id: str,
|
||||||
run_config: dict[str, Any],
|
run_config: dict[str, Any],
|
||||||
run_context: dict[str, Any],
|
run_context: dict[str, Any],
|
||||||
|
human_message: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
|
|
||||||
@@ -854,21 +1136,26 @@ class ChannelManager:
|
|||||||
last_published_text = ""
|
last_published_text = ""
|
||||||
last_publish_at = 0.0
|
last_publish_at = 0.0
|
||||||
stream_error: BaseException | None = None
|
stream_error: BaseException | None = None
|
||||||
|
stream_kwargs: dict[str, Any] = {
|
||||||
|
"input": {"messages": [human_message]},
|
||||||
|
"config": run_config,
|
||||||
|
"context": run_context,
|
||||||
|
"stream_mode": list(STREAM_MODES),
|
||||||
|
"multitask_strategy": "reject",
|
||||||
|
}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
stream_kwargs["headers"] = owner_headers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in client.runs.stream(
|
async for chunk in client.runs.stream(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
**stream_kwargs,
|
||||||
config=run_config,
|
|
||||||
context=run_context,
|
|
||||||
stream_mode=["messages-tuple", "values"],
|
|
||||||
multitask_strategy="reject",
|
|
||||||
):
|
):
|
||||||
event = getattr(chunk, "event", "")
|
event = getattr(chunk, "event", "")
|
||||||
data = getattr(chunk, "data", None)
|
data = getattr(chunk, "data", None)
|
||||||
|
|
||||||
if event == "messages-tuple":
|
if event in MESSAGE_STREAM_EVENTS:
|
||||||
accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data)
|
accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data)
|
||||||
if accumulated_text:
|
if accumulated_text:
|
||||||
latest_text = accumulated_text
|
latest_text = accumulated_text
|
||||||
@@ -893,7 +1180,9 @@ class ChannelManager:
|
|||||||
text=latest_text,
|
text=latest_text,
|
||||||
is_final=False,
|
is_final=False,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
|
metadata=_response_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
last_published_text = latest_text
|
last_published_text = latest_text
|
||||||
@@ -907,6 +1196,7 @@ class ChannelManager:
|
|||||||
finally:
|
finally:
|
||||||
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
|
pending_clarification = _has_current_turn_clarification(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
||||||
|
|
||||||
@@ -938,18 +1228,29 @@ class ChannelManager:
|
|||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
is_final=True,
|
is_final=True,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
|
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# -- command handling --------------------------------------------------
|
# -- command handling --------------------------------------------------
|
||||||
|
|
||||||
async def _handle_command(self, msg: InboundMessage) -> None:
|
async def _handle_command(self, msg: InboundMessage) -> None:
|
||||||
text = msg.text.strip()
|
raw_text = msg.text
|
||||||
|
text = raw_text.strip()
|
||||||
parts = text.split(maxsplit=1)
|
parts = text.split(maxsplit=1)
|
||||||
command = parts[0].lower().lstrip("/")
|
reply: str | None = None
|
||||||
|
if not parts:
|
||||||
|
command = None
|
||||||
|
reply = _unknown_command_reply()
|
||||||
|
else:
|
||||||
|
command = parts[0].lower().removeprefix("/")
|
||||||
|
|
||||||
if command == "bootstrap":
|
if reply is None and not raw_text.startswith("/"):
|
||||||
|
reply = _unknown_command_reply(command)
|
||||||
|
|
||||||
|
if reply is None and command == "bootstrap":
|
||||||
from dataclasses import replace as _dc_replace
|
from dataclasses import replace as _dc_replace
|
||||||
|
|
||||||
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
|
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
|
||||||
@@ -957,27 +1258,19 @@ class ChannelManager:
|
|||||||
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
|
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
|
||||||
return
|
return
|
||||||
|
|
||||||
if command == "new":
|
if reply is None and command == "new":
|
||||||
# Create a new thread through Gateway
|
# Create a new thread through Gateway
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
thread = await client.threads.create()
|
await self._create_thread(client, msg)
|
||||||
new_thread_id = thread["thread_id"]
|
|
||||||
self.store.set_thread_id(
|
|
||||||
msg.channel_name,
|
|
||||||
msg.chat_id,
|
|
||||||
new_thread_id,
|
|
||||||
topic_id=msg.topic_id,
|
|
||||||
user_id=msg.user_id,
|
|
||||||
)
|
|
||||||
reply = "New conversation started."
|
reply = "New conversation started."
|
||||||
elif command == "status":
|
elif reply is None and command == "status":
|
||||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
thread_id = await self._lookup_thread_id(msg)
|
||||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||||
elif command == "models":
|
elif reply is None and command == "models":
|
||||||
reply = await self._fetch_gateway("/api/models", "models")
|
reply = await self._fetch_gateway("/api/models", "models")
|
||||||
elif command == "memory":
|
elif reply is None and command == "memory":
|
||||||
reply = await self._fetch_gateway("/api/memory", "memory")
|
reply = await self._fetch_gateway("/api/memory", "memory")
|
||||||
elif command == "help":
|
elif reply is None and command == "help":
|
||||||
reply = (
|
reply = (
|
||||||
"Available commands:\n"
|
"Available commands:\n"
|
||||||
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
|
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
|
||||||
@@ -985,18 +1278,36 @@ class ChannelManager:
|
|||||||
"/status — Show current thread info\n"
|
"/status — Show current thread info\n"
|
||||||
"/models — List available models\n"
|
"/models — List available models\n"
|
||||||
"/memory — Show memory status\n"
|
"/memory — Show memory status\n"
|
||||||
|
"/<skill-name> <task> — Activate an enabled skill for one turn\n"
|
||||||
"/help — Show this help"
|
"/help — Show this help"
|
||||||
)
|
)
|
||||||
else:
|
elif reply is None:
|
||||||
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
|
slash_resolution = await asyncio.to_thread(
|
||||||
reply = f"Unknown command: /{command}. Available commands: {available}"
|
lambda: _resolve_slash_skill_command(
|
||||||
|
raw_text,
|
||||||
|
self._resolve_available_skill_names(msg),
|
||||||
|
self._get_skill_storage,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if slash_resolution and slash_resolution.failure_message:
|
||||||
|
reply = slash_resolution.failure_message
|
||||||
|
elif slash_resolution and slash_resolution.route_to_chat:
|
||||||
|
from dataclasses import replace as _dc_replace
|
||||||
|
|
||||||
|
chat_msg = _dc_replace(msg, msg_type=InboundMessageType.CHAT)
|
||||||
|
await self._handle_chat(chat_msg)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
reply = _unknown_command_reply(command)
|
||||||
|
|
||||||
outbound = OutboundMessage(
|
outbound = OutboundMessage(
|
||||||
channel_name=msg.channel_name,
|
channel_name=msg.channel_name,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
thread_id=await self._lookup_thread_id(msg) or "",
|
||||||
text=reply,
|
text=reply,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
metadata=_slim_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
@@ -1032,9 +1343,11 @@ class ChannelManager:
|
|||||||
outbound = OutboundMessage(
|
outbound = OutboundMessage(
|
||||||
channel_name=msg.channel_name,
|
channel_name=msg.channel_name,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
thread_id=await self._lookup_thread_id(msg) or "",
|
||||||
text=error_text,
|
text=error_text,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
metadata=_slim_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ from typing import Any
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PENDING_CLARIFICATION_METADATA_KEY = "pending_clarification"
|
||||||
|
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY = "resolved_from_pending_clarification"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Message types
|
# Message types
|
||||||
@@ -41,6 +44,12 @@ class InboundMessage:
|
|||||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||||
reuse the same DeerFlow thread. When ``None``, each message
|
reuse the same DeerFlow thread. When ``None``, each message
|
||||||
creates a new thread (one-shot Q&A).
|
creates a new thread (one-shot Q&A).
|
||||||
|
connection_id: Optional DeerFlow channel connection id. When present,
|
||||||
|
conversation mapping is scoped by the connection instead of the
|
||||||
|
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
||||||
|
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||||
|
Platform user ids stay in ``user_id``.
|
||||||
|
workspace_id: Optional external workspace/guild/team id.
|
||||||
files: Optional list of file attachments (platform-specific dicts).
|
files: Optional list of file attachments (platform-specific dicts).
|
||||||
metadata: Arbitrary extra data from the channel.
|
metadata: Arbitrary extra data from the channel.
|
||||||
created_at: Unix timestamp when the message was created.
|
created_at: Unix timestamp when the message was created.
|
||||||
@@ -53,6 +62,9 @@ class InboundMessage:
|
|||||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||||
thread_ts: str | None = None
|
thread_ts: str | None = None
|
||||||
topic_id: str | None = None
|
topic_id: str | None = None
|
||||||
|
connection_id: str | None = None
|
||||||
|
owner_user_id: str | None = None
|
||||||
|
workspace_id: str | None = None
|
||||||
files: list[dict[str, Any]] = field(default_factory=list)
|
files: list[dict[str, Any]] = field(default_factory=list)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
@@ -92,6 +104,9 @@ class OutboundMessage:
|
|||||||
is_final: Whether this is the final message in the response stream.
|
is_final: Whether this is the final message in the response stream.
|
||||||
thread_ts: Optional platform thread identifier for threaded replies.
|
thread_ts: Optional platform thread identifier for threaded replies.
|
||||||
metadata: Arbitrary extra data.
|
metadata: Arbitrary extra data.
|
||||||
|
connection_id: Optional DeerFlow channel connection id used for
|
||||||
|
connection-specific outbound credentials.
|
||||||
|
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||||
created_at: Unix timestamp.
|
created_at: Unix timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -103,6 +118,8 @@ class OutboundMessage:
|
|||||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||||
is_final: bool = True
|
is_final: bool = True
|
||||||
thread_ts: str | None = None
|
thread_ts: str | None = None
|
||||||
|
connection_id: str | None = None
|
||||||
|
owner_user_id: str | None = None
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
"""Local persistence for runtime IM channel configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelRuntimeConfigStore:
|
||||||
|
"""JSON-backed store for channel credentials entered from the UI.
|
||||||
|
|
||||||
|
This intentionally mirrors ``ChannelStore``: local/private deployments get
|
||||||
|
durable runtime configuration without needing a public callback URL or a
|
||||||
|
config.yaml edit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: str | Path | None = None) -> None:
|
||||||
|
if path is None:
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
|
||||||
|
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
|
||||||
|
self._path = Path(path)
|
||||||
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._data: dict[str, dict[str, Any]] = self._load()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, dict[str, Any]]:
|
||||||
|
if self._path.exists():
|
||||||
|
try:
|
||||||
|
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
|
||||||
|
return {}
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save(self) -> None:
|
||||||
|
fd = tempfile.NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
dir=self._path.parent,
|
||||||
|
suffix=".tmp",
|
||||||
|
delete=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
json.dump(self._data, fd, indent=2, ensure_ascii=False)
|
||||||
|
fd.close()
|
||||||
|
Path(fd.name).replace(self._path)
|
||||||
|
try:
|
||||||
|
self._path.chmod(0o600)
|
||||||
|
except OSError:
|
||||||
|
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
|
||||||
|
except BaseException:
|
||||||
|
fd.close()
|
||||||
|
Path(fd.name).unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_all(self) -> dict[str, dict[str, Any]]:
|
||||||
|
with self._lock:
|
||||||
|
return {name: dict(config) for name, config in self._data.items()}
|
||||||
|
|
||||||
|
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
with self._lock:
|
||||||
|
config = self._data.get(provider)
|
||||||
|
return dict(config) if isinstance(config, dict) else None
|
||||||
|
|
||||||
|
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._data[provider] = dict(config)
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
def set_provider_disconnected(self, provider: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._data[provider] = {
|
||||||
|
"enabled": False,
|
||||||
|
RUNTIME_CHANNEL_DISABLED_FLAG: True,
|
||||||
|
}
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
def remove_provider_config(self, provider: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
if provider not in self._data:
|
||||||
|
return False
|
||||||
|
del self._data[provider]
|
||||||
|
self._save()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||||
|
provider_config = getattr(channel_connections_config, provider, None)
|
||||||
|
return bool(getattr(provider_config, "enabled", False))
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
|
||||||
|
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
|
||||||
|
|
||||||
|
|
||||||
|
def merge_runtime_channel_configs(
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
channel_connections_config: Any,
|
||||||
|
*,
|
||||||
|
store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
|
||||||
|
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
runtime_store = store or ChannelRuntimeConfigStore()
|
||||||
|
for provider, runtime_config in runtime_store.load_all().items():
|
||||||
|
if not _provider_enabled(channel_connections_config, provider):
|
||||||
|
continue
|
||||||
|
if _runtime_channel_disconnected(runtime_config):
|
||||||
|
channels_config.pop(provider, None)
|
||||||
|
continue
|
||||||
|
existing = channels_config.get(provider)
|
||||||
|
merged = dict(runtime_config)
|
||||||
|
if isinstance(existing, dict):
|
||||||
|
merged.update(existing)
|
||||||
|
channels_config[provider] = merged
|
||||||
|
|
||||||
|
|
||||||
|
def apply_runtime_connection_config(
|
||||||
|
channel_connections_config: Any,
|
||||||
|
*,
|
||||||
|
store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Apply persisted connection metadata that lives outside ``channels``.
|
||||||
|
|
||||||
|
Telegram uses a bot username for deep links; UI-entered values are stored
|
||||||
|
with the runtime channel config so local restarts keep the provider
|
||||||
|
configured.
|
||||||
|
"""
|
||||||
|
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||||
|
return channel_connections_config
|
||||||
|
|
||||||
|
runtime_store = store or ChannelRuntimeConfigStore()
|
||||||
|
telegram_runtime_config = runtime_store.get_provider_config("telegram")
|
||||||
|
bot_username = ""
|
||||||
|
if isinstance(telegram_runtime_config, dict):
|
||||||
|
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
|
||||||
|
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
|
||||||
|
return channel_connections_config
|
||||||
|
|
||||||
|
config = channel_connections_config.model_copy(deep=True)
|
||||||
|
config.telegram.bot_username = bot_username
|
||||||
|
return config
|
||||||
+170
-25
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||||
from app.channels.message_bus import MessageBus
|
from app.channels.message_bus import MessageBus
|
||||||
|
from app.channels.runtime_config_store import merge_runtime_channel_configs
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
|||||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||||
|
|
||||||
|
|
||||||
|
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
|
||||||
|
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||||
|
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
||||||
value = config.pop(config_key, None)
|
value = config.pop(config_key, None)
|
||||||
if isinstance(value, str) and value.strip():
|
if isinstance(value, str) and value.strip():
|
||||||
@@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||||
|
connection_config = getattr(app_config, "channel_connections", None)
|
||||||
|
merge_runtime_channel_configs(channels_config, connection_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connection_repo(app_config: AppConfig):
|
||||||
|
connection_config = getattr(app_config, "channel_connections", None)
|
||||||
|
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
|
from deerflow.persistence.engine import get_session_factory
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel connection repository")
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
if session_factory is None:
|
||||||
|
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||||
|
return None
|
||||||
|
return ChannelConnectionRepository(session_factory)
|
||||||
|
|
||||||
|
|
||||||
class ChannelService:
|
class ChannelService:
|
||||||
"""Manages the lifecycle of all configured IM channels.
|
"""Manages the lifecycle of all configured IM channels.
|
||||||
|
|
||||||
@@ -59,9 +90,10 @@ class ChannelService:
|
|||||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||||
self.bus = MessageBus()
|
self.bus = MessageBus()
|
||||||
self.store = ChannelStore()
|
self.store = ChannelStore()
|
||||||
|
self._connection_repo = connection_repo
|
||||||
config = dict(channels_config or {})
|
config = dict(channels_config or {})
|
||||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||||
@@ -74,10 +106,12 @@ class ChannelService:
|
|||||||
gateway_url=gateway_url,
|
gateway_url=gateway_url,
|
||||||
default_session=default_session if isinstance(default_session, dict) else None,
|
default_session=default_session if isinstance(default_session, dict) else None,
|
||||||
channel_sessions=channel_sessions,
|
channel_sessions=channel_sessions,
|
||||||
|
connection_repo=connection_repo,
|
||||||
)
|
)
|
||||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||||
self._config = config
|
self._config = config
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._readiness_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
||||||
@@ -90,8 +124,9 @@ class ChannelService:
|
|||||||
# extra fields are allowed by AppConfig (extra="allow")
|
# extra fields are allowed by AppConfig (extra="allow")
|
||||||
extra = app_config.model_extra or {}
|
extra = app_config.model_extra or {}
|
||||||
if "channels" in extra:
|
if "channels" in extra:
|
||||||
channels_config = extra["channels"]
|
channels_config = dict(extra["channels"] or {})
|
||||||
return cls(channels_config=channels_config)
|
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||||
|
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the manager and all enabled channels."""
|
"""Start the manager and all enabled channels."""
|
||||||
@@ -99,63 +134,169 @@ class ChannelService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
await self.manager.start()
|
await self.manager.start()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
ready_status = await self.ensure_ready_channels(attempts=2)
|
||||||
|
ready_count = sum(1 for ready in ready_status.values() if ready)
|
||||||
|
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
|
||||||
|
|
||||||
|
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
|
||||||
|
"""Start or restart enabled configured channels that are not ready."""
|
||||||
|
ready_status: dict[str, bool] = {}
|
||||||
for name, channel_config in self._config.items():
|
for name, channel_config in self._config.items():
|
||||||
if not isinstance(channel_config, dict):
|
if not isinstance(channel_config, dict):
|
||||||
continue
|
continue
|
||||||
if not channel_config.get("enabled", False):
|
if not channel_config.get("enabled", False):
|
||||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
if _channel_has_credentials(name, channel_config):
|
||||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
|
||||||
if has_creds:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
|
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
||||||
name,
|
|
||||||
name,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Channel %s is disabled, skipping", name)
|
logger.info("A configured channel is disabled, skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self._start_channel(name, channel_config)
|
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
|
||||||
|
return ready_status
|
||||||
|
|
||||||
self._running = True
|
async def ensure_channel_ready(
|
||||||
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
|
self,
|
||||||
|
name: str,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
attempts: int = 1,
|
||||||
|
) -> bool:
|
||||||
|
"""Ensure a single enabled channel is running using its current config."""
|
||||||
|
if not self._running:
|
||||||
|
logger.warning("ChannelService is not running; cannot ensure channel readiness")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
self._config[name] = dict(config)
|
||||||
|
|
||||||
|
# Serialize per channel: readiness is polled from request handlers, so
|
||||||
|
# concurrent calls must not stop/start the same channel worker twice.
|
||||||
|
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
|
||||||
|
async with lock:
|
||||||
|
channel_config = self._config.get(name)
|
||||||
|
if not channel_config or not isinstance(channel_config, dict):
|
||||||
|
logger.warning("No config for requested channel")
|
||||||
|
return False
|
||||||
|
if not channel_config.get("enabled", False):
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel = self._channels.get(name)
|
||||||
|
if channel is not None and channel.is_running:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if channel is not None:
|
||||||
|
try:
|
||||||
|
await channel.stop()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error stopping non-running channel before readiness retry")
|
||||||
|
self._channels.pop(name, None)
|
||||||
|
|
||||||
|
max_attempts = max(1, attempts)
|
||||||
|
for attempt in range(max_attempts):
|
||||||
|
if attempt > 0:
|
||||||
|
logger.info("Retrying channel startup after readiness check")
|
||||||
|
if await self._start_channel(name, channel_config):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop all channels and the manager."""
|
"""Stop all channels and the manager."""
|
||||||
for name, channel in list(self._channels.items()):
|
for name, channel in list(self._channels.items()):
|
||||||
try:
|
try:
|
||||||
await channel.stop()
|
await channel.stop()
|
||||||
logger.info("Channel %s stopped", name)
|
logger.info("Channel stopped")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error stopping channel %s", name)
|
logger.exception("Error stopping channel")
|
||||||
self._channels.clear()
|
self._channels.clear()
|
||||||
|
|
||||||
await self.manager.stop()
|
await self.manager.stop()
|
||||||
self._running = False
|
self._running = False
|
||||||
logger.info("ChannelService stopped")
|
logger.info("ChannelService stopped")
|
||||||
|
|
||||||
async def restart_channel(self, name: str) -> bool:
|
def _load_channel_config(self, name: str) -> dict[str, Any] | None:
|
||||||
|
"""Load the latest config for a specific channel from disk.
|
||||||
|
|
||||||
|
Uses ``get_app_config()`` which detects file changes via mtime,
|
||||||
|
so edits to ``config.yaml`` are picked up without a process restart.
|
||||||
|
The UI runtime-config overlay applied at startup is re-applied here
|
||||||
|
so a file-driven reload neither drops credentials entered from the
|
||||||
|
browser nor resurrects a channel disconnected from it.
|
||||||
|
Falls back to the cached ``self._config`` when config loading fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
|
app_config = get_app_config()
|
||||||
|
extra = app_config.model_extra or {}
|
||||||
|
channels_config = dict(extra.get("channels") or {})
|
||||||
|
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||||
|
channel_config = channels_config.get(name)
|
||||||
|
if isinstance(channel_config, dict):
|
||||||
|
# Update the cached config so get_status() stays consistent.
|
||||||
|
self._config[name] = channel_config
|
||||||
|
return channel_config
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to reload config for channel %s, using cached version", name)
|
||||||
|
return self._config.get(name)
|
||||||
|
|
||||||
|
async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
|
||||||
"""Restart a specific channel. Returns True if successful."""
|
"""Restart a specific channel. Returns True if successful."""
|
||||||
if name in self._channels:
|
if name in self._channels:
|
||||||
try:
|
try:
|
||||||
await self._channels[name].stop()
|
await self._channels[name].stop()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error stopping channel %s for restart", name)
|
logger.exception("Error stopping channel for restart")
|
||||||
del self._channels[name]
|
del self._channels[name]
|
||||||
|
|
||||||
config = self._config.get(name)
|
if reload_config:
|
||||||
|
# Reading config.yaml and the runtime store is disk IO; keep it
|
||||||
|
# off the event loop.
|
||||||
|
config = await asyncio.to_thread(self._load_channel_config, name)
|
||||||
|
else:
|
||||||
|
config = self._config.get(name)
|
||||||
if not config or not isinstance(config, dict):
|
if not config or not isinstance(config, dict):
|
||||||
logger.warning("No config for channel %s", name)
|
logger.warning("No config for requested channel")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if not config.get("enabled", False):
|
||||||
|
logger.info("Channel %s is disabled, skipping restart", name)
|
||||||
|
return True
|
||||||
|
|
||||||
return await self._start_channel(name, config)
|
return await self._start_channel(name, config)
|
||||||
|
|
||||||
|
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||||
|
"""Apply runtime config for a channel and restart it if the service is running."""
|
||||||
|
self._config[name] = dict(config)
|
||||||
|
if not self._running:
|
||||||
|
return True
|
||||||
|
# The caller just supplied the authoritative config (e.g. credentials
|
||||||
|
# entered in the browser that are never written to config.yaml) — a
|
||||||
|
# file reload here would clobber it with the stale on-disk entry.
|
||||||
|
return await self.restart_channel(name, reload_config=False)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str) -> bool:
|
||||||
|
"""Remove runtime config for a channel and stop it if currently running."""
|
||||||
|
self._config.pop(name, None)
|
||||||
|
channel = self._channels.pop(name, None)
|
||||||
|
if channel is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
await channel.stop()
|
||||||
|
logger.info("Channel stopped and removed")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error stopping channel for removal")
|
||||||
|
return False
|
||||||
|
|
||||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||||
"""Instantiate and start a single channel."""
|
"""Instantiate and start a single channel."""
|
||||||
import_path = _CHANNEL_REGISTRY.get(name)
|
import_path = _CHANNEL_REGISTRY.get(name)
|
||||||
if not import_path:
|
if not import_path:
|
||||||
logger.warning("Unknown channel type: %s", name)
|
logger.warning("Unknown channel type")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -163,24 +304,26 @@ class ChannelService:
|
|||||||
|
|
||||||
channel_cls = resolve_class(import_path, base_class=None)
|
channel_cls = resolve_class(import_path, base_class=None)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to import channel class for %s", name)
|
logger.exception("Failed to import channel class")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = dict(config)
|
config = dict(config)
|
||||||
config["channel_store"] = self.store
|
config["channel_store"] = self.store
|
||||||
|
if self._connection_repo is not None:
|
||||||
|
config["connection_repo"] = self._connection_repo
|
||||||
channel = channel_cls(bus=self.bus, config=config)
|
channel = channel_cls(bus=self.bus, config=config)
|
||||||
self._channels[name] = channel
|
self._channels[name] = channel
|
||||||
await channel.start()
|
await channel.start()
|
||||||
if not channel.is_running:
|
if not channel.is_running:
|
||||||
self._channels.pop(name, None)
|
self._channels.pop(name, None)
|
||||||
logger.error("Channel %s did not enter a running state after start()", name)
|
logger.error("Channel did not enter a running state after start()")
|
||||||
return False
|
return False
|
||||||
logger.info("Channel %s started", name)
|
logger.info("Channel started")
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
self._channels.pop(name, None)
|
self._channels.pop(name, None)
|
||||||
logger.exception("Failed to start channel %s", name)
|
logger.exception("Failed to start channel")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_status(self) -> dict[str, Any]:
|
def get_status(self) -> dict[str, Any]:
|
||||||
@@ -219,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
|
|||||||
global _channel_service
|
global _channel_service
|
||||||
if _channel_service is not None:
|
if _channel_service is not None:
|
||||||
return _channel_service
|
return _channel_service
|
||||||
_channel_service = ChannelService.from_app_config(app_config)
|
# from_app_config reads the JSON channel store and runtime config files;
|
||||||
|
# keep that disk IO off the event loop.
|
||||||
|
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
|
||||||
await _channel_service.start()
|
await _channel_service.start()
|
||||||
return _channel_service
|
return _channel_service
|
||||||
|
|
||||||
|
|||||||
+173
-15
@@ -9,6 +9,8 @@ from typing import Any
|
|||||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -32,6 +34,20 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|||||||
return {str(user_id) for user_id in values if str(user_id)}
|
return {str(user_id) for user_id in values if str(user_id)}
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
|
||||||
|
if not bot_user_id:
|
||||||
|
return text
|
||||||
|
if not text.startswith("<@"):
|
||||||
|
return text
|
||||||
|
end = text.find(">")
|
||||||
|
if end <= 2:
|
||||||
|
return text
|
||||||
|
mentioned_user_id = text[2:end].split("|", 1)[0].lstrip("!")
|
||||||
|
if mentioned_user_id != bot_user_id:
|
||||||
|
return text
|
||||||
|
return text[end + 1 :].lstrip()
|
||||||
|
|
||||||
|
|
||||||
class SlackChannel(Channel):
|
class SlackChannel(Channel):
|
||||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||||
|
|
||||||
@@ -49,6 +65,11 @@ class SlackChannel(Channel):
|
|||||||
self._web_client = None
|
self._web_client = None
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
self._web_client_factory = config.get("web_client_factory")
|
||||||
|
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
|
||||||
|
configured_bot_user_id = config.get("bot_user_id")
|
||||||
|
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -63,15 +84,28 @@ class SlackChannel(Channel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._SocketModeResponse = SocketModeResponse
|
self._SocketModeResponse = SocketModeResponse
|
||||||
|
if self._web_client_factory is None:
|
||||||
|
self._web_client_factory = WebClient
|
||||||
|
|
||||||
bot_token = self.config.get("bot_token", "")
|
bot_token = self.config.get("bot_token", "")
|
||||||
app_token = self.config.get("app_token", "")
|
app_token = self.config.get("app_token", "")
|
||||||
|
|
||||||
|
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||||
|
if not bot_token:
|
||||||
|
logger.error("Slack HTTP Events mode requires bot_token")
|
||||||
|
return
|
||||||
|
await self._initialize_operator_web_client(str(bot_token))
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
self._running = True
|
||||||
|
self.bus.subscribe_outbound(self._on_outbound)
|
||||||
|
logger.info("Slack channel started in HTTP Events mode")
|
||||||
|
return
|
||||||
|
|
||||||
if not bot_token or not app_token:
|
if not bot_token or not app_token:
|
||||||
logger.error("Slack channel requires bot_token and app_token")
|
logger.error("Slack channel requires bot_token and app_token")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._web_client = WebClient(token=bot_token)
|
await self._initialize_operator_web_client(str(bot_token))
|
||||||
self._socket_client = SocketModeClient(
|
self._socket_client = SocketModeClient(
|
||||||
app_token=app_token,
|
app_token=app_token,
|
||||||
web_client=self._web_client,
|
web_client=self._web_client,
|
||||||
@@ -96,7 +130,8 @@ class SlackChannel(Channel):
|
|||||||
logger.info("Slack channel stopped")
|
logger.info("Slack channel stopped")
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||||
if not self._web_client:
|
web_client = await self._get_web_client_for_message(msg)
|
||||||
|
if not web_client:
|
||||||
return
|
return
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
@@ -109,11 +144,12 @@ class SlackChannel(Channel):
|
|||||||
last_exc: Exception | None = None
|
last_exc: Exception | None = None
|
||||||
for attempt in range(_max_retries):
|
for attempt in range(_max_retries):
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||||
# Add a completion reaction to the thread root
|
# Add a completion reaction to the thread root
|
||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self._add_reaction,
|
self._add_reaction_with_client,
|
||||||
|
web_client,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"white_check_mark",
|
"white_check_mark",
|
||||||
@@ -137,7 +173,8 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self._add_reaction,
|
self._add_reaction_with_client,
|
||||||
|
web_client,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"x",
|
"x",
|
||||||
@@ -149,7 +186,8 @@ class SlackChannel(Channel):
|
|||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
if not self._web_client:
|
web_client = await self._get_web_client_for_message(msg)
|
||||||
|
if not web_client:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -162,7 +200,7 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
kwargs["thread_ts"] = msg.thread_ts
|
kwargs["thread_ts"] = msg.thread_ts
|
||||||
|
|
||||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
||||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -171,12 +209,45 @@ class SlackChannel(Channel):
|
|||||||
|
|
||||||
# -- internal ----------------------------------------------------------
|
# -- internal ----------------------------------------------------------
|
||||||
|
|
||||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
||||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
self._web_client = self._web_client_factory(token=bot_token)
|
||||||
if not self._web_client:
|
if self._bot_user_id is not None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
self._web_client.reactions_add(
|
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||||
|
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||||
|
if user_id is None:
|
||||||
|
auth_get = getattr(auth_info, "get", None)
|
||||||
|
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||||
|
if isinstance(user_id, str) and user_id:
|
||||||
|
self._bot_user_id = user_id
|
||||||
|
except Exception:
|
||||||
|
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||||
|
|
||||||
|
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||||
|
if msg.connection_id and self._connection_repo is not None:
|
||||||
|
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||||
|
access_token = credentials.get("access_token") if credentials else None
|
||||||
|
if not access_token:
|
||||||
|
return self._web_client
|
||||||
|
# WebClient keeps its own HTTP session and rate-limit state, so
|
||||||
|
# reuse one per connection until its token changes.
|
||||||
|
cached = self._connection_web_clients.get(msg.connection_id)
|
||||||
|
if cached is not None and cached[0] == access_token:
|
||||||
|
return cached[1]
|
||||||
|
if self._web_client_factory is None:
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
|
||||||
|
self._web_client_factory = WebClient
|
||||||
|
web_client = self._web_client_factory(token=access_token)
|
||||||
|
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
|
||||||
|
return web_client
|
||||||
|
return self._web_client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||||
|
try:
|
||||||
|
web_client.reactions_add(
|
||||||
channel=channel_id,
|
channel=channel_id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
name=emoji,
|
name=emoji,
|
||||||
@@ -185,6 +256,12 @@ class SlackChannel(Channel):
|
|||||||
if "already_reacted" not in str(exc):
|
if "already_reacted" not in str(exc):
|
||||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||||
|
|
||||||
|
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||||
|
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||||
|
if not self._web_client:
|
||||||
|
return
|
||||||
|
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
||||||
|
|
||||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||||
if not self._web_client:
|
if not self._web_client:
|
||||||
@@ -210,17 +287,26 @@ class SlackChannel(Channel):
|
|||||||
if event_type != "events_api":
|
if event_type != "events_api":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._bot_user_id is None:
|
||||||
|
authorization = next((item for item in req.payload.get("authorizations", []) if isinstance(item, dict)), None)
|
||||||
|
user_id = authorization.get("user_id") if authorization else None
|
||||||
|
if isinstance(user_id, str) and user_id:
|
||||||
|
self._bot_user_id = user_id
|
||||||
|
|
||||||
event = req.payload.get("event", {})
|
event = req.payload.get("event", {})
|
||||||
etype = event.get("type", "")
|
etype = event.get("type", "")
|
||||||
|
|
||||||
# Handle message events (DM or @mention)
|
# Handle message events (DM or @mention)
|
||||||
if etype in ("message", "app_mention"):
|
if etype in ("message", "app_mention"):
|
||||||
self._handle_message_event(event)
|
self._handle_message_event(
|
||||||
|
event,
|
||||||
|
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing Slack event")
|
logger.exception("Error processing Slack event")
|
||||||
|
|
||||||
def _handle_message_event(self, event: dict) -> None:
|
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
|
||||||
# Ignore bot messages
|
# Ignore bot messages
|
||||||
if event.get("bot_id") or event.get("subtype"):
|
if event.get("bot_id") or event.get("subtype"):
|
||||||
return
|
return
|
||||||
@@ -233,13 +319,28 @@ class SlackChannel(Channel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
text = event.get("text", "").strip()
|
text = event.get("text", "").strip()
|
||||||
|
if event.get("type") == "app_mention":
|
||||||
|
text = _strip_leading_slack_bot_mention(text, self._bot_user_id)
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code:
|
||||||
|
if self._loop and self._loop.is_running():
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self._bind_connection_from_connect_code(
|
||||||
|
event=event,
|
||||||
|
team_id=str(team_id or event.get("team") or ""),
|
||||||
|
code=connect_code,
|
||||||
|
),
|
||||||
|
self._loop,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
channel_id = event.get("channel", "")
|
channel_id = event.get("channel", "")
|
||||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||||
|
|
||||||
if text.startswith("/"):
|
if is_known_channel_command(text):
|
||||||
msg_type = InboundMessageType.COMMAND
|
msg_type = InboundMessageType.COMMAND
|
||||||
else:
|
else:
|
||||||
msg_type = InboundMessageType.CHAT
|
msg_type = InboundMessageType.CHAT
|
||||||
@@ -261,4 +362,61 @@ class SlackChannel(Channel):
|
|||||||
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
||||||
# Send "running" reply first (fire-and-forget from SDK thread)
|
# Send "running" reply first (fire-and-forget from SDK thread)
|
||||||
self._send_running_reply(channel_id, thread_ts)
|
self._send_running_reply(channel_id, thread_ts)
|
||||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
if self._connection_repo is None:
|
||||||
|
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||||
|
else:
|
||||||
|
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
|
||||||
|
|
||||||
|
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
|
||||||
|
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
|
||||||
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||||
|
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="slack",
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel_id = str(event.get("channel") or "")
|
||||||
|
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
|
||||||
|
return True
|
||||||
|
|
||||||
|
user_id = str(event.get("user") or "")
|
||||||
|
if not user_id or not team_id:
|
||||||
|
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="slack",
|
||||||
|
external_account_id=user_id,
|
||||||
|
workspace_id=team_id,
|
||||||
|
metadata={
|
||||||
|
"team_id": team_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
|
||||||
|
if not self._web_client or not channel_id:
|
||||||
|
return
|
||||||
|
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
|
||||||
|
if thread_ts:
|
||||||
|
kwargs["thread_ts"] = thread_ts
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
|
||||||
|
|||||||
@@ -5,13 +5,27 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TELEGRAM_MAX_MESSAGE_LENGTH = 4096
|
||||||
|
STREAM_EDIT_MIN_INTERVAL_SECONDS = 1.0
|
||||||
|
# Groups (negative chat_id) are capped at 20 messages/minute by Telegram,
|
||||||
|
# so stream edits there must pace well below the private-chat 1 msg/s guideline.
|
||||||
|
STREAM_EDIT_GROUP_MIN_INTERVAL_SECONDS = 3.0
|
||||||
|
# Bound on tracked in-flight streamed messages; entries normally clear on the
|
||||||
|
# final update, this only guards against leaks when a final never arrives.
|
||||||
|
MAX_TRACKED_STREAM_MESSAGES = 256
|
||||||
|
|
||||||
|
# Indirection so tests can patch the clock without touching the global time module.
|
||||||
|
_monotonic = time.monotonic
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannel(Channel):
|
class TelegramChannel(Channel):
|
||||||
"""Telegram bot channel using long-polling.
|
"""Telegram bot channel using long-polling.
|
||||||
@@ -35,6 +49,14 @@ class TelegramChannel(Channel):
|
|||||||
pass
|
pass
|
||||||
# chat_id -> last sent message_id for threaded replies
|
# chat_id -> last sent message_id for threaded replies
|
||||||
self._last_bot_message: dict[str, int] = {}
|
self._last_bot_message: dict[str, int] = {}
|
||||||
|
# stream_key ("chat_id:thread_ts") -> state of the in-flight streamed
|
||||||
|
# bot message being edited in place: {"message_id", "last_edit_at", "last_text"}
|
||||||
|
self._stream_messages: dict[str, dict[str, Any]] = {}
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -60,12 +82,17 @@ class TelegramChannel(Channel):
|
|||||||
|
|
||||||
# Command handlers
|
# Command handlers
|
||||||
app.add_handler(CommandHandler("start", self._cmd_start))
|
app.add_handler(CommandHandler("start", self._cmd_start))
|
||||||
|
app.add_handler(CommandHandler("bootstrap", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("new", self._cmd_generic))
|
app.add_handler(CommandHandler("new", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("status", self._cmd_generic))
|
app.add_handler(CommandHandler("status", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("models", self._cmd_generic))
|
app.add_handler(CommandHandler("models", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("memory", self._cmd_generic))
|
app.add_handler(CommandHandler("memory", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("help", self._cmd_generic))
|
app.add_handler(CommandHandler("help", self._cmd_generic))
|
||||||
|
|
||||||
|
# Slash skill commands are dynamic and cannot all be pre-registered
|
||||||
|
# with Telegram, so route unknown slash commands through chat handling.
|
||||||
|
app.add_handler(MessageHandler(filters.TEXT & filters.COMMAND, self._on_text))
|
||||||
|
|
||||||
# General message handler
|
# General message handler
|
||||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
|
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
|
||||||
|
|
||||||
@@ -97,10 +124,117 @@ class TelegramChannel(Channel):
|
|||||||
logger.error("Invalid Telegram chat_id: %s", msg.chat_id)
|
logger.error("Invalid Telegram chat_id: %s", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": msg.text}
|
key = self._stream_key(msg.chat_id, msg.thread_ts)
|
||||||
|
|
||||||
|
if not msg.is_final:
|
||||||
|
await self._send_stream_update(chat_id, key, msg.text, reply_to=self._parse_message_id(msg.thread_ts))
|
||||||
|
return
|
||||||
|
|
||||||
|
state = self._stream_messages.pop(key, None)
|
||||||
|
if state is not None:
|
||||||
|
await self._finalize_stream_message(chat_id, msg.chat_id, state, msg.text)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send_new_message(chat_id, msg.chat_id, msg.text, _max_retries=_max_retries)
|
||||||
|
|
||||||
|
async def _send_stream_update(self, chat_id: int, key: str, text: str, reply_to: int | None = None) -> None:
|
||||||
|
"""Edit the in-flight streamed message with accumulated text.
|
||||||
|
|
||||||
|
Updates are best-effort: throttled, rate-limit drops are silent. The
|
||||||
|
manager always publishes a final message afterwards, which guarantees
|
||||||
|
delivery of the complete text.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
|
||||||
|
display = text
|
||||||
|
if len(display) > TELEGRAM_MAX_MESSAGE_LENGTH:
|
||||||
|
display = display[: TELEGRAM_MAX_MESSAGE_LENGTH - 1] + "…"
|
||||||
|
|
||||||
|
bot = self._application.bot
|
||||||
|
state = self._stream_messages.get(key)
|
||||||
|
|
||||||
|
send_kwargs: dict[str, Any] = {"chat_id": chat_id, "text": display}
|
||||||
|
if reply_to:
|
||||||
|
send_kwargs["reply_to_message_id"] = reply_to
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
try:
|
||||||
|
sent = await bot.send_message(**send_kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Telegram] failed to start stream message in chat=%s", chat_id)
|
||||||
|
return
|
||||||
|
self._register_stream_message(key, message_id=sent.message_id, last_text=display, last_edit_at=_monotonic())
|
||||||
|
return
|
||||||
|
|
||||||
|
now = _monotonic()
|
||||||
|
min_interval = STREAM_EDIT_GROUP_MIN_INTERVAL_SECONDS if chat_id < 0 else STREAM_EDIT_MIN_INTERVAL_SECONDS
|
||||||
|
if now - state["last_edit_at"] < min_interval:
|
||||||
|
return
|
||||||
|
if display == state["last_text"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await bot.edit_message_text(chat_id=chat_id, message_id=state["message_id"], text=display)
|
||||||
|
except Exception as exc:
|
||||||
|
if self._is_not_modified(exc):
|
||||||
|
state["last_text"] = display
|
||||||
|
return
|
||||||
|
if self._is_retry_after(exc):
|
||||||
|
logger.debug("[Telegram] stream edit rate-limited in chat=%s, dropping update", chat_id)
|
||||||
|
return
|
||||||
|
logger.warning("[Telegram] stream edit failed in chat=%s, sending new message: %s", chat_id, exc)
|
||||||
|
try:
|
||||||
|
sent = await bot.send_message(**send_kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Telegram] failed to send fallback stream message in chat=%s", chat_id)
|
||||||
|
return
|
||||||
|
state["message_id"] = sent.message_id
|
||||||
|
|
||||||
|
state["last_edit_at"] = _monotonic()
|
||||||
|
state["last_text"] = display
|
||||||
|
|
||||||
|
async def _finalize_stream_message(self, chat_id: int, chat_key: str, state: dict[str, Any], text: str) -> None:
|
||||||
|
"""Apply the final text: edit the streamed message, splitting overflow into follow-ups."""
|
||||||
|
bot = self._application.bot
|
||||||
|
chunks = self._split_message(text or "")
|
||||||
|
|
||||||
|
edited = True
|
||||||
|
if chunks[0] != state["last_text"]:
|
||||||
|
edited = await self._edit_final_chunk(bot, chat_id, state["message_id"], chunks[0])
|
||||||
|
|
||||||
|
if edited:
|
||||||
|
self._last_bot_message[chat_key] = state["message_id"]
|
||||||
|
else:
|
||||||
|
# Edit could not be applied (e.g. message deleted) — deliver the
|
||||||
|
# first chunk as a fresh message with the standard retry policy.
|
||||||
|
await self._send_new_message(chat_id, chat_key, chunks[0])
|
||||||
|
|
||||||
|
for chunk in chunks[1:]:
|
||||||
|
await self._send_new_message(chat_id, chat_key, chunk)
|
||||||
|
|
||||||
|
async def _edit_final_chunk(self, bot, chat_id: int, message_id: int, text: str) -> bool:
|
||||||
|
"""Edit with one rate-limit retry. Returns False if the edit could not be applied."""
|
||||||
|
for attempt in range(2):
|
||||||
|
try:
|
||||||
|
await bot.edit_message_text(chat_id=chat_id, message_id=message_id, text=text)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
if self._is_not_modified(exc):
|
||||||
|
return True
|
||||||
|
if self._is_retry_after(exc) and attempt == 0:
|
||||||
|
await asyncio.sleep(self._retry_after_seconds(exc))
|
||||||
|
continue
|
||||||
|
logger.warning("[Telegram] final edit failed in chat=%s: %s", chat_id, exc)
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _send_new_message(self, chat_id: int, chat_key: str, text: str, *, _max_retries: int = 3) -> int | None:
|
||||||
|
"""Send a fresh message with retry/backoff. Returns the sent message_id."""
|
||||||
|
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": text}
|
||||||
|
|
||||||
# Reply to the last bot message in this chat for threading
|
# Reply to the last bot message in this chat for threading
|
||||||
reply_to = self._last_bot_message.get(msg.chat_id)
|
reply_to = self._last_bot_message.get(chat_key)
|
||||||
if reply_to:
|
if reply_to:
|
||||||
kwargs["reply_to_message_id"] = reply_to
|
kwargs["reply_to_message_id"] = reply_to
|
||||||
|
|
||||||
@@ -109,8 +243,8 @@ class TelegramChannel(Channel):
|
|||||||
for attempt in range(_max_retries):
|
for attempt in range(_max_retries):
|
||||||
try:
|
try:
|
||||||
sent = await bot.send_message(**kwargs)
|
sent = await bot.send_message(**kwargs)
|
||||||
self._last_bot_message[msg.chat_id] = sent.message_id
|
self._last_bot_message[chat_key] = sent.message_id
|
||||||
return
|
return sent.message_id
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
if attempt < _max_retries - 1:
|
if attempt < _max_retries - 1:
|
||||||
@@ -173,17 +307,63 @@ class TelegramChannel(Channel):
|
|||||||
|
|
||||||
# -- helpers -----------------------------------------------------------
|
# -- helpers -----------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stream_key(chat_id: str, thread_ts: str | None) -> str:
|
||||||
|
return f"{chat_id}:{thread_ts or ''}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_message_id(value: str | None) -> int | None:
|
||||||
|
try:
|
||||||
|
return int(value) if value else None
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _register_stream_message(self, key: str, *, message_id: int, last_text: str, last_edit_at: float) -> None:
|
||||||
|
self._stream_messages.pop(key, None)
|
||||||
|
while len(self._stream_messages) >= MAX_TRACKED_STREAM_MESSAGES:
|
||||||
|
self._stream_messages.pop(next(iter(self._stream_messages)))
|
||||||
|
self._stream_messages[key] = {
|
||||||
|
"message_id": message_id,
|
||||||
|
"last_edit_at": last_edit_at,
|
||||||
|
"last_text": last_text,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_retry_after(exc: Exception) -> bool:
|
||||||
|
return getattr(exc, "retry_after", None) is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _retry_after_seconds(exc: Exception) -> float:
|
||||||
|
value = getattr(exc, "retry_after", 0)
|
||||||
|
if hasattr(value, "total_seconds"):
|
||||||
|
return float(value.total_seconds())
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_not_modified(exc: Exception) -> bool:
|
||||||
|
return "message is not modified" in str(exc).lower()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_message(text: str) -> list[str]:
|
||||||
|
return [text[i : i + TELEGRAM_MAX_MESSAGE_LENGTH] for i in range(0, len(text), TELEGRAM_MAX_MESSAGE_LENGTH)] or [text]
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||||
"""Send a 'Working on it...' reply to the user's message."""
|
"""Send a 'Working on it...' reply and register it as the stream target."""
|
||||||
if not self._application:
|
if not self._application:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
bot = self._application.bot
|
bot = self._application.bot
|
||||||
await bot.send_message(
|
sent = await bot.send_message(
|
||||||
chat_id=int(chat_id),
|
chat_id=int(chat_id),
|
||||||
text="Working on it...",
|
text="Working on it...",
|
||||||
reply_to_message_id=reply_to_message_id,
|
reply_to_message_id=reply_to_message_id,
|
||||||
)
|
)
|
||||||
|
self._register_stream_message(
|
||||||
|
self._stream_key(chat_id, str(reply_to_message_id)),
|
||||||
|
message_id=sent.message_id,
|
||||||
|
last_text="Working on it...",
|
||||||
|
last_edit_at=0.0,
|
||||||
|
)
|
||||||
logger.info("[Telegram] 'Working on it...' reply sent in chat=%s", chat_id)
|
logger.info("[Telegram] 'Working on it...' reply sent in chat=%s", chat_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
||||||
@@ -228,10 +408,90 @@ class TelegramChannel(Channel):
|
|||||||
return True
|
return True
|
||||||
return user_id in self._allowed_users
|
return user_id in self._allowed_users
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _telegram_display_name(user) -> str:
|
||||||
|
full_name = getattr(user, "full_name", None)
|
||||||
|
if isinstance(full_name, str) and full_name:
|
||||||
|
return full_name
|
||||||
|
username = getattr(user, "username", None)
|
||||||
|
if isinstance(username, str) and username:
|
||||||
|
return username
|
||||||
|
return str(getattr(user, "id", ""))
|
||||||
|
|
||||||
|
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
||||||
|
if self._connection_repo is None or not state_token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
||||||
|
if state is None:
|
||||||
|
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
owner_user_id = state["owner_user_id"]
|
||||||
|
user_id = str(update.effective_user.id)
|
||||||
|
chat_id = str(update.effective_chat.id)
|
||||||
|
connection = await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id=user_id,
|
||||||
|
external_account_name=self._telegram_display_name(update.effective_user),
|
||||||
|
workspace_id=chat_id,
|
||||||
|
workspace_name=None,
|
||||||
|
metadata={
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"chat_type": update.effective_chat.type,
|
||||||
|
"telegram_username": getattr(update.effective_user, "username", None),
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
||||||
|
await update.message.reply_text("Telegram connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="telegram",
|
||||||
|
workspace_id=inbound.chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_bot_username(self, context) -> str | None:
|
||||||
|
bot = getattr(context, "bot", None)
|
||||||
|
username = getattr(bot, "username", None)
|
||||||
|
if not username and self._application is not None:
|
||||||
|
username = getattr(getattr(self._application, "bot", None), "username", None)
|
||||||
|
return str(username) if username else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_bot_username_from_leading_command(text: str, bot_username: str | None) -> str:
|
||||||
|
username = (bot_username or "").lstrip("@").lower()
|
||||||
|
if not username or not text.startswith("/"):
|
||||||
|
return text
|
||||||
|
|
||||||
|
parts = text.split(maxsplit=1)
|
||||||
|
command_token = parts[0]
|
||||||
|
if "@" not in command_token:
|
||||||
|
return text
|
||||||
|
|
||||||
|
command_name, addressed_username = command_token[1:].rsplit("@", 1)
|
||||||
|
if not command_name or addressed_username.lower() != username:
|
||||||
|
return text
|
||||||
|
|
||||||
|
normalized = f"/{command_name}"
|
||||||
|
if len(parts) > 1:
|
||||||
|
normalized = f"{normalized} {parts[1]}"
|
||||||
|
return normalized
|
||||||
|
|
||||||
async def _cmd_start(self, update, context) -> None:
|
async def _cmd_start(self, update, context) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
|
args = getattr(context, "args", []) if context is not None else []
|
||||||
|
if args:
|
||||||
|
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||||
|
if handled:
|
||||||
|
return
|
||||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||||
|
|
||||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||||
@@ -243,7 +503,7 @@ class TelegramChannel(Channel):
|
|||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
|
|
||||||
text = update.message.text
|
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
|
||||||
chat_id = str(update.effective_chat.id)
|
chat_id = str(update.effective_chat.id)
|
||||||
user_id = str(update.effective_user.id)
|
user_id = str(update.effective_user.id)
|
||||||
msg_id = str(update.message.message_id)
|
msg_id = str(update.message.message_id)
|
||||||
@@ -267,6 +527,7 @@ class TelegramChannel(Channel):
|
|||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||||
@@ -279,7 +540,7 @@ class TelegramChannel(Channel):
|
|||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
|
|
||||||
text = update.message.text.strip()
|
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -309,6 +570,7 @@ class TelegramChannel(Channel):
|
|||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ from cryptography.hazmat.primitives import padding
|
|||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -252,6 +254,7 @@ class WechatChannel(Channel):
|
|||||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
self._load_state()
|
self._load_state()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -616,11 +619,21 @@ class WechatChannel(Channel):
|
|||||||
if thread_ts:
|
if thread_ts:
|
||||||
self._context_tokens_by_thread[thread_ts] = context_token
|
self._context_tokens_by_thread[thread_ts] = context_token
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
handled = await self._bind_connection_from_connect_code(
|
||||||
|
chat_id=chat_id,
|
||||||
|
context_token=context_token,
|
||||||
|
code=connect_code,
|
||||||
|
)
|
||||||
|
if handled:
|
||||||
|
return
|
||||||
|
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=chat_id,
|
user_id=chat_id,
|
||||||
text=text,
|
text=text,
|
||||||
msg_type=InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT,
|
msg_type=InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT,
|
||||||
thread_ts=thread_ts,
|
thread_ts=thread_ts,
|
||||||
files=files,
|
files=files,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -631,8 +644,54 @@ class WechatChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = None
|
inbound.topic_id = None
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="wechat",
|
||||||
|
workspace_id=inbound.chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not chat_id:
|
||||||
|
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="wechat",
|
||||||
|
external_account_id=chat_id,
|
||||||
|
workspace_id=chat_id,
|
||||||
|
metadata={
|
||||||
|
"context_token": context_token,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
|
||||||
|
if not context_token:
|
||||||
|
return
|
||||||
|
await self._send_text_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
context_token=context_token,
|
||||||
|
text=text,
|
||||||
|
client_id_prefix="deerflow-connect",
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
async def _ensure_authenticated(self) -> bool:
|
async def _ensure_authenticated(self) -> bool:
|
||||||
async with self._auth_lock:
|
async with self._auth_lock:
|
||||||
if self._bot_token:
|
if self._bot_token:
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from collections.abc import Awaitable, Callable
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import (
|
from app.channels.message_bus import (
|
||||||
|
InboundMessage,
|
||||||
InboundMessageType,
|
InboundMessageType,
|
||||||
MessageBus,
|
MessageBus,
|
||||||
OutboundMessage,
|
OutboundMessage,
|
||||||
@@ -28,6 +31,7 @@ class WeComChannel(Channel):
|
|||||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||||
self._ws_stream_ids: dict[str, str] = {}
|
self._ws_stream_ids: dict[str, str] = {}
|
||||||
self._working_message = "Working on it..."
|
self._working_message = "Working on it..."
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
@@ -78,12 +82,33 @@ class WeComChannel(Channel):
|
|||||||
self._ws_client.on("message.mixed", self._on_ws_mixed)
|
self._ws_client.on("message.mixed", self._on_ws_mixed)
|
||||||
self._ws_client.on("message.image", self._on_ws_image)
|
self._ws_client.on("message.image", self._on_ws_image)
|
||||||
self._ws_client.on("message.file", self._on_ws_file)
|
self._ws_client.on("message.file", self._on_ws_file)
|
||||||
|
self._ws_client.on("error", self._on_ws_error)
|
||||||
|
self._ws_client.on("disconnected", self._on_ws_disconnected)
|
||||||
self._ws_task = asyncio.create_task(self._ws_client.connect())
|
self._ws_task = asyncio.create_task(self._ws_client.connect())
|
||||||
|
self._ws_task.add_done_callback(self._on_ws_task_done)
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
self.bus.subscribe_outbound(self._on_outbound)
|
||||||
logger.info("WeCom channel started")
|
logger.info("WeCom channel started")
|
||||||
|
|
||||||
|
def _on_ws_task_done(self, task: asyncio.Task) -> None:
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc is None:
|
||||||
|
return
|
||||||
|
logger.error(
|
||||||
|
"WeCom WebSocket connection task failed: %s. Check that the network/proxy allows wss://openws.work.weixin.qq.com and that bot_id/bot_secret are valid.",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_ws_error(self, error: Any) -> None:
|
||||||
|
logger.error("WeCom WebSocket error: %s", error)
|
||||||
|
|
||||||
|
def _on_ws_disconnected(self, *args: Any) -> None:
|
||||||
|
detail = f" ({args[0]})" if args else ""
|
||||||
|
logger.warning("WeCom WebSocket disconnected%s; SDK will attempt to reconnect", detail)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||||
@@ -270,7 +295,17 @@ class WeComChannel(Channel):
|
|||||||
|
|
||||||
user_id = (body.get("from") or {}).get("userid")
|
user_id = (body.get("from") or {}).get("userid")
|
||||||
|
|
||||||
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
handled = await self._bind_connection_from_connect_code(
|
||||||
|
frame=frame,
|
||||||
|
user_id=str(user_id or ""),
|
||||||
|
code=connect_code,
|
||||||
|
)
|
||||||
|
if handled:
|
||||||
|
return
|
||||||
|
|
||||||
|
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=user_id, # keep user's conversation in memory
|
chat_id=user_id, # keep user's conversation in memory
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -291,8 +326,52 @@ class WeComChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="wecom",
|
||||||
|
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||||||
|
fallback_without_workspace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
body = frame.get("body", {}) or {}
|
||||||
|
workspace_id = str(body.get("aibotid") or "") or None
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="wecom",
|
||||||
|
external_account_id=user_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
metadata={
|
||||||
|
"aibotid": workspace_id,
|
||||||
|
"chattype": body.get("chattype"),
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||||||
|
if not self._ws_client:
|
||||||
|
return
|
||||||
|
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||||||
|
|
||||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||||
if not self._ws_client:
|
if not self._ws_client:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
from app.gateway.auth_middleware import AuthMiddleware
|
from app.gateway.auth_middleware import AuthMiddleware
|
||||||
from app.gateway.config import get_gateway_config
|
from app.gateway.config import get_gateway_config
|
||||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
||||||
@@ -15,6 +16,7 @@ from app.gateway.routers import (
|
|||||||
artifacts,
|
artifacts,
|
||||||
assistants_compat,
|
assistants_compat,
|
||||||
auth,
|
auth,
|
||||||
|
channel_connections,
|
||||||
channels,
|
channels,
|
||||||
feedback,
|
feedback,
|
||||||
mcp,
|
mcp,
|
||||||
@@ -172,6 +174,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
startup_config = get_app_config()
|
startup_config = get_app_config()
|
||||||
apply_logging_level(startup_config.log_level)
|
apply_logging_level(startup_config.log_level)
|
||||||
logger.info("Configuration loaded successfully")
|
logger.info("Configuration loaded successfully")
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
@@ -179,6 +182,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
config = get_gateway_config()
|
config = get_gateway_config()
|
||||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||||
|
|
||||||
|
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
||||||
|
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
||||||
|
# that may be unreachable in restricted networks — see issue #3402).
|
||||||
|
# When memory.token_counting is "char", token counting never touches
|
||||||
|
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
|
||||||
|
# network-restricted deployments — see issue #3429).
|
||||||
|
if startup_config.memory.token_counting == "char":
|
||||||
|
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||||
|
|
||||||
|
warmed = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(warm_tiktoken_cache),
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if warmed:
|
||||||
|
logger.info("tiktoken encoding cache warmed successfully")
|
||||||
|
else:
|
||||||
|
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||||
async with langgraph_runtime(app, startup_config):
|
async with langgraph_runtime(app, startup_config):
|
||||||
logger.info("LangGraph runtime initialised")
|
logger.info("LangGraph runtime initialised")
|
||||||
@@ -357,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
|||||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||||
app.include_router(suggestions.router)
|
app.include_router(suggestions.router)
|
||||||
|
|
||||||
|
# User-facing IM channel connection API is mounted at /api/channels
|
||||||
|
app.include_router(channel_connections.router)
|
||||||
|
|
||||||
# Channels API is mounted at /api/channels
|
# Channels API is mounted at /api/channels
|
||||||
app.include_router(channels.router)
|
app.include_router(channels.router)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Shared helpers for local/E2E auth-disabled mode."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
|
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
|
||||||
|
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
|
||||||
|
AUTH_DISABLED_USER_EMAIL = "default@test.local"
|
||||||
|
|
||||||
|
AUTH_SOURCE_SESSION = "session"
|
||||||
|
AUTH_SOURCE_INTERNAL = "internal"
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
|
||||||
|
|
||||||
|
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
|
||||||
|
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_explicit_production_environment() -> bool:
|
||||||
|
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_disabled_requested() -> bool:
|
||||||
|
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_disabled() -> bool:
|
||||||
|
return is_auth_disabled_requested() and not is_explicit_production_environment()
|
||||||
|
|
||||||
|
|
||||||
|
def warn_if_auth_disabled_enabled() -> None:
|
||||||
|
if not is_auth_disabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
|
||||||
|
AUTH_DISABLED_ENV_VAR,
|
||||||
|
AUTH_DISABLED_USER_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_disabled_user():
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=AUTH_DISABLED_USER_ID,
|
||||||
|
email=AUTH_DISABLED_USER_EMAIL,
|
||||||
|
password_hash=None,
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=False,
|
||||||
|
token_version=0,
|
||||||
|
)
|
||||||
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
|
|||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||||
|
from app.gateway.auth_disabled import (
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED,
|
||||||
|
AUTH_SOURCE_INTERNAL,
|
||||||
|
AUTH_SOURCE_SESSION,
|
||||||
|
get_auth_disabled_user,
|
||||||
|
is_auth_disabled,
|
||||||
|
)
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
@@ -80,8 +87,38 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
||||||
internal_user = get_internal_user()
|
internal_user = get_internal_user()
|
||||||
|
|
||||||
|
auth_source = AUTH_SOURCE_SESSION
|
||||||
|
access_token = request.cookies.get("access_token")
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
# Non-public path: require session cookie
|
||||||
if internal_user is None and not request.cookies.get("access_token"):
|
if internal_user is not None:
|
||||||
|
user = internal_user
|
||||||
|
auth_source = AUTH_SOURCE_INTERNAL
|
||||||
|
elif access_token:
|
||||||
|
# Strict JWT validation: reject junk/expired tokens with 401
|
||||||
|
# right here instead of silently passing through. This closes
|
||||||
|
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||||
|
# without this, non-isolation routes like /api/models would
|
||||||
|
# accept any cookie-shaped string as authentication.
|
||||||
|
#
|
||||||
|
# We call the *strict* resolver so that fine-grained error
|
||||||
|
# codes (token_expired, token_invalid, user_not_found, …)
|
||||||
|
# propagate from AuthErrorCode, not get flattened into one
|
||||||
|
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||||
|
# bubble up, so we catch and render it as JSONResponse here.
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if not is_auth_disabled():
|
||||||
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
|
user = get_auth_disabled_user()
|
||||||
|
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
elif is_auth_disabled():
|
||||||
|
user = get_auth_disabled_user()
|
||||||
|
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
else:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
content={
|
content={
|
||||||
@@ -92,32 +129,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Strict JWT validation: reject junk/expired tokens with 401
|
|
||||||
# right here instead of silently passing through. This closes
|
|
||||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
|
||||||
# without this, non-isolation routes like /api/models would
|
|
||||||
# accept any cookie-shaped string as authentication.
|
|
||||||
#
|
|
||||||
# We call the *strict* resolver so that fine-grained error
|
|
||||||
# codes (token_expired, token_invalid, user_not_found, …)
|
|
||||||
# propagate from AuthErrorCode, not get flattened into one
|
|
||||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
|
||||||
# bubble up, so we catch and render it as JSONResponse here.
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
if internal_user is not None:
|
|
||||||
user = internal_user
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
except HTTPException as exc:
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
# Stamp both request.state.user (for the contextvar pattern)
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
# and request.state.auth (so @require_permission's "auth is
|
||||||
# None" branch short-circuits instead of running the entire
|
# None" branch short-circuits instead of running the entire
|
||||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
# JWT-decode + DB-lookup pipeline a second time per request).
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
|
request.state.auth_source = auth_source
|
||||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||||
token = set_current_user(user)
|
token = set_current_user(user)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -276,6 +276,8 @@ def require_permission(
|
|||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* user_id triggers 404.
|
# row with a *different* user_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||||
|
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
@@ -288,6 +290,22 @@ def require_permission(
|
|||||||
str(auth.user.id),
|
str(auth.user.id),
|
||||||
require_existing=require_existing,
|
require_existing=require_existing,
|
||||||
)
|
)
|
||||||
|
if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||||
|
# Trusted internal callers (channel workers) also act for
|
||||||
|
# the connection owner carried in X-DeerFlow-Owner-User-Id.
|
||||||
|
# Scope the check to that owner instead of bypassing it; a
|
||||||
|
# leaked internal token must not grant cross-user thread
|
||||||
|
# access. The header is honored only after ``auth`` proved
|
||||||
|
# the caller holds the internal token (mirrors
|
||||||
|
# get_trusted_internal_owner_user_id, which keys off the
|
||||||
|
# middleware-stamped ``request.state.user``).
|
||||||
|
header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip()
|
||||||
|
if header_owner:
|
||||||
|
allowed = await thread_store.check_access(
|
||||||
|
thread_id,
|
||||||
|
header_owner,
|
||||||
|
require_existing=require_existing,
|
||||||
|
)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.gateway.auth_disabled import is_auth_disabled
|
||||||
|
|
||||||
CSRF_COOKIE_NAME = "csrf_token"
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||||
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
|
|||||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return False
|
||||||
|
|
||||||
path = request.url.path.rstrip("/")
|
path = request.url.path.rstrip("/")
|
||||||
# Exempt /api/v1/auth/me endpoint
|
# Exempt /api/v1/auth/me endpoint
|
||||||
if path == "/api/v1/auth/me":
|
if path == "/api/v1/auth/me":
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
@@ -33,6 +34,43 @@ from deerflow.runtime.runs.store.base import RunStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Upper bound (seconds) for draining in-flight runs during shutdown, before the
|
||||||
|
# AsyncExitStack tears down the checkpointer (and its connection pool). Kept
|
||||||
|
# local to avoid an app -> deps -> app import cycle. This is a *separate* budget
|
||||||
|
# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s,
|
||||||
|
# which bounds channel-service stop): the two govern independent teardown steps
|
||||||
|
# and may diverge, but both count toward the lifespan shutdown window — revisit
|
||||||
|
# them together if their sum must stay within the server's graceful-shutdown
|
||||||
|
# timeout.
|
||||||
|
_RUN_DRAIN_TIMEOUT_SECONDS = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
async def _drain_inflight_runs(run_manager: RunManager) -> None:
|
||||||
|
"""Drain in-flight runs before the checkpointer is torn down (issue #3373).
|
||||||
|
|
||||||
|
Shields the (internally-bounded) drain so that even if the lifespan
|
||||||
|
coroutine is itself cancelled mid-shutdown — a second SIGINT or the server's
|
||||||
|
graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the
|
||||||
|
checkpointer pool is not closed while run tasks are still writing
|
||||||
|
checkpoints. On such a cancellation we let the already-running drain finish
|
||||||
|
(it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate
|
||||||
|
the cancellation.
|
||||||
|
"""
|
||||||
|
drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS))
|
||||||
|
try:
|
||||||
|
await asyncio.shield(drain)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Re-shield so this second wait does not abandon the in-flight drain;
|
||||||
|
# it is bounded, so this cannot hang. Then re-raise to honour shutdown.
|
||||||
|
try:
|
||||||
|
await asyncio.shield(drain)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("In-flight run drain failed after shutdown cancellation")
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to drain in-flight runs during shutdown")
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
@@ -81,6 +119,16 @@ def get_config() -> AppConfig:
|
|||||||
split-brain where the worker / lead-agent thread saw a stale startup
|
split-brain where the worker / lead-agent thread saw a stale startup
|
||||||
snapshot.
|
snapshot.
|
||||||
|
|
||||||
|
Hot-reload boundary: fields backed by startup-time singletons
|
||||||
|
(engines, sandbox provider, IM channels, logging handler) require a
|
||||||
|
process restart to change at runtime. The authoritative list lives in
|
||||||
|
:mod:`deerflow.config.reload_boundary` and is mirrored by the
|
||||||
|
standardised ``"startup-only:"`` prefix on the matching
|
||||||
|
``Field(description=...)`` in :class:`AppConfig` — IDE hover on those
|
||||||
|
fields will surface the boundary inline. See
|
||||||
|
``backend/CLAUDE.md`` "Config Hot-Reload Boundary" for the operator
|
||||||
|
summary.
|
||||||
|
|
||||||
Any failure to materialise the config (missing file, permission denied,
|
Any failure to materialise the config (missing file, permission denied,
|
||||||
YAML parse error, validation error) is reported as 503 — semantically
|
YAML parse error, validation error) is reported as 503 — semantically
|
||||||
"the gateway cannot serve requests without a usable configuration" — and
|
"the gateway cannot serve requests without a usable configuration" — and
|
||||||
@@ -177,6 +225,14 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
|
|||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
# Drain in-flight run tasks BEFORE the AsyncExitStack tears down the
|
||||||
|
# checkpointer (and its connection pool). A run still mid-graph would
|
||||||
|
# otherwise leak into asyncio.run() shutdown, where langgraph's
|
||||||
|
# _checkpointer_put_after_previous aput races the closed pool and
|
||||||
|
# raises PoolClosed (issue #3373).
|
||||||
|
run_manager = getattr(app.state, "run_manager", None)
|
||||||
|
if run_manager is not None:
|
||||||
|
await _drain_inflight_runs(run_manager)
|
||||||
await close_engine()
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
@@ -275,6 +331,17 @@ async def get_current_user_from_request(request: Request):
|
|||||||
|
|
||||||
Raises HTTPException 401 if not authenticated.
|
Raises HTTPException 401 if not authenticated.
|
||||||
"""
|
"""
|
||||||
|
state = getattr(request, "state", None)
|
||||||
|
state_user = getattr(state, "user", None)
|
||||||
|
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
|
||||||
|
|
||||||
|
if state_user is not None and getattr(state, "auth_source", None) in {
|
||||||
|
AUTH_SOURCE_SESSION,
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED,
|
||||||
|
AUTH_SOURCE_INTERNAL,
|
||||||
|
}:
|
||||||
|
return state_user
|
||||||
|
|
||||||
from app.gateway.auth import decode_token
|
from app.gateway.auth import decode_token
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||||
|
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
||||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||||
|
INTERNAL_SYSTEM_ROLE = "internal"
|
||||||
|
|
||||||
|
|
||||||
def _load_internal_auth_token() -> str:
|
def _load_internal_auth_token() -> str:
|
||||||
@@ -22,9 +25,12 @@ def _load_internal_auth_token() -> str:
|
|||||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||||
|
|
||||||
|
|
||||||
def create_internal_auth_headers() -> dict[str, str]:
|
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
||||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||||
|
if owner_user_id:
|
||||||
|
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||||
@@ -34,4 +40,22 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
|||||||
|
|
||||||
def get_internal_user():
|
def get_internal_user():
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
"""Return the synthetic user used for trusted internal channel calls."""
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
|
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
||||||
|
"""Return the owner override for a trusted internal request, if present.
|
||||||
|
|
||||||
|
The header is ignored for normal browser/API callers. It is only honored
|
||||||
|
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
||||||
|
the synthetic internal user onto ``request.state.user``.
|
||||||
|
"""
|
||||||
|
user = getattr(getattr(request, "state", None), "user", None)
|
||||||
|
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
||||||
|
if not owner_user_id:
|
||||||
|
return None
|
||||||
|
owner_user_id = owner_user_id.strip()
|
||||||
|
return owner_user_id or None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
|
|||||||
|
|
||||||
from app.gateway.auth.errors import TokenError
|
from app.gateway.auth.errors import TokenError
|
||||||
from app.gateway.auth.jwt import decode_token
|
from app.gateway.auth.jwt import decode_token
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||||
from app.gateway.deps import get_local_provider
|
from app.gateway.deps import get_local_provider
|
||||||
|
|
||||||
auth = Auth()
|
auth = Auth()
|
||||||
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
|
|||||||
if method.upper() not in _CSRF_METHODS:
|
if method.upper() not in _CSRF_METHODS:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return
|
||||||
|
|
||||||
cookie_token = request.cookies.get("csrf_token")
|
cookie_token = request.cookies.get("csrf_token")
|
||||||
header_token = request.headers.get("x-csrf-token")
|
header_token = request.headers.get("x-csrf-token")
|
||||||
|
|
||||||
@@ -66,6 +70,9 @@ async def authenticate(request):
|
|||||||
# are rejected early, even if the cookie carries a valid JWT.
|
# are rejected early, even if the cookie carries a valid JWT.
|
||||||
_check_csrf(request)
|
_check_csrf(request)
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
token = request.cookies.get("access_token")
|
token = request.cookies.get("access_token")
|
||||||
if not token:
|
if not token:
|
||||||
raise Auth.exceptions.HTTPException(
|
raise Auth.exceptions.HTTPException(
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
"""Shared pagination helpers for gateway routers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def trim_run_message_page(rows: list[dict], *, limit: int, after_seq: int | None) -> tuple[list[dict], bool]:
|
||||||
|
"""Trim a ``limit + 1`` run-message page while preserving page boundaries."""
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
if not has_more:
|
||||||
|
return rows, False
|
||||||
|
|
||||||
|
if after_seq is not None:
|
||||||
|
return rows[:limit], True
|
||||||
|
|
||||||
|
return rows[-limit:], True
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""CRUD API for custom agents."""
|
"""CRUD API for custom agents."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -213,48 +214,61 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
user_id = get_effective_user_id()
|
user_id = get_effective_user_id()
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
|
|
||||||
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
def _create_agent() -> AgentResponse | None:
|
||||||
legacy_dir = paths.agent_dir(normalized_name)
|
# Worker thread: base-dir resolution, existence checks, directory/file
|
||||||
|
# creation, read-back, and failure cleanup are all blocking filesystem
|
||||||
|
# IO that must stay off the event loop.
|
||||||
|
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
||||||
|
legacy_dir = paths.agent_dir(normalized_name)
|
||||||
|
|
||||||
if agent_dir.exists() or legacy_dir.exists():
|
if legacy_dir.exists():
|
||||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
return None # signals 409 to the caller
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
agent_dir.mkdir(parents=True, exist_ok=False)
|
||||||
|
except FileExistsError:
|
||||||
|
return None # signals 409 to the caller
|
||||||
|
# Write config.yaml
|
||||||
|
config_data: dict = {"name": normalized_name}
|
||||||
|
if request.description:
|
||||||
|
config_data["description"] = request.description
|
||||||
|
if request.model is not None:
|
||||||
|
config_data["model"] = request.model
|
||||||
|
if request.tool_groups is not None:
|
||||||
|
config_data["tool_groups"] = request.tool_groups
|
||||||
|
if request.skills is not None:
|
||||||
|
config_data["skills"] = request.skills
|
||||||
|
|
||||||
|
config_file = agent_dir / "config.yaml"
|
||||||
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
||||||
|
|
||||||
|
# Write SOUL.md
|
||||||
|
soul_file = agent_dir / "SOUL.md"
|
||||||
|
soul_file.write_text(request.soul, encoding="utf-8")
|
||||||
|
|
||||||
|
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
||||||
|
|
||||||
|
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
||||||
|
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
||||||
|
except Exception:
|
||||||
|
# Clean up partial state on failure before surfacing the error.
|
||||||
|
if agent_dir.exists():
|
||||||
|
shutil.rmtree(agent_dir)
|
||||||
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
response = await asyncio.to_thread(_create_agent)
|
||||||
|
|
||||||
# Write config.yaml
|
|
||||||
config_data: dict = {"name": normalized_name}
|
|
||||||
if request.description:
|
|
||||||
config_data["description"] = request.description
|
|
||||||
if request.model is not None:
|
|
||||||
config_data["model"] = request.model
|
|
||||||
if request.tool_groups is not None:
|
|
||||||
config_data["tool_groups"] = request.tool_groups
|
|
||||||
if request.skills is not None:
|
|
||||||
config_data["skills"] = request.skills
|
|
||||||
|
|
||||||
config_file = agent_dir / "config.yaml"
|
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
|
||||||
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
|
||||||
|
|
||||||
# Write SOUL.md
|
|
||||||
soul_file = agent_dir / "SOUL.md"
|
|
||||||
soul_file.write_text(request.soul, encoding="utf-8")
|
|
||||||
|
|
||||||
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
|
||||||
|
|
||||||
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
|
||||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clean up on failure
|
|
||||||
if agent_dir.exists():
|
|
||||||
shutil.rmtree(agent_dir)
|
|
||||||
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
|
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/agents/{name}",
|
"/agents/{name}",
|
||||||
@@ -428,19 +442,30 @@ async def delete_agent(name: str) -> None:
|
|||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
user_id = get_effective_user_id()
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
agent_dir = paths.user_agent_dir(user_id, name)
|
|
||||||
|
|
||||||
if not agent_dir.exists():
|
def _remove_agent_dir() -> tuple[str, str]:
|
||||||
if paths.agent_dir(name).exists():
|
# Runs in a worker thread: resolving the base dir, probing the directory
|
||||||
raise HTTPException(
|
# (`exists`), and removing it (`rmtree`) are all blocking filesystem IO
|
||||||
status_code=409,
|
# that must stay off the event loop.
|
||||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
agent_dir = paths.user_agent_dir(user_id, name)
|
||||||
)
|
if not agent_dir.exists():
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
outcome = "legacy" if paths.agent_dir(name).exists() else "missing"
|
||||||
|
return outcome, str(agent_dir)
|
||||||
|
shutil.rmtree(agent_dir)
|
||||||
|
return "deleted", str(agent_dir)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(agent_dir)
|
outcome, agent_dir = await asyncio.to_thread(_remove_agent_dir)
|
||||||
logger.info(f"Deleted agent '{name}' from {agent_dir}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
|
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
|
||||||
|
|
||||||
|
if outcome == "legacy":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
||||||
|
)
|
||||||
|
if outcome == "missing":
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
|
|
||||||
|
logger.info(f"Deleted agent '{name}' from {agent_dir}")
|
||||||
|
|||||||
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
|
|||||||
- Re-issues session cookie with new token_version
|
- Re-issues session cookie with new token_version
|
||||||
"""
|
"""
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||||
|
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
|
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
if user.password_hash is None:
|
if user.password_hash is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,670 @@
|
|||||||
|
"""Browser-facing APIs for user-owned IM channel bindings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request, Response
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.channels.runtime_config_store import (
|
||||||
|
ChannelRuntimeConfigStore,
|
||||||
|
apply_runtime_connection_config,
|
||||||
|
merge_runtime_channel_configs,
|
||||||
|
)
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
|
from deerflow.persistence.engine import get_session_factory
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STATE_TTL_SECONDS = 600
|
||||||
|
_MASKED_CREDENTIAL_VALUE = "********"
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelCredentialFieldResponse(BaseModel):
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
type: str = "text"
|
||||||
|
required: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelProviderResponse(BaseModel):
|
||||||
|
provider: str
|
||||||
|
display_name: str
|
||||||
|
enabled: bool
|
||||||
|
configured: bool
|
||||||
|
connectable: bool
|
||||||
|
unavailable_reason: str | None = None
|
||||||
|
auth_mode: str
|
||||||
|
connection_status: str
|
||||||
|
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
|
||||||
|
credential_values: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelProvidersResponse(BaseModel):
|
||||||
|
enabled: bool
|
||||||
|
providers: list[ChannelProviderResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
provider: str
|
||||||
|
status: str
|
||||||
|
external_account_id: str | None = None
|
||||||
|
external_account_name: str | None = None
|
||||||
|
workspace_id: str | None = None
|
||||||
|
workspace_name: str | None = None
|
||||||
|
scopes: list[str] = Field(default_factory=list)
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionsResponse(BaseModel):
|
||||||
|
connections: list[ChannelConnectionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectResponse(BaseModel):
|
||||||
|
provider: str
|
||||||
|
mode: str
|
||||||
|
url: str | None = None
|
||||||
|
code: str
|
||||||
|
instruction: str
|
||||||
|
expires_in: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelRuntimeConfigRequest(BaseModel):
|
||||||
|
values: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDER_META: dict[str, dict[str, str]] = {
|
||||||
|
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
||||||
|
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
||||||
|
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
||||||
|
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
|
||||||
|
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
|
||||||
|
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
|
||||||
|
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
|
||||||
|
"telegram": (
|
||||||
|
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||||
|
{"name": "bot_username", "label": "Bot username", "type": "text"},
|
||||||
|
),
|
||||||
|
"slack": (
|
||||||
|
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||||
|
{"name": "app_token", "label": "App token", "type": "password"},
|
||||||
|
),
|
||||||
|
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||||
|
"feishu": (
|
||||||
|
{"name": "app_id", "label": "App ID", "type": "text"},
|
||||||
|
{"name": "app_secret", "label": "App secret", "type": "password"},
|
||||||
|
),
|
||||||
|
"dingtalk": (
|
||||||
|
{"name": "client_id", "label": "Client ID", "type": "text"},
|
||||||
|
{"name": "client_secret", "label": "Client secret", "type": "password"},
|
||||||
|
),
|
||||||
|
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||||
|
"wecom": (
|
||||||
|
{"name": "bot_id", "label": "Bot ID", "type": "text"},
|
||||||
|
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
||||||
|
"telegram": ("bot_token",),
|
||||||
|
"slack": ("bot_token", "app_token"),
|
||||||
|
"discord": ("bot_token",),
|
||||||
|
"feishu": ("app_id", "app_secret"),
|
||||||
|
"dingtalk": ("client_id", "client_secret"),
|
||||||
|
"wechat": ("bot_token",),
|
||||||
|
"wecom": ("bot_id", "bot_secret"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id(request: Request) -> str:
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
return str(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _require_admin_user(request: Request) -> None:
|
||||||
|
"""Require an admin caller for instance-wide channel runtime mutations.
|
||||||
|
|
||||||
|
Runtime credentials and the channel workers they start/stop are shared by
|
||||||
|
every user of the deployment, so only admins may change them (same model
|
||||||
|
as the MCP config API). Auth-disabled local mode uses a synthetic admin
|
||||||
|
user and is unaffected.
|
||||||
|
"""
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
if user is None:
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
|
if getattr(user, "system_role", None) != "admin":
|
||||||
|
raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_app_config():
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
|
return get_app_config()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
||||||
|
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
||||||
|
if isinstance(store, ChannelRuntimeConfigStore):
|
||||||
|
return store
|
||||||
|
# Constructing the store reads its JSON file from disk; keep it off the
|
||||||
|
# event loop.
|
||||||
|
store = await asyncio.to_thread(ChannelRuntimeConfigStore)
|
||||||
|
request.app.state.channel_runtime_config_store = store
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||||
|
config = getattr(request.app.state, "channel_connections_config", None)
|
||||||
|
if not isinstance(config, ChannelConnectionsConfig):
|
||||||
|
config = _get_app_config().channel_connections
|
||||||
|
config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request))
|
||||||
|
request.app.state.channel_connections_config = config
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||||
|
state_config = getattr(request.app.state, "channels_config", None)
|
||||||
|
if isinstance(state_config, dict):
|
||||||
|
return state_config
|
||||||
|
|
||||||
|
result = await _load_channels_config(request, await _get_channel_connections_config(request))
|
||||||
|
request.app.state.channels_config = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
||||||
|
app_config = _get_app_config()
|
||||||
|
extra = app_config.model_extra or {}
|
||||||
|
channels_config = extra.get("channels")
|
||||||
|
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||||
|
merge_runtime_channel_configs(
|
||||||
|
result,
|
||||||
|
config,
|
||||||
|
store=await _get_runtime_config_store(request),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||||
|
repo = getattr(request.app.state, "channel_connection_repo", None)
|
||||||
|
if isinstance(repo, ChannelConnectionRepository):
|
||||||
|
return repo
|
||||||
|
|
||||||
|
sf = get_session_factory()
|
||||||
|
if sf is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||||
|
|
||||||
|
repo = ChannelConnectionRepository(sf)
|
||||||
|
request.app.state.channel_connection_repo = repo
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||||
|
provider_config = getattr(config, provider, None)
|
||||||
|
if provider_config is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
return provider_config
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||||
|
return False
|
||||||
|
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_unavailable_reason(provider: str) -> str:
|
||||||
|
meta = _PROVIDER_META.get(provider)
|
||||||
|
display_name = meta["display_name"] if meta else provider
|
||||||
|
return f"Enter the required {display_name} credentials to connect this channel."
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_not_running_reason(provider: str) -> str:
|
||||||
|
meta = _PROVIDER_META.get(provider)
|
||||||
|
display_name = meta["display_name"] if meta else provider
|
||||||
|
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_channel_running(provider: str) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to inspect channel service status", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
status = service.get_status()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to read channel service status", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not status.get("service_running"):
|
||||||
|
return False
|
||||||
|
channel_status = status.get("channels", {}).get(provider)
|
||||||
|
if not isinstance(channel_status, dict):
|
||||||
|
return None
|
||||||
|
return bool(channel_status.get("running"))
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_runtime_channel_ready_if_available(
|
||||||
|
provider: str,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
) -> bool | None:
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
|
||||||
|
if ensure_channel_ready is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await ensure_channel_ready(provider, runtime_config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to reconcile runtime channel readiness")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_unavailable_reason(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
provider: str,
|
||||||
|
) -> str | None:
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
return None
|
||||||
|
if not provider_config.configured:
|
||||||
|
return _runtime_unavailable_reason(provider)
|
||||||
|
if not _runtime_channel_configured(provider, channels_config):
|
||||||
|
return _runtime_unavailable_reason(provider)
|
||||||
|
if _runtime_channel_running(provider) is False:
|
||||||
|
return _runtime_not_running_reason(provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_status(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
provider: str,
|
||||||
|
) -> tuple[dict[str, bool], str | None]:
|
||||||
|
declared = config.provider_status(provider)
|
||||||
|
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
||||||
|
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
||||||
|
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
||||||
|
|
||||||
|
|
||||||
|
def _new_binding_code() -> str:
|
||||||
|
return secrets.token_urlsafe(16)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_state(
|
||||||
|
repo: ChannelConnectionRepository,
|
||||||
|
*,
|
||||||
|
owner_user_id: str,
|
||||||
|
provider: str,
|
||||||
|
) -> str:
|
||||||
|
state = _new_binding_code()
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider=provider,
|
||||||
|
state=state,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _connect_instruction(provider: str, code: str) -> str:
|
||||||
|
if provider == "telegram":
|
||||||
|
return f"Send /start {code} to the DeerFlow Telegram bot."
|
||||||
|
meta = _PROVIDER_META.get(provider)
|
||||||
|
if meta is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
|
||||||
|
|
||||||
|
|
||||||
|
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
||||||
|
if provider == "telegram":
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
||||||
|
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
|
||||||
|
return None
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
|
||||||
|
|
||||||
|
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
||||||
|
value = connection.get("updated_at")
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return datetime.min.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||||
|
by_provider: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in connections:
|
||||||
|
existing = by_provider.get(item["provider"])
|
||||||
|
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
||||||
|
by_provider[item["provider"]] = item
|
||||||
|
return by_provider
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
|
||||||
|
fields = _CREDENTIAL_FIELDS.get(provider)
|
||||||
|
if fields is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
return [ChannelCredentialFieldResponse(**field) for field in fields]
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if not isinstance(runtime_config, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
values: dict[str, str] = {}
|
||||||
|
for field in _credential_fields(provider):
|
||||||
|
value = str(runtime_config.get(field.name) or "").strip()
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_response(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
provider: str,
|
||||||
|
meta: dict[str, str],
|
||||||
|
connection: dict[str, Any] | None = None,
|
||||||
|
) -> ChannelProviderResponse:
|
||||||
|
from app.gateway.auth_disabled import is_auth_disabled
|
||||||
|
|
||||||
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||||
|
if connection:
|
||||||
|
connection_status = connection["status"]
|
||||||
|
elif is_auth_disabled() and status["configured"] and unavailable_reason is None:
|
||||||
|
# Auth-disabled local mode routes every channel message to the default
|
||||||
|
# user, so a configured running channel needs no per-user binding.
|
||||||
|
connection_status = "connected"
|
||||||
|
else:
|
||||||
|
connection_status = "not_connected"
|
||||||
|
credential_values = _credential_values(provider, channels_config)
|
||||||
|
if provider == "telegram" and not credential_values.get("bot_username"):
|
||||||
|
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
|
||||||
|
if bot_username:
|
||||||
|
credential_values["bot_username"] = bot_username
|
||||||
|
return ChannelProviderResponse(
|
||||||
|
provider=provider,
|
||||||
|
display_name=meta["display_name"],
|
||||||
|
enabled=status["enabled"],
|
||||||
|
configured=status["configured"],
|
||||||
|
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
||||||
|
unavailable_reason=unavailable_reason,
|
||||||
|
auth_mode=meta["auth_mode"],
|
||||||
|
connection_status=connection_status,
|
||||||
|
credential_fields=_credential_fields(provider),
|
||||||
|
credential_values=credential_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _required_runtime_values(
|
||||||
|
provider: str,
|
||||||
|
values: dict[str, str],
|
||||||
|
existing_config: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
fields = _credential_fields(provider)
|
||||||
|
cleaned: dict[str, str] = {}
|
||||||
|
missing: list[str] = []
|
||||||
|
existing_config = existing_config or {}
|
||||||
|
for field in fields:
|
||||||
|
raw_value = values.get(field.name, "")
|
||||||
|
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
|
||||||
|
existing_value = str(existing_config.get(field.name) or "").strip()
|
||||||
|
if existing_value:
|
||||||
|
cleaned[field.name] = existing_value
|
||||||
|
continue
|
||||||
|
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
|
||||||
|
if field.required and not value:
|
||||||
|
missing.append(field.label)
|
||||||
|
cleaned[field.name] = value
|
||||||
|
if missing:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel service while configuring a runtime channel")
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
return await service.configure_channel(provider, runtime_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel service while disconnecting a runtime channel")
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
||||||
|
return await service.configure_channel(provider, runtime_config)
|
||||||
|
return await service.remove_channel(provider)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||||
|
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||||
|
config = await _get_channel_connections_config(request)
|
||||||
|
channels_config = await _get_channels_config(request)
|
||||||
|
repo = None
|
||||||
|
if config.enabled:
|
||||||
|
try:
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 503:
|
||||||
|
raise
|
||||||
|
owner_user_id = _get_user_id(request)
|
||||||
|
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||||
|
by_provider = _newest_connection_by_provider(connections)
|
||||||
|
|
||||||
|
enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]]
|
||||||
|
# Readiness reconciliation is independent per provider; run it
|
||||||
|
# concurrently so one slow channel restart does not serialize the
|
||||||
|
# whole /providers response.
|
||||||
|
await asyncio.gather(
|
||||||
|
*(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)),
|
||||||
|
)
|
||||||
|
|
||||||
|
providers: list[ChannelProviderResponse] = []
|
||||||
|
for provider in enabled_providers:
|
||||||
|
connection = by_provider.get(provider)
|
||||||
|
providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection))
|
||||||
|
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
||||||
|
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
||||||
|
config = await _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
return ChannelConnectionsResponse(connections=[])
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
rows = await repo.list_connections(_get_user_id(request))
|
||||||
|
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/connections/{connection_id}", status_code=204)
|
||||||
|
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
||||||
|
config = await _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
disconnected = await repo.disconnect_connection(
|
||||||
|
connection_id=connection_id,
|
||||||
|
owner_user_id=_get_user_id(request),
|
||||||
|
)
|
||||||
|
if not disconnected:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel connection not found")
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||||
|
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
||||||
|
await _require_admin_user(request)
|
||||||
|
config = await _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
|
||||||
|
owner_user_id = _get_user_id(request)
|
||||||
|
try:
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 503:
|
||||||
|
raise
|
||||||
|
repo = None
|
||||||
|
|
||||||
|
if repo is not None:
|
||||||
|
for connection in await repo.list_connections(owner_user_id):
|
||||||
|
if connection["provider"] == provider and connection["status"] != "revoked":
|
||||||
|
await repo.disconnect_connection(
|
||||||
|
connection_id=connection["id"],
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
store = await _get_runtime_config_store(request)
|
||||||
|
await asyncio.to_thread(store.set_provider_disconnected, provider)
|
||||||
|
channels_config = await _load_channels_config(request, config)
|
||||||
|
request.app.state.channels_config = channels_config
|
||||||
|
|
||||||
|
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
||||||
|
if stopped is False:
|
||||||
|
display_name = _PROVIDER_META[provider]["display_name"]
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
||||||
|
|
||||||
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||||
|
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||||
|
config = await _get_channel_connections_config(request)
|
||||||
|
channels_config = await _get_channels_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
|
||||||
|
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||||
|
|
||||||
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||||
|
if not status["enabled"]:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
if unavailable_reason:
|
||||||
|
raise HTTPException(status_code=400, detail=unavailable_reason)
|
||||||
|
if not status["configured"]:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||||
|
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
code = await _create_state(
|
||||||
|
repo,
|
||||||
|
owner_user_id=_get_user_id(request),
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
return ChannelConnectResponse(
|
||||||
|
provider=provider,
|
||||||
|
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||||
|
url=_connect_url(config, provider, code),
|
||||||
|
code=code,
|
||||||
|
instruction=_connect_instruction(provider, code),
|
||||||
|
expires_in=_STATE_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||||
|
async def configure_channel_provider_runtime(
|
||||||
|
provider: str,
|
||||||
|
body: ChannelRuntimeConfigRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> ChannelProviderResponse:
|
||||||
|
await _require_admin_user(request)
|
||||||
|
config = await _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
|
||||||
|
channels_config = await _get_channels_config(request)
|
||||||
|
existing = channels_config.get(provider)
|
||||||
|
runtime_config = dict(existing) if isinstance(existing, dict) else {}
|
||||||
|
values = _required_runtime_values(provider, body.values, runtime_config)
|
||||||
|
runtime_config["enabled"] = True
|
||||||
|
|
||||||
|
for key in _RUNTIME_REQUIREMENTS[provider]:
|
||||||
|
runtime_config[key] = values[key]
|
||||||
|
|
||||||
|
if provider == "telegram":
|
||||||
|
# The deep-link username is persisted with the runtime channel config
|
||||||
|
# (set_provider_config below) and applied to future requests via
|
||||||
|
# apply_runtime_connection_config; never mutate the config instance
|
||||||
|
# cached by get_app_config().
|
||||||
|
runtime_config["bot_username"] = values["bot_username"]
|
||||||
|
|
||||||
|
channels_config[provider] = runtime_config
|
||||||
|
request.app.state.channels_config = channels_config
|
||||||
|
|
||||||
|
started = await _restart_runtime_channel_if_available(provider, runtime_config)
|
||||||
|
if started is False:
|
||||||
|
display_name = _PROVIDER_META[provider]["display_name"]
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
||||||
|
|
||||||
|
store = await _get_runtime_config_store(request)
|
||||||
|
await asyncio.to_thread(store.set_provider_config, provider, runtime_config)
|
||||||
|
|
||||||
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
|
||||||
@@ -12,6 +13,11 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/api", tags=["mcp"])
|
router = APIRouter(prefix="/api", tags=["mcp"])
|
||||||
|
|
||||||
|
|
||||||
|
_MCP_STDIO_COMMAND_ALLOWLIST_ENV = "DEER_FLOW_MCP_STDIO_COMMAND_ALLOWLIST"
|
||||||
|
_DEFAULT_MCP_STDIO_COMMAND_ALLOWLIST = frozenset({"npx", "uvx"})
|
||||||
|
_SHELL_METACHARS = frozenset(";|&`$<>\n\r")
|
||||||
|
|
||||||
|
|
||||||
class McpOAuthConfigResponse(BaseModel):
|
class McpOAuthConfigResponse(BaseModel):
|
||||||
"""OAuth configuration for an MCP server."""
|
"""OAuth configuration for an MCP server."""
|
||||||
|
|
||||||
@@ -66,6 +72,78 @@ class McpConfigUpdateRequest(BaseModel):
|
|||||||
_MASKED_VALUE = "***"
|
_MASKED_VALUE = "***"
|
||||||
|
|
||||||
|
|
||||||
|
async def _require_admin_user(request: Request) -> None:
|
||||||
|
"""Require the authenticated caller to be an admin user.
|
||||||
|
|
||||||
|
``AuthMiddleware`` normally stamps ``request.state.user`` before the
|
||||||
|
request reaches this router. Falling back to the strict dependency keeps
|
||||||
|
this route safe even in tests or alternative ASGI compositions that mount
|
||||||
|
the router without the global middleware.
|
||||||
|
"""
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
if user is None:
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
|
if getattr(user, "system_role", None) != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin privileges required to manage MCP configuration.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _allowed_stdio_commands() -> set[str]:
|
||||||
|
"""Return executable names allowed for API-managed stdio MCP servers."""
|
||||||
|
raw = os.environ.get(_MCP_STDIO_COMMAND_ALLOWLIST_ENV)
|
||||||
|
base = set(_DEFAULT_MCP_STDIO_COMMAND_ALLOWLIST)
|
||||||
|
if raw is None:
|
||||||
|
return base
|
||||||
|
extra = {item.strip() for item in raw.split(",") if item.strip()}
|
||||||
|
return base | extra
|
||||||
|
|
||||||
|
|
||||||
|
def _stdio_command_name(command: str | None, *, server_name: str) -> str:
|
||||||
|
"""Normalize and validate a stdio command field from the API boundary."""
|
||||||
|
if command is None or not command.strip():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"MCP server '{server_name}' with stdio transport requires a command.",
|
||||||
|
)
|
||||||
|
|
||||||
|
stripped = command.strip()
|
||||||
|
has_path_separator = "/" in stripped or "\\" in stripped
|
||||||
|
if stripped != command or has_path_separator or any(ch.isspace() for ch in stripped) or any(ch in stripped for ch in _SHELL_METACHARS):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(f"MCP server '{server_name}' command must be a single executable name; put parameters in args instead."),
|
||||||
|
)
|
||||||
|
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_mcp_update_request(request: McpConfigUpdateRequest) -> None:
|
||||||
|
"""Validate API-submitted MCP config before it is persisted.
|
||||||
|
|
||||||
|
Local config files can still express arbitrary advanced setups, but the
|
||||||
|
HTTP API is an untrusted boundary. Restricting stdio commands here reduces
|
||||||
|
the blast radius of a compromised authenticated browser session.
|
||||||
|
"""
|
||||||
|
allowed_commands = _allowed_stdio_commands()
|
||||||
|
for name, server in request.mcp_servers.items():
|
||||||
|
transport_type = (server.type or "stdio").lower()
|
||||||
|
if transport_type != "stdio":
|
||||||
|
continue
|
||||||
|
|
||||||
|
command_name = _stdio_command_name(server.command, server_name=name)
|
||||||
|
if command_name not in allowed_commands:
|
||||||
|
allowed = ", ".join(sorted(allowed_commands)) or "<none>"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(f"MCP server '{name}' uses disallowed stdio command '{command_name}'. Allowed commands: {allowed}. Configure {_MCP_STDIO_COMMAND_ALLOWLIST_ENV} to extend this list."),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse:
|
def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse:
|
||||||
"""Return a copy of server config with sensitive fields masked.
|
"""Return a copy of server config with sensitive fields masked.
|
||||||
|
|
||||||
@@ -162,7 +240,7 @@ def _merge_preserving_secrets(
|
|||||||
summary="Get MCP Configuration",
|
summary="Get MCP Configuration",
|
||||||
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
|
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
|
||||||
)
|
)
|
||||||
async def get_mcp_configuration() -> McpConfigResponse:
|
async def get_mcp_configuration(request: Request) -> McpConfigResponse:
|
||||||
"""Get the current MCP configuration.
|
"""Get the current MCP configuration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -183,6 +261,8 @@ async def get_mcp_configuration() -> McpConfigResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
await _require_admin_user(request)
|
||||||
|
|
||||||
config = get_extensions_config()
|
config = get_extensions_config()
|
||||||
|
|
||||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()}
|
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()}
|
||||||
@@ -195,7 +275,7 @@ async def get_mcp_configuration() -> McpConfigResponse:
|
|||||||
summary="Update MCP Configuration",
|
summary="Update MCP Configuration",
|
||||||
description="Update Model Context Protocol (MCP) server configurations and save to file.",
|
description="Update Model Context Protocol (MCP) server configurations and save to file.",
|
||||||
)
|
)
|
||||||
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
|
async def update_mcp_configuration(request: Request, body: McpConfigUpdateRequest) -> McpConfigResponse:
|
||||||
"""Update the MCP configuration.
|
"""Update the MCP configuration.
|
||||||
|
|
||||||
This will:
|
This will:
|
||||||
@@ -228,6 +308,9 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
await _require_admin_user(request)
|
||||||
|
_validate_mcp_update_request(body)
|
||||||
|
|
||||||
# Get the current config path (or determine where to save it)
|
# Get the current config path (or determine where to save it)
|
||||||
config_path = ExtensionsConfig.resolve_config_path()
|
config_path = ExtensionsConfig.resolve_config_path()
|
||||||
|
|
||||||
@@ -255,7 +338,7 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
|||||||
|
|
||||||
# Merge incoming server configs with raw on-disk secrets
|
# Merge incoming server configs with raw on-disk secrets
|
||||||
merged_servers: dict[str, McpServerConfigResponse] = {}
|
merged_servers: dict[str, McpServerConfigResponse] = {}
|
||||||
for name, incoming in request.mcp_servers.items():
|
for name, incoming in body.mcp_servers.items():
|
||||||
raw_server = raw_servers.get(name)
|
raw_server = raw_servers.get(name)
|
||||||
if raw_server is not None:
|
if raw_server is not None:
|
||||||
merged_servers[name] = _merge_preserving_secrets(
|
merged_servers[name] = _merge_preserving_secrets(
|
||||||
@@ -283,6 +366,8 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
|||||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
|
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
|
||||||
return McpConfigResponse(mcp_servers=servers)
|
return McpConfigResponse(mcp_servers=servers)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
|
|||||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||||
|
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
|
||||||
|
|
||||||
|
|
||||||
class MemoryStatusResponse(BaseModel):
|
class MemoryStatusResponse(BaseModel):
|
||||||
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
"max_facts": 100,
|
"max_facts": 100,
|
||||||
"fact_confidence_threshold": 0.7,
|
"fact_confidence_threshold": 0.7,
|
||||||
"injection_enabled": true,
|
"injection_enabled": true,
|
||||||
"max_injection_tokens": 2000
|
"max_injection_tokens": 2000,
|
||||||
|
"token_counting": "tiktoken"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
|
token_counting=config.token_counting,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
|
token_counting=config.token_counting,
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,9 +15,10 @@ from fastapi.responses import StreamingResponse
|
|||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
|
from app.gateway.pagination import trim_run_message_page
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values_for_api
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
||||||
@@ -81,7 +82,7 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|||||||
if checkpoint_tuple is not None:
|
if checkpoint_tuple is not None:
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
return serialize_channel_values(channel_values)
|
return serialize_channel_values_for_api(channel_values)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||||
|
|
||||||
@@ -129,8 +130,7 @@ async def run_messages(
|
|||||||
before_seq=before_seq,
|
before_seq=before_seq,
|
||||||
after_seq=after_seq,
|
after_seq=after_seq,
|
||||||
)
|
)
|
||||||
has_more = len(rows) > limit
|
data, has_more = trim_run_message_page(rows, limit=limit, after_seq=after_seq)
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
return {"data": data, "has_more": has_more}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
@@ -30,6 +31,31 @@ class SuggestionsResponse(BaseModel):
|
|||||||
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||||
|
|
||||||
|
|
||||||
|
# Matches a complete <think>...</think> block (case-insensitive, spans newlines).
|
||||||
|
_THINK_BLOCK_RE = re.compile(r"<think\b[^>]*>.*?</think\s*>", re.IGNORECASE | re.DOTALL)
|
||||||
|
# Matches a dangling, unclosed <think> (model truncated at max_tokens mid-thought).
|
||||||
|
_OPEN_THINK_RE = re.compile(r"<think\b[^>]*>", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_think_blocks(text: str) -> str:
|
||||||
|
"""Remove reasoning-model ``<think>...</think>`` blocks from the response.
|
||||||
|
|
||||||
|
Reasoning models such as MiniMax-M3 inline their chain-of-thought into the
|
||||||
|
message ``content`` wrapped in ``<think>...</think>`` (``reasoning_split``
|
||||||
|
defaults to false), rather than exposing a separate ``reasoning_content``
|
||||||
|
field. The thinking text frequently contains ``[`` / ``]`` characters, which
|
||||||
|
corrupted the downstream ``find('[')`` / ``rfind(']')`` JSON extraction and
|
||||||
|
produced empty suggestions. We strip the reasoning before parsing so only
|
||||||
|
the actual answer remains.
|
||||||
|
"""
|
||||||
|
text = _THINK_BLOCK_RE.sub("", text)
|
||||||
|
# Drop any unclosed <think> (and everything after it) left by truncation.
|
||||||
|
open_match = _OPEN_THINK_RE.search(text)
|
||||||
|
if open_match:
|
||||||
|
text = text[: open_match.start()]
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _strip_markdown_code_fence(text: str) -> str:
|
def _strip_markdown_code_fence(text: str) -> str:
|
||||||
stripped = text.strip()
|
stripped = text.strip()
|
||||||
if not stripped.startswith("```"):
|
if not stripped.startswith("```"):
|
||||||
@@ -41,7 +67,8 @@ def _strip_markdown_code_fence(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_json_string_list(text: str) -> list[str] | None:
|
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||||
candidate = _strip_markdown_code_fence(text)
|
candidate = _strip_think_blocks(text)
|
||||||
|
candidate = _strip_markdown_code_fence(candidate)
|
||||||
start = candidate.find("[")
|
start = candidate.find("[")
|
||||||
end = candidate.rfind("]")
|
end = candidate.rfind("]")
|
||||||
if start == -1 or end == -1 or end <= start:
|
if start == -1 or end == -1 or end <= start:
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
|
from app.gateway.pagination import trim_run_message_page
|
||||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values_for_api
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||||
@@ -191,7 +192,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
|||||||
if checkpoint_tuple is not None:
|
if checkpoint_tuple is not None:
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
return serialize_channel_values(channel_values)
|
return serialize_channel_values_for_api(channel_values)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||||
|
|
||||||
@@ -402,8 +403,7 @@ async def list_run_messages(
|
|||||||
before_seq=before_seq,
|
before_seq=before_seq,
|
||||||
after_seq=after_seq,
|
after_seq=after_seq,
|
||||||
)
|
)
|
||||||
has_more = len(rows) > limit
|
data, has_more = trim_run_message_page(rows, limit=limit, after_seq=after_seq)
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
return {"data": data, "has_more": has_more}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,14 +17,15 @@ import uuid
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
from langgraph.checkpoint.base import empty_checkpoint, uuid6
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer
|
from app.gateway.deps import get_checkpointer
|
||||||
|
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values_for_api
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.utils.time import coerce_iso, now_iso
|
from deerflow.utils.time import coerce_iso, now_iso
|
||||||
|
|
||||||
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
thread_store = get_thread_store(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = now_iso()
|
now = now_iso()
|
||||||
|
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||||
|
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_store.get(thread_id)
|
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||||
|
if existing_record is None and thread_owner_user_id:
|
||||||
|
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
||||||
|
if unscoped_record is not None:
|
||||||
|
if unscoped_record.get("user_id") != thread_owner_user_id:
|
||||||
|
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
||||||
|
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
await thread_store.create(
|
await thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
|
**thread_owner_kwargs,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -427,7 +437,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
created_at=coerce_iso(record.get("created_at", "")),
|
created_at=coerce_iso(record.get("created_at", "")),
|
||||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||||
metadata=record.get("metadata", {}),
|
metadata=record.get("metadata", {}),
|
||||||
values=serialize_channel_values(channel_values),
|
values=serialize_channel_values_for_api(channel_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -470,7 +480,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
|||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
values = serialize_channel_values(channel_values)
|
values = serialize_channel_values_for_api(channel_values)
|
||||||
|
|
||||||
return ThreadStateResponse(
|
return ThreadStateResponse(
|
||||||
values=values,
|
values=values,
|
||||||
@@ -536,9 +546,21 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
metadata["step"] = metadata.get("step", 0) + 1
|
metadata["step"] = metadata.get("step", 0) + 1
|
||||||
metadata["writes"] = {body.as_node: body.values}
|
metadata["writes"] = {body.as_node: body.values}
|
||||||
|
|
||||||
|
# Assign a new checkpoint ID so aput performs an INSERT rather than an
|
||||||
|
# in-place REPLACE of the existing row. Use uuid6 (time-ordered) rather
|
||||||
|
# than uuid4 (random) so the new ID is always lexicographically greater
|
||||||
|
# than the previous one — LangGraph's checkpointers determine the "latest"
|
||||||
|
# checkpoint by max(checkpoint_ids) string order, matching the uuid6 epoch.
|
||||||
|
checkpoint["id"] = str(uuid6())
|
||||||
|
|
||||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
# aput requires checkpoint_ns in the config — use the same config used for the
|
||||||
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
# read (which always includes checkpoint_ns=""). The fresh checkpoint ID is
|
||||||
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
# assigned above via checkpoint["id"]; keep checkpoint_id out of the config so
|
||||||
|
# the write is keyed by the new checkpoint payload rather than the prior read.
|
||||||
|
# All supported savers (InMemorySaver, AsyncSqliteSaver, AsyncPostgresSaver)
|
||||||
|
# persist and echo back checkpoint["id"] verbatim — none mint their own — so
|
||||||
|
# the new_config below carries the uuid6 we assigned here. (Regression-locked
|
||||||
|
# by test_update_thread_state_inserts_new_checkpoint_each_call.)
|
||||||
write_config: dict[str, Any] = {
|
write_config: dict[str, Any] = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
@@ -557,7 +579,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
|
|
||||||
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
||||||
# reflects them immediately in both sqlite and memory backends.
|
# reflects them immediately in both sqlite and memory backends.
|
||||||
if body.values and "title" in body.values:
|
if thread_store and body.values and "title" in body.values:
|
||||||
new_title = body.values["title"]
|
new_title = body.values["title"]
|
||||||
if new_title: # Skip empty strings and None
|
if new_title: # Skip empty strings and None
|
||||||
try:
|
try:
|
||||||
@@ -566,7 +588,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
return ThreadStateResponse(
|
return ThreadStateResponse(
|
||||||
values=serialize_channel_values(channel_values),
|
values=serialize_channel_values_for_api(channel_values),
|
||||||
next=[],
|
next=[],
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
checkpoint_id=new_checkpoint_id,
|
checkpoint_id=new_checkpoint_id,
|
||||||
@@ -618,7 +640,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
if is_latest_checkpoint:
|
if is_latest_checkpoint:
|
||||||
messages = channel_values.get("messages")
|
messages = channel_values.get("messages")
|
||||||
if messages:
|
if messages:
|
||||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
values["messages"] = serialize_channel_values_for_api({"messages": messages}).get("messages", [])
|
||||||
is_latest_checkpoint = False
|
is_latest_checkpoint = False
|
||||||
|
|
||||||
# Derive next tasks
|
# Derive next tasks
|
||||||
|
|||||||
@@ -39,15 +39,39 @@ DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
|
|||||||
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
|
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class UploadedFileInfo(BaseModel):
|
||||||
|
"""Uploaded file metadata exposed by upload and list APIs."""
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
size: int
|
||||||
|
path: str
|
||||||
|
virtual_path: str
|
||||||
|
artifact_url: str
|
||||||
|
extension: str | None = None
|
||||||
|
modified: float | None = None
|
||||||
|
original_filename: str | None = None
|
||||||
|
markdown_file: str | None = None
|
||||||
|
markdown_path: str | None = None
|
||||||
|
markdown_virtual_path: str | None = None
|
||||||
|
markdown_artifact_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UploadResponse(BaseModel):
|
class UploadResponse(BaseModel):
|
||||||
"""Response model for file upload."""
|
"""Response model for file upload."""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
files: list[dict[str, str]]
|
files: list[UploadedFileInfo]
|
||||||
message: str
|
message: str
|
||||||
skipped_files: list[str] = Field(default_factory=list)
|
skipped_files: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadListResponse(BaseModel):
|
||||||
|
"""Response model for uploaded file listing."""
|
||||||
|
|
||||||
|
files: list[UploadedFileInfo]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
class UploadLimits(BaseModel):
|
class UploadLimits(BaseModel):
|
||||||
"""Application-level upload limits exposed to clients."""
|
"""Application-level upload limits exposed to clients."""
|
||||||
|
|
||||||
@@ -256,7 +280,7 @@ async def upload_files(
|
|||||||
|
|
||||||
file_info = {
|
file_info = {
|
||||||
"filename": safe_filename,
|
"filename": safe_filename,
|
||||||
"size": str(file_size),
|
"size": file_size,
|
||||||
"path": str(sandbox_uploads / safe_filename),
|
"path": str(sandbox_uploads / safe_filename),
|
||||||
"virtual_path": virtual_path,
|
"virtual_path": virtual_path,
|
||||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||||
@@ -333,9 +357,9 @@ async def get_upload_limits(
|
|||||||
return _get_upload_limits(config)
|
return _get_upload_limits(config)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=dict)
|
@router.get("/list", response_model=UploadListResponse)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
async def list_uploaded_files(thread_id: str, request: Request) -> UploadListResponse:
|
||||||
"""List all files in a thread's uploads directory."""
|
"""List all files in a thread's uploads directory."""
|
||||||
try:
|
try:
|
||||||
uploads_dir = get_uploads_dir(thread_id)
|
uploads_dir = get_uploads_dir(thread_id)
|
||||||
@@ -349,7 +373,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
|||||||
for f in result["files"]:
|
for f in result["files"]:
|
||||||
f["path"] = str(sandbox_uploads / f["filename"])
|
f["path"] = str(sandbox_uploads / f["filename"])
|
||||||
|
|
||||||
return result
|
return UploadListResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{filename}")
|
@router.delete("/{filename}")
|
||||||
|
|||||||
+128
-70
@@ -12,6 +12,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -19,6 +20,7 @@ from langchain_core.messages import BaseMessage
|
|||||||
from langchain_core.messages.utils import convert_to_messages
|
from langchain_core.messages.utils import convert_to_messages
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||||
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
@@ -34,6 +36,7 @@ from deerflow.runtime import (
|
|||||||
run_agent,
|
run_agent,
|
||||||
)
|
)
|
||||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||||
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -140,7 +143,14 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
|||||||
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
||||||
and ``config['context']`` so they are visible to legacy configurable readers and
|
and ``config['context']`` so they are visible to legacy configurable readers and
|
||||||
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
||||||
see issue #2677)."""
|
see issue #2677).
|
||||||
|
|
||||||
|
``user_id`` is intentionally propagated into ``config['context']`` in addition to
|
||||||
|
the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in
|
||||||
|
``body.context`` keep it on ``ToolRuntime.context``. It is merged with
|
||||||
|
``setdefault`` so a server-authenticated id stamped by
|
||||||
|
:func:`inject_authenticated_user_context` always wins over the client-supplied one.
|
||||||
|
"""
|
||||||
if not context:
|
if not context:
|
||||||
return
|
return
|
||||||
configurable = config.setdefault("configurable", {})
|
configurable = config.setdefault("configurable", {})
|
||||||
@@ -151,6 +161,8 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
|||||||
configurable.setdefault(key, context[key])
|
configurable.setdefault(key, context[key])
|
||||||
if isinstance(runtime_context, dict):
|
if isinstance(runtime_context, dict):
|
||||||
runtime_context.setdefault(key, context[key])
|
runtime_context.setdefault(key, context[key])
|
||||||
|
if "user_id" in context and isinstance(runtime_context, dict):
|
||||||
|
runtime_context.setdefault("user_id", context["user_id"])
|
||||||
|
|
||||||
|
|
||||||
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
||||||
@@ -166,6 +178,9 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request)
|
|||||||
if user_id is None:
|
if user_id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||||
|
return
|
||||||
|
|
||||||
runtime_context = config.setdefault("context", {})
|
runtime_context = config.setdefault("context", {})
|
||||||
if isinstance(runtime_context, dict):
|
if isinstance(runtime_context, dict):
|
||||||
runtime_context["user_id"] = str(user_id)
|
runtime_context["user_id"] = str(user_id)
|
||||||
@@ -196,11 +211,14 @@ def build_run_config(
|
|||||||
|
|
||||||
When *assistant_id* refers to a custom agent (anything other than
|
When *assistant_id* refers to a custom agent (anything other than
|
||||||
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
|
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
|
||||||
whichever runtime options container is active: ``context`` for
|
both ``configurable`` and ``context`` so it is visible to legacy
|
||||||
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
|
configurable readers and to LangGraph ``ToolRuntime.context`` consumers
|
||||||
``make_lead_agent`` reads this key to load the matching
|
(e.g. the ``setup_agent`` tool, which since LangGraph >=1.1.9 no longer
|
||||||
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
|
falls back from ``context`` to ``configurable``). An explicit
|
||||||
silently runs as the default lead agent.
|
``agent_name`` in either container takes precedence over the value
|
||||||
|
derived from ``assistant_id``. ``make_lead_agent`` reads this key to
|
||||||
|
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||||||
|
without it the agent silently runs as the default lead agent.
|
||||||
|
|
||||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||||
@@ -238,19 +256,23 @@ def build_run_config(
|
|||||||
config["configurable"] = {"thread_id": thread_id}
|
config["configurable"] = {"thread_id": thread_id}
|
||||||
|
|
||||||
# Inject custom agent name when the caller specified a non-default assistant.
|
# Inject custom agent name when the caller specified a non-default assistant.
|
||||||
# Honour an explicit agent_name in the active runtime options container.
|
# Honour an explicit agent_name in either runtime options container.
|
||||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
|
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
|
||||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||||
if "configurable" in config:
|
configurable = config.setdefault("configurable", {})
|
||||||
target = config["configurable"]
|
runtime_context = config.setdefault("context", {})
|
||||||
elif "context" in config:
|
explicit_agent_name: str | None = None
|
||||||
target = config["context"]
|
if isinstance(configurable, dict) and isinstance(configurable.get("agent_name"), str):
|
||||||
else:
|
explicit_agent_name = configurable["agent_name"]
|
||||||
target = config.setdefault("configurable", {})
|
elif isinstance(runtime_context, dict) and isinstance(runtime_context.get("agent_name"), str):
|
||||||
if target is not None and "agent_name" not in target:
|
explicit_agent_name = runtime_context["agent_name"]
|
||||||
target["agent_name"] = normalized
|
effective_agent_name = explicit_agent_name or normalized
|
||||||
|
if isinstance(configurable, dict):
|
||||||
|
configurable["agent_name"] = effective_agent_name
|
||||||
|
if isinstance(runtime_context, dict):
|
||||||
|
runtime_context["agent_name"] = effective_agent_name
|
||||||
config.setdefault("run_name", resolve_root_run_name(config, normalized))
|
config.setdefault("run_name", resolve_root_run_name(config, normalized))
|
||||||
if metadata:
|
if metadata:
|
||||||
config.setdefault("metadata", {}).update(metadata)
|
config.setdefault("metadata", {}).update(metadata)
|
||||||
@@ -302,72 +324,108 @@ async def start_run(
|
|||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||||
record = await run_mgr.create_or_reject(
|
# Stateless run endpoints carry thread_id in the request *body*, so the
|
||||||
thread_id,
|
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
||||||
body.assistant_id,
|
# from the path param -- cannot protect them. Enforce thread ownership here,
|
||||||
on_disconnect=disconnect,
|
# before any run is created, so one user cannot start runs on (or read /wait
|
||||||
metadata=body.metadata or {},
|
# checkpoint state from) another user's thread. Missing rows (auto-created
|
||||||
kwargs={"input": body.input, "config": body.config},
|
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
||||||
multitask_strategy=body.multitask_strategy,
|
# via check_access; only a thread already owned by another user is rejected
|
||||||
model_name=model_name,
|
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
||||||
)
|
# channel runs act on behalf of the connection owner carried in
|
||||||
except ConflictError as exc:
|
# X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
# bypassing the check -- a leaked internal token must not grant cross-user
|
||||||
except UnsupportedStrategyError as exc:
|
# thread access.
|
||||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
user = getattr(request.state, "user", None)
|
||||||
|
if user is not None:
|
||||||
|
allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
|
||||||
|
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||||
|
# Channel workers may also act for the connection owner named in
|
||||||
|
# the trusted header (e.g. claiming a legacy default-owned channel
|
||||||
|
# thread for its real owner).
|
||||||
|
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
# Upsert thread metadata so the thread appears in /threads/search,
|
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
||||||
# even for threads that were never explicitly created via POST /threads
|
|
||||||
# (e.g. stateless runs).
|
|
||||||
try:
|
try:
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
try:
|
||||||
if existing is None:
|
record = await run_mgr.create_or_reject(
|
||||||
await run_ctx.thread_store.create(
|
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
body.assistant_id,
|
||||||
metadata=body.metadata,
|
on_disconnect=disconnect,
|
||||||
|
metadata=body.metadata or {},
|
||||||
|
kwargs={"input": body.input, "config": body.config},
|
||||||
|
multitask_strategy=body.multitask_strategy,
|
||||||
|
model_name=model_name,
|
||||||
|
user_id=owner_user_id,
|
||||||
)
|
)
|
||||||
else:
|
except ConflictError as exc:
|
||||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
except Exception:
|
except UnsupportedStrategyError as exc:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
# Upsert thread metadata so the thread appears in /threads/search,
|
||||||
graph_input = normalize_input(body.input)
|
# even for threads that were never explicitly created via POST /threads
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
# (e.g. stateless runs).
|
||||||
|
try:
|
||||||
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
|
if existing is None and owner_user_id:
|
||||||
|
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
||||||
|
if unscoped_existing is not None:
|
||||||
|
if unscoped_existing.get("user_id") != owner_user_id:
|
||||||
|
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
||||||
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
|
if existing is None:
|
||||||
|
await run_ctx.thread_store.create(
|
||||||
|
thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
graph_input = normalize_input(body.input)
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
|
||||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
|
||||||
inject_authenticated_user_context(config, request)
|
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||||
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
|
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||||
|
inject_authenticated_user_context(config, request)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
run_agent(
|
|
||||||
bridge,
|
task = asyncio.create_task(
|
||||||
run_mgr,
|
run_agent(
|
||||||
record,
|
bridge,
|
||||||
ctx=run_ctx,
|
run_mgr,
|
||||||
agent_factory=agent_factory,
|
record,
|
||||||
graph_input=graph_input,
|
ctx=run_ctx,
|
||||||
config=config,
|
agent_factory=agent_factory,
|
||||||
stream_modes=stream_modes,
|
graph_input=graph_input,
|
||||||
stream_subgraphs=body.stream_subgraphs,
|
config=config,
|
||||||
interrupt_before=body.interrupt_before,
|
stream_modes=stream_modes,
|
||||||
interrupt_after=body.interrupt_after,
|
stream_subgraphs=body.stream_subgraphs,
|
||||||
|
interrupt_before=body.interrupt_before,
|
||||||
|
interrupt_after=body.interrupt_after,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
record.task = task
|
||||||
record.task = task
|
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
# Title sync is handled by worker.py's finally block which reads the
|
||||||
# title from the checkpoint and calls thread_store.update_display_name
|
# title from the checkpoint and calls thread_store.update_display_name
|
||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
finally:
|
||||||
|
if owner_context_token is not None:
|
||||||
|
reset_current_user(owner_context_token)
|
||||||
|
|
||||||
|
|
||||||
async def sse_consumer(
|
async def sse_consumer(
|
||||||
|
|||||||
+22
-4
@@ -228,10 +228,13 @@ Get current MCP server configurations.
|
|||||||
GET /api/mcp/config
|
GET /api/mcp/config
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Requires an authenticated admin session. Sensitive env/header/OAuth secret
|
||||||
|
values are masked in the response.
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcpServers": {
|
"mcp_servers": {
|
||||||
"github": {
|
"github": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"type": "stdio",
|
"type": "stdio",
|
||||||
@@ -255,10 +258,15 @@ PUT /api/mcp/config
|
|||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Requires an authenticated admin session. API-managed `stdio` MCP servers may
|
||||||
|
only use allowed executable names for `command` (default: `npx`, `uvx`). Set
|
||||||
|
`DEER_FLOW_MCP_STDIO_COMMAND_ALLOWLIST` to a comma-separated list when a
|
||||||
|
deployment needs additional trusted launchers.
|
||||||
|
|
||||||
**Request Body:**
|
**Request Body:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcpServers": {
|
"mcp_servers": {
|
||||||
"github": {
|
"github": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"type": "stdio",
|
"type": "stdio",
|
||||||
@@ -276,8 +284,18 @@ Content-Type: application/json
|
|||||||
**Response:**
|
**Response:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"success": true,
|
"mcp_servers": {
|
||||||
"message": "MCP configuration updated"
|
"github": {
|
||||||
|
"enabled": true,
|
||||||
|
"type": "stdio",
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||||
|
"env": {
|
||||||
|
"GITHUB_TOKEN": "***"
|
||||||
|
},
|
||||||
|
"description": "GitHub operations"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ All other test plan sections were executed against either:
|
|||||||
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
|
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
|
||||||
| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` |
|
| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` |
|
||||||
| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
|
| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
|
||||||
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
|
| TC-DOCKER-06 | Docker deploy uses Gateway embedded runtime | `./scripts/deploy.sh` produces a Gateway + frontend + nginx topology (no `langgraph` container); same auth flow as local `make dev` | needs `docker compose up` |
|
||||||
|
|
||||||
## Coverage already provided by non-Docker tests
|
## Coverage already provided by non-Docker tests
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ the test cases that ran on sg_dev or local:
|
|||||||
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
|
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
|
||||||
| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies |
|
| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies |
|
||||||
| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
|
| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
|
||||||
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
|
| TC-DOCKER-06 (Gateway embedded runtime container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (Gateway auth flow on sg_dev) — same Gateway code, container is just a packaging change |
|
||||||
|
|
||||||
## Reproduction steps when Docker becomes available
|
## Reproduction steps when Docker becomes available
|
||||||
|
|
||||||
|
|||||||
@@ -124,8 +124,8 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
|
|||||||
|
|
||||||
## 兼容性
|
## 兼容性
|
||||||
|
|
||||||
- **标准模式**(`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化
|
- **本地开发**(`make dev`):Gateway embedded runtime 完全兼容;无 admin 时访问 `/setup` 初始化
|
||||||
- **Gateway 模式**(`make dev-pro`):完全兼容
|
- **Gateway embedded runtime**:标准脚本、Docker dev 和生产部署均通过 Gateway 提供认证与 LangGraph-compatible API
|
||||||
- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载
|
- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载
|
||||||
- **IM 渠道**(Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶
|
- **IM 渠道**(Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶
|
||||||
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
||||||
|
|||||||
@@ -67,6 +67,11 @@ The normal workflow is:
|
|||||||
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
|
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
|
||||||
4. Let CI prevent that path from regressing.
|
4. Let CI prevent that path from regressing.
|
||||||
|
|
||||||
|
Contributors changing backend async code can run the `blocking-io-guard` skill
|
||||||
|
(`.agent/skills/blocking-io-guard/`) to execute steps 1–3 for their own diff: it
|
||||||
|
scans the change for blocking-IO candidates, drafts or extends a runtime anchor,
|
||||||
|
and verifies the anchor fails when the blocking IO regresses.
|
||||||
|
|
||||||
Runtime detection has two maintenance paths.
|
Runtime detection has two maintenance paths.
|
||||||
|
|
||||||
### Add a runtime rule
|
### Add a runtime rule
|
||||||
|
|||||||
@@ -95,25 +95,35 @@ models:
|
|||||||
thinking:
|
thinking:
|
||||||
type: enabled
|
type: enabled
|
||||||
|
|
||||||
- name: minimax-m2.5
|
- name: minimax-m3
|
||||||
display_name: MiniMax M2.5
|
display_name: MiniMax M3
|
||||||
use: langchain_openai:ChatOpenAI
|
use: langchain_openai:ChatOpenAI
|
||||||
model: MiniMax-M2.5
|
model: MiniMax-M3
|
||||||
api_key: $MINIMAX_API_KEY
|
api_key: $MINIMAX_API_KEY
|
||||||
base_url: https://api.minimax.io/v1
|
base_url: https://api.minimax.io/v1
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
||||||
supports_vision: true
|
supports_vision: true
|
||||||
|
|
||||||
- name: minimax-m2.5-highspeed
|
- name: minimax-m2.7
|
||||||
display_name: MiniMax M2.5 Highspeed
|
display_name: MiniMax M2.7
|
||||||
use: langchain_openai:ChatOpenAI
|
use: langchain_openai:ChatOpenAI
|
||||||
model: MiniMax-M2.5-highspeed
|
model: MiniMax-M2.7
|
||||||
api_key: $MINIMAX_API_KEY
|
api_key: $MINIMAX_API_KEY
|
||||||
base_url: https://api.minimax.io/v1
|
base_url: https://api.minimax.io/v1
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
||||||
supports_vision: true
|
supports_vision: false # M2.7 is text-only; M3 supports vision
|
||||||
|
|
||||||
|
- name: minimax-m2.7-highspeed
|
||||||
|
display_name: MiniMax M2.7 Highspeed
|
||||||
|
use: langchain_openai:ChatOpenAI
|
||||||
|
model: MiniMax-M2.7-highspeed
|
||||||
|
api_key: $MINIMAX_API_KEY
|
||||||
|
base_url: https://api.minimax.io/v1
|
||||||
|
max_tokens: 4096
|
||||||
|
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
||||||
|
supports_vision: false # M2.7 is text-only; M3 supports vision
|
||||||
- name: openrouter-gemini-2.5-flash
|
- name: openrouter-gemini-2.5-flash
|
||||||
display_name: Gemini 2.5 Flash (OpenRouter)
|
display_name: Gemini 2.5 Flash (OpenRouter)
|
||||||
use: langchain_openai:ChatOpenAI
|
use: langchain_openai:ChatOpenAI
|
||||||
@@ -224,7 +234,7 @@ tools:
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Built-in Tools**:
|
**Built-in Tools**:
|
||||||
- `web_search` - Search the web (DuckDuckGo, Tavily, Exa, InfoQuest, Firecrawl)
|
- `web_search` - Search the web (DuckDuckGo, Tavily, Brave, Exa, InfoQuest, Firecrawl)
|
||||||
- `web_fetch` - Fetch web pages (Jina AI, Exa, InfoQuest, Firecrawl)
|
- `web_fetch` - Fetch web pages (Jina AI, Exa, InfoQuest, Firecrawl)
|
||||||
- `ls` - List directory contents
|
- `ls` - List directory contents
|
||||||
- `read_file` - Read file contents
|
- `read_file` - Read file contents
|
||||||
@@ -293,6 +303,55 @@ When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` val
|
|||||||
|
|
||||||
For bare-metal Docker sandbox runs that use localhost, DeerFlow binds the sandbox HTTP port to `127.0.0.1` by default so it is not exposed on every host interface. Docker-outside-of-Docker deployments that connect through `host.docker.internal` keep the broad legacy bind for compatibility. Set `DEER_FLOW_SANDBOX_BIND_HOST` explicitly if your deployment needs a different bind address.
|
For bare-metal Docker sandbox runs that use localhost, DeerFlow binds the sandbox HTTP port to `127.0.0.1` by default so it is not exposed on every host interface. Docker-outside-of-Docker deployments that connect through `host.docker.internal` keep the broad legacy bind for compatibility. Set `DEER_FLOW_SANDBOX_BIND_HOST` explicitly if your deployment needs a different bind address.
|
||||||
|
|
||||||
|
### Building a Custom AIO Sandbox Image
|
||||||
|
|
||||||
|
`AioSandboxProvider` talks to the sandbox container through the `agent-sandbox` SDK. The Dockerfile for the default `enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest` image is not part of this repository; DeerFlow treats that image as an upstream AIO sandbox runtime.
|
||||||
|
|
||||||
|
For persistent system or language dependencies, extend the published image and keep its startup command intact:
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest
|
||||||
|
|
||||||
|
USER root
|
||||||
|
# Example user dependency; not required by DeerFlow itself.
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends graphviz \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Example Python dependency for work done inside the sandbox.
|
||||||
|
RUN python -m pip install --no-cache-dir pandas
|
||||||
|
|
||||||
|
# Do not override ENTRYPOINT or CMD; keep the upstream sandbox server startup.
|
||||||
|
```
|
||||||
|
|
||||||
|
Use the custom image in local Docker or Apple Container mode with `sandbox.image`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sandbox:
|
||||||
|
use: deerflow.community.aio_sandbox:AioSandboxProvider
|
||||||
|
image: your-registry/your-aio-sandbox:tag
|
||||||
|
```
|
||||||
|
|
||||||
|
In provisioner mode, sandbox Pods are created by the provisioner service, so configure the provisioner `SANDBOX_IMAGE` environment variable instead of `sandbox.image`. See the [Provisioner Setup Guide](../../docker/provisioner/README.md#custom-sandbox-image).
|
||||||
|
|
||||||
|
If you rebuild the runtime from scratch instead of extending the published image, it must expose the same HTTP API used by `agent-sandbox`. DeerFlow currently depends on:
|
||||||
|
|
||||||
|
- `sandbox.get_context()`, including `home_dir`
|
||||||
|
- `shell.exec_command(...)`
|
||||||
|
- `file.read_file(...)`
|
||||||
|
- `file.write_file(...)`, including base64 writes for binary content
|
||||||
|
- streamed `file.download_file(...)`
|
||||||
|
- `file.find_files(...)`
|
||||||
|
- `file.list_path(...)`
|
||||||
|
- `file.search_in_file(...)`
|
||||||
|
|
||||||
|
Custom images must also keep these compatibility constraints:
|
||||||
|
|
||||||
|
- The container should listen on the configured sandbox port, `8080` by default.
|
||||||
|
- `/mnt/user-data` must remain writable because DeerFlow mounts thread workspace, uploads, and outputs there.
|
||||||
|
- `home_dir` comes from the sandbox context endpoint; do not assume DeerFlow hardcodes it.
|
||||||
|
- Shell command handling must remain compatible with serialized `exec_command` calls. DeerFlow serializes shell access on the host side to avoid corrupting the sandbox's persistent shell session.
|
||||||
|
|
||||||
### Skills
|
### Skills
|
||||||
|
|
||||||
Configure the skills directory for specialized workflows:
|
Configure the skills directory for specialized workflows:
|
||||||
@@ -354,6 +413,7 @@ models:
|
|||||||
- `MIMO_API_KEY` - Xiaomi MiMo API key
|
- `MIMO_API_KEY` - Xiaomi MiMo API key
|
||||||
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
||||||
- `TAVILY_API_KEY` - Tavily search API key
|
- `TAVILY_API_KEY` - Tavily search API key
|
||||||
|
- `BRAVE_SEARCH_API_KEY` - Brave Search API key
|
||||||
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
|
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
|
||||||
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
|
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
|
||||||
- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path
|
- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
# IM Channel Connections
|
||||||
|
|
||||||
|
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
|
||||||
|
|
||||||
|
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the actual IM bots under the existing `channels` block:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
channels:
|
||||||
|
telegram:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $TELEGRAM_BOT_TOKEN
|
||||||
|
|
||||||
|
slack:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $SLACK_BOT_TOKEN
|
||||||
|
app_token: $SLACK_APP_TOKEN
|
||||||
|
|
||||||
|
discord:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $DISCORD_BOT_TOKEN
|
||||||
|
|
||||||
|
feishu:
|
||||||
|
enabled: true
|
||||||
|
app_id: $FEISHU_APP_ID
|
||||||
|
app_secret: $FEISHU_APP_SECRET
|
||||||
|
|
||||||
|
dingtalk:
|
||||||
|
enabled: true
|
||||||
|
client_id: $DINGTALK_CLIENT_ID
|
||||||
|
client_secret: $DINGTALK_CLIENT_SECRET
|
||||||
|
|
||||||
|
wechat:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $WECHAT_BOT_TOKEN
|
||||||
|
|
||||||
|
wecom:
|
||||||
|
enabled: true
|
||||||
|
bot_id: $WECOM_BOT_ID
|
||||||
|
bot_secret: $WECOM_BOT_SECRET
|
||||||
|
```
|
||||||
|
|
||||||
|
Then enable user bindings in `channel_connections`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
channel_connections:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
telegram:
|
||||||
|
enabled: true
|
||||||
|
bot_username: $TELEGRAM_BOT_USERNAME
|
||||||
|
|
||||||
|
slack:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
discord:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
feishu:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
dingtalk:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
wechat:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
wecom:
|
||||||
|
enabled: true
|
||||||
|
```
|
||||||
|
|
||||||
|
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
|
||||||
|
|
||||||
|
## Connect Flow
|
||||||
|
|
||||||
|
Telegram:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
|
||||||
|
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
|
||||||
|
|
||||||
|
Slack:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
|
||||||
|
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
|
||||||
|
|
||||||
|
Discord:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
|
||||||
|
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
|
||||||
|
|
||||||
|
Feishu/Lark, DingTalk, WeChat, and WeCom:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The UI shows `Send /connect <code> to the DeerFlow <Provider> bot.`
|
||||||
|
- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user.
|
||||||
|
|
||||||
|
Codes use 128 bits of randomness, expire after 10 minutes, and are single-use.
|
||||||
|
|
||||||
|
## Runtime Model
|
||||||
|
|
||||||
|
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
|
||||||
|
|
||||||
|
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
|
||||||
|
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
|
||||||
|
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
|
||||||
|
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
|
||||||
|
|
||||||
|
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- Browser APIs remain authenticated and CSRF-protected.
|
||||||
|
- Connect codes are 128-bit random, short-lived, and single-use.
|
||||||
|
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
|
||||||
|
- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets.
|
||||||
|
- This implementation does not add public provider callback or webhook routes.
|
||||||
@@ -31,7 +31,8 @@ Current injection format:
|
|||||||
|
|
||||||
Token counting:
|
Token counting:
|
||||||
- Uses `tiktoken` (`cl100k_base`) when available
|
- Uses `tiktoken` (`cl100k_base`) when available
|
||||||
- Falls back to `len(text) // 4` if tokenizer import fails
|
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
|
||||||
|
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
|
||||||
|
|
||||||
## Known Gap
|
## Known Gap
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ This directory contains detailed documentation for the DeerFlow backend.
|
|||||||
| [STREAMING.md](STREAMING.md) | Token-level streaming design: Gateway vs DeerFlowClient paths, `stream_mode` semantics, per-id dedup |
|
| [STREAMING.md](STREAMING.md) | Token-level streaming design: Gateway vs DeerFlowClient paths, `stream_mode` semantics, per-id dedup |
|
||||||
| [FILE_UPLOAD.md](FILE_UPLOAD.md) | File upload functionality |
|
| [FILE_UPLOAD.md](FILE_UPLOAD.md) | File upload functionality |
|
||||||
| [PATH_EXAMPLES.md](PATH_EXAMPLES.md) | Path types and usage examples |
|
| [PATH_EXAMPLES.md](PATH_EXAMPLES.md) | Path types and usage examples |
|
||||||
|
| [SANDBOX_MEMORY_PROFILING.md](SANDBOX_MEMORY_PROFILING.md) | Sandbox memory baseline and runtime comparison guide |
|
||||||
| [summarization.md](summarization.md) | Context summarization feature |
|
| [summarization.md](summarization.md) | Context summarization feature |
|
||||||
| [plan_mode_usage.md](plan_mode_usage.md) | Plan mode with TodoList |
|
| [plan_mode_usage.md](plan_mode_usage.md) | Plan mode with TodoList |
|
||||||
| [AUTO_TITLE_GENERATION.md](AUTO_TITLE_GENERATION.md) | Automatic title generation |
|
| [AUTO_TITLE_GENERATION.md](AUTO_TITLE_GENERATION.md) | Automatic title generation |
|
||||||
|
|||||||
@@ -0,0 +1,120 @@
|
|||||||
|
# Record/Replay E2E — front-back contract verification
|
||||||
|
|
||||||
|
Deterministic, **key-free** end-to-end checks that a backend change can't
|
||||||
|
silently break the frontend (and vice-versa). Two complementary layers, fed by a
|
||||||
|
single recording.
|
||||||
|
|
||||||
|
## Why
|
||||||
|
|
||||||
|
The mock-based frontend e2e hand-writes the backend's JSON/SSE, so a backend
|
||||||
|
schema or SSE change passes green ("fake green"). These layers replay a recorded
|
||||||
|
**real** run against the **real** backend (and, for Layer 2, the real frontend),
|
||||||
|
so contract drift turns the build red instead.
|
||||||
|
|
||||||
|
## The two layers
|
||||||
|
|
||||||
|
- **Layer 1 — backend golden** (`tests/test_replay_golden.py`): replays a fixture
|
||||||
|
through the real FastAPI gateway with `ReplayChatModel` and asserts the streamed
|
||||||
|
SSE event sequence equals a committed golden. Fast, no browser. Guards protocol
|
||||||
|
*shape*.
|
||||||
|
- **Layer 2 — full-stack render** (`frontend/tests/e2e-real-backend/`): real
|
||||||
|
Next.js + real gateway (replay model) + Chromium; asserts the replayed
|
||||||
|
auto-title and a follow-up suggestion render in the browser. Guards semantic
|
||||||
|
*render*. (Complementary to Layer 1 — neither subsumes the other.)
|
||||||
|
|
||||||
|
Layer 2 also hosts **cross-stack contract scenarios** — the dangerous class
|
||||||
|
where a backend change silently breaks a frontend assumption and *both sides'
|
||||||
|
unit tests stay green*. See below.
|
||||||
|
|
||||||
|
## Cross-stack scenario: multi-run render order (`multi-run-order.spec.ts`)
|
||||||
|
|
||||||
|
Regression guard for issue **#3352** (after context compression, refreshing a
|
||||||
|
thread rendered history out of order). Root cause was a front-back desync:
|
||||||
|
backend `RunManager.list_by_thread` returns runs **newest-first** (PR #2932),
|
||||||
|
while the frontend (`core/threads/hooks.ts`) iterated runs and **prepended** each
|
||||||
|
loaded page — inverting chronological order once the checkpoint no longer held
|
||||||
|
the older messages. The backend ordering test was green throughout, and the
|
||||||
|
frontend regression unit test hardcodes "backend returns newest-first" in a mock,
|
||||||
|
so only a *real frontend against a real backend* catches the desync.
|
||||||
|
|
||||||
|
This scenario does **not** record a conversation. It uses a **test-only seeder**
|
||||||
|
(`tests/seed_runs_router.py`, mounted on the replay gateway only when
|
||||||
|
`DEERFLOW_ENABLE_TEST_SEED=1`) to stand up a thread with ≥2 runs and per-run
|
||||||
|
message events — and deliberately **no checkpoint**, which is the #3352
|
||||||
|
precondition: it forces the frontend's per-run reload path to be the sole source
|
||||||
|
of truth so the ordering bug becomes observable. The seeder writes through the
|
||||||
|
gateway's own run/event stores using the request's auth context, so the real
|
||||||
|
`list_by_thread` → `/runs/{id}/messages` → prepend path runs live. Reverting the
|
||||||
|
#3354 frontend fix turns this spec red.
|
||||||
|
|
||||||
|
## How replay works
|
||||||
|
|
||||||
|
`tests/replay_provider.py::ReplayChatModel` returns recorded assistant turns keyed
|
||||||
|
by a **normalized hash of the model caller + conversation**. The conversation is
|
||||||
|
human / ai / tool messages — role, text, tool-call name+args; with
|
||||||
|
`<system-reminder>`, dates, UUIDs, tmp paths stripped. The caller is the stable
|
||||||
|
source of the model call (`lead_agent`, `middleware:title`, `suggest_agent`,
|
||||||
|
`subagent:*`, etc.). A miss raises loudly rather than passing silently.
|
||||||
|
|
||||||
|
**The system prompt is excluded from the match key.** The lead-agent system
|
||||||
|
prompt is a living, frequently-edited implementation detail — its wording changes
|
||||||
|
across PRs (e.g. #3195 added a "File Editing Workflow" section). Hashing it would
|
||||||
|
make every fixture go stale and red-fail unrelated PRs the moment anyone edits the
|
||||||
|
prompt. The conversation flow (user input → tool calls → results → answer) is the
|
||||||
|
stable contract that identifies a recorded turn. The caller still stays in the
|
||||||
|
key so two different model users with identical conversation text do not compete
|
||||||
|
for the same replay bucket. (This mirrors how open-design's mock picker keys on
|
||||||
|
the user prompt, not the system internals.) Combined with pinning skills +
|
||||||
|
extensions empty and disabling memory/summarization
|
||||||
|
(`tests/_replay_fixture.py::build_config_yaml`), a fixture replays the same across
|
||||||
|
machines, days, prompt edits, and CI. Replaying needs **no API key**.
|
||||||
|
|
||||||
|
A swallowed hash-miss keeps the SSE *event shapes* identical (the gateway wraps it
|
||||||
|
into a normal assistant error message), so the Layer-1 golden can't catch a miss
|
||||||
|
by shape alone — it inspects `replay_provider.replay_misses()` and fails loud
|
||||||
|
instead. Layer-2 already fails on a miss (the recorded turns never render).
|
||||||
|
|
||||||
|
## Record a new scenario (needs a real key — dev machine only)
|
||||||
|
|
||||||
|
Recording drives the **real frontend** so captured inputs match exactly what the
|
||||||
|
browser sends; fixtures contain no API key.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. drive the real frontend against a real-model gateway, capturing model calls
|
||||||
|
OPENAI_API_KEY=... OPENAI_API_BASE=<openai-compatible-endpoint>/v1 \
|
||||||
|
DEERFLOW_RECORD_OUT=/tmp/rec/turns.jsonl RECORD_MODEL=<model> \
|
||||||
|
bash -c 'cd frontend && pnpm exec playwright test -c playwright.record.config.ts'
|
||||||
|
|
||||||
|
# 2. stitch the capture into a fixture
|
||||||
|
cd backend && uv run python scripts/build_fixture_from_jsonl.py \
|
||||||
|
--jsonl /tmp/rec/turns.jsonl --meta /tmp/rec/turns.jsonl.meta.json \
|
||||||
|
--out tests/fixtures/replay/<scenario>.<mode>.json --model <model>
|
||||||
|
|
||||||
|
# 3. regenerate the committed golden
|
||||||
|
DEERFLOW_WRITE_GOLDEN=1 PYTHONPATH=. uv run pytest tests/test_replay_golden.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run (no key)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend && PYTHONPATH=. uv run pytest tests/test_replay_golden.py # Layer 1
|
||||||
|
cd frontend && pnpm exec playwright test -c playwright.real-backend.config.ts # Layer 2
|
||||||
|
```
|
||||||
|
|
||||||
|
## CI
|
||||||
|
|
||||||
|
`.github/workflows/replay-e2e.yml` runs both layers on changes to **either** side
|
||||||
|
of the contract (`frontend/**`, `backend/app/gateway/**`,
|
||||||
|
`backend/packages/harness/**`, fixtures). DOM assertions are the gate; the rendered
|
||||||
|
screenshot + Playwright HTML report are uploaded as a CI artifact.
|
||||||
|
|
||||||
|
## Known limitations
|
||||||
|
|
||||||
|
- Visual regression baselines are OS-specific, so they are a **local dev gate
|
||||||
|
only** (gitignored); CI uploads the render as an artifact for human review
|
||||||
|
instead of hard-asserting a cross-OS baseline.
|
||||||
|
- Fixtures are coupled to the recording-time prompt; if new
|
||||||
|
environment-dependent content enters the system prompt, extend the
|
||||||
|
normalization in `replay_provider.py` (or pin it in `build_config_yaml`).
|
||||||
|
- Re-record a scenario if the agent graph changes how many model calls it makes
|
||||||
|
— the replay raises loudly on a hash miss pointing at the divergence.
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
# Sandbox Memory Profiling
|
||||||
|
|
||||||
|
This guide records a repeatable baseline before changing the sandbox runtime.
|
||||||
|
Issue #3213 reports per-sandbox memory near 1 GiB in Kubernetes. Before adding
|
||||||
|
or recommending a new provider, capture the current AIO sandbox baseline and
|
||||||
|
compare candidates with the same DeerFlow workload.
|
||||||
|
|
||||||
|
## What to Measure
|
||||||
|
|
||||||
|
Measure at least these samples:
|
||||||
|
|
||||||
|
1. Empty sandbox after it becomes ready.
|
||||||
|
2. After a simple bash command.
|
||||||
|
3. After a Python task that imports common packages.
|
||||||
|
4. After a Node task when Node-based workloads are expected.
|
||||||
|
5. After generating files under `/mnt/user-data/outputs`.
|
||||||
|
6. After release and warm reuse.
|
||||||
|
7. At the target concurrency level, for example 10, 50, or 100 sandboxes.
|
||||||
|
|
||||||
|
`kubectl top` reports Kubernetes/container working set memory. Treat it as a
|
||||||
|
capacity signal, not exclusive RSS/PSS. Pod-level memory includes every
|
||||||
|
container in the Pod and may include cache charged to the cgroup. If a result
|
||||||
|
looks surprising, inspect the sandbox processes and cgroup metrics on the node
|
||||||
|
before drawing conclusions.
|
||||||
|
|
||||||
|
## Capture a Snapshot
|
||||||
|
|
||||||
|
Run this from the repository root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/sandbox_memory_profile.py \
|
||||||
|
--namespace deer-flow \
|
||||||
|
--selector app=deer-flow-sandbox \
|
||||||
|
--sample empty \
|
||||||
|
--include-processes \
|
||||||
|
--format markdown
|
||||||
|
```
|
||||||
|
|
||||||
|
Use a descriptive `--sample` value for each phase:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/sandbox_memory_profile.py --sample after-bash --format json
|
||||||
|
python scripts/sandbox_memory_profile.py --sample after-python --format json
|
||||||
|
python scripts/sandbox_memory_profile.py --sample after-artifact --format json
|
||||||
|
```
|
||||||
|
|
||||||
|
`--include-processes` runs `kubectl exec ... ps` in each sandbox Pod and adds
|
||||||
|
the highest-RSS processes to the report. This helps distinguish Pod-level cgroup
|
||||||
|
memory from process RSS. The two numbers will not match exactly because cgroup
|
||||||
|
memory can include cache and other kernel-accounted memory.
|
||||||
|
|
||||||
|
Save the raw JSON when comparing backends so totals, pod names, images,
|
||||||
|
requests, limits, and timestamps can be audited later.
|
||||||
|
|
||||||
|
## Candidate Runtime Matrix
|
||||||
|
|
||||||
|
For AIO, CubeSandbox, OpenSandbox, gVisor, Kata, or another candidate, compare
|
||||||
|
the same workload and record:
|
||||||
|
|
||||||
|
| Area | Required Evidence |
|
||||||
|
| --- | --- |
|
||||||
|
| Capacity | Pod or instance count, total memory, average memory, max memory |
|
||||||
|
| Startup | Ready latency at 1, 10, 50, and 100 concurrent sandboxes |
|
||||||
|
| Commands | Bash output, timeout behavior, failure shape |
|
||||||
|
| Files | `read_file`, `write_file`, binary `update_file`, `list_dir`, `glob`, `grep` |
|
||||||
|
| Uploads | Files uploaded by the gateway are visible inside the sandbox |
|
||||||
|
| Artifacts | Files written to `/mnt/user-data/outputs` are readable by the backend artifact API |
|
||||||
|
| Paths | `/mnt/user-data/workspace`, `/mnt/user-data/uploads`, `/mnt/user-data/outputs`, `/mnt/acp-workspace`, and skills paths keep their expected semantics |
|
||||||
|
| Isolation | Different users and threads cannot read each other's data |
|
||||||
|
| Cleanup | Release, idle timeout, process restart, and orphan cleanup free resources |
|
||||||
|
| Operations | Deployment prerequisites, privileged components, networking, storage, and upgrade path |
|
||||||
|
|
||||||
|
## PR Guidance
|
||||||
|
|
||||||
|
Do not claim that a new provider fixes high-concurrency memory usage until the
|
||||||
|
same DeerFlow workload has been measured on both the current AIO sandbox and the
|
||||||
|
candidate backend.
|
||||||
|
|
||||||
|
For an experimental provider PR, prefer `Related to #3213` unless the PR also
|
||||||
|
includes reproducible DeerFlow workload data that demonstrates the target memory
|
||||||
|
reduction and preserves uploads, outputs, artifacts, and isolation behavior.
|
||||||
@@ -127,8 +127,8 @@ complex_agent = create_agent_for_task("high")
|
|||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
1. When `make_lead_agent(config)` is called, it extracts `is_plan_mode` from `config.configurable`
|
1. When `make_lead_agent(config)` is called, it extracts `is_plan_mode` from `config.configurable`
|
||||||
2. The config is passed to `_build_middlewares(config)`
|
2. The config is passed to `build_middlewares(config)`
|
||||||
3. `_build_middlewares()` reads `is_plan_mode` and calls `_create_todo_list_middleware(is_plan_mode)`
|
3. `build_middlewares()` reads `is_plan_mode` and calls `_create_todo_list_middleware(is_plan_mode)`
|
||||||
4. If `is_plan_mode=True`, a `TodoListMiddleware` instance is created and added to the middleware chain
|
4. If `is_plan_mode=True`, a `TodoListMiddleware` instance is created and added to the middleware chain
|
||||||
5. The middleware automatically adds a `write_todos` tool to the agent's toolset
|
5. The middleware automatically adds a `write_todos` tool to the agent's toolset
|
||||||
6. The agent can use this tool to manage tasks during execution
|
6. The agent can use this tool to manage tasks during execution
|
||||||
@@ -141,7 +141,7 @@ make_lead_agent(config)
|
|||||||
│
|
│
|
||||||
├─> Extracts: is_plan_mode = config.configurable.get("is_plan_mode", False)
|
├─> Extracts: is_plan_mode = config.configurable.get("is_plan_mode", False)
|
||||||
│
|
│
|
||||||
└─> _build_middlewares(config)
|
└─> build_middlewares(config)
|
||||||
│
|
│
|
||||||
├─> ThreadDataMiddleware
|
├─> ThreadDataMiddleware
|
||||||
├─> SandboxMiddleware
|
├─> SandboxMiddleware
|
||||||
@@ -156,7 +156,7 @@ make_lead_agent(config)
|
|||||||
### Agent Module
|
### Agent Module
|
||||||
- **Location**: `packages/harness/deerflow/agents/lead_agent/agent.py`
|
- **Location**: `packages/harness/deerflow/agents/lead_agent/agent.py`
|
||||||
- **Function**: `_create_todo_list_middleware(is_plan_mode: bool)` - Creates TodoListMiddleware if plan mode is enabled
|
- **Function**: `_create_todo_list_middleware(is_plan_mode: bool)` - Creates TodoListMiddleware if plan mode is enabled
|
||||||
- **Function**: `_build_middlewares(config: RunnableConfig)` - Builds middleware chain based on runtime config
|
- **Function**: `build_middlewares(config: RunnableConfig)` - Builds middleware chain based on runtime config
|
||||||
- **Function**: `make_lead_agent(config: RunnableConfig)` - Creates agent with appropriate middlewares
|
- **Function**: `make_lead_agent(config: RunnableConfig)` - Creates agent with appropriate middlewares
|
||||||
|
|
||||||
### Runtime Configuration
|
### Runtime Configuration
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
|
|||||||
``create_chat_model`` call must add to this list and pass the flag.
|
``create_chat_model`` call must add to this list and pass the flag.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
@@ -47,6 +49,8 @@ from deerflow.tracing import build_tracing_callbacks
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_BOOTSTRAP_SKILL_NAMES = {"bootstrap"}
|
||||||
|
|
||||||
|
|
||||||
def _get_runtime_config(config: RunnableConfig) -> dict:
|
def _get_runtime_config(config: RunnableConfig) -> dict:
|
||||||
"""Merge legacy configurable options with LangGraph runtime context."""
|
"""Merge legacy configurable options with LangGraph runtime context."""
|
||||||
@@ -263,20 +267,31 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
|||||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||||
def _build_middlewares(
|
def build_middlewares(
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
model_name: str | None,
|
model_name: str | None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
custom_middlewares: list[AgentMiddleware] | None = None,
|
custom_middlewares: list[AgentMiddleware] | None = None,
|
||||||
*,
|
*,
|
||||||
|
available_skills: set[str] | None = None,
|
||||||
app_config: AppConfig | None = None,
|
app_config: AppConfig | None = None,
|
||||||
|
deferred_setup=None,
|
||||||
):
|
):
|
||||||
"""Build middleware chain based on runtime configuration.
|
"""Build the lead-agent middleware chain based on runtime configuration.
|
||||||
|
|
||||||
|
Public entry point for the lead agent's full middleware composition. Used by
|
||||||
|
``make_lead_agent`` and by the embedded ``DeerFlowClient`` (a lead-agent variant
|
||||||
|
that needs the identical chain). Keep this name stable: it is imported across a
|
||||||
|
module boundary, so renames/signature changes ripple into ``client.py``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||||
|
model_name: Resolved runtime model name; gates vision-only middleware.
|
||||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||||
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
||||||
|
app_config: Explicit AppConfig; falls back to ``get_app_config()`` when omitted.
|
||||||
|
deferred_setup: Optional deferred-MCP-tool setup that attaches
|
||||||
|
``DeferredToolFilterMiddleware`` when ``tool_search`` is enabled.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of middleware instances.
|
List of middleware instances.
|
||||||
@@ -290,6 +305,13 @@ def _build_middlewares(
|
|||||||
|
|
||||||
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
||||||
|
|
||||||
|
# Deterministically load a full SKILL.md when the user starts the turn with
|
||||||
|
# /skill-name. This keeps the base system prompt metadata-only while giving
|
||||||
|
# explicit user activation priority over model-side relevance guessing.
|
||||||
|
from deerflow.agents.middlewares.skill_activation_middleware import SkillActivationMiddleware
|
||||||
|
|
||||||
|
middlewares.append(SkillActivationMiddleware(available_skills=available_skills, app_config=resolved_app_config))
|
||||||
|
|
||||||
# Add summarization middleware if enabled
|
# Add summarization middleware if enabled
|
||||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
||||||
if summarization_middleware is not None:
|
if summarization_middleware is not None:
|
||||||
@@ -318,11 +340,13 @@ def _build_middlewares(
|
|||||||
if model_config is not None and model_config.supports_vision:
|
if model_config is not None and model_config.supports_vision:
|
||||||
middlewares.append(ViewImageMiddleware())
|
middlewares.append(ViewImageMiddleware())
|
||||||
|
|
||||||
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
|
# Hide deferred tool schemas from model binding until tool_search promotes them.
|
||||||
if resolved_app_config.tool_search.enabled:
|
# The deferred set + catalog hash come from the build-time setup (assembled
|
||||||
|
# after tool-policy filtering); promotion is read from graph state.
|
||||||
|
if deferred_setup is not None and deferred_setup.deferred_names:
|
||||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
middlewares.append(DeferredToolFilterMiddleware())
|
middlewares.append(DeferredToolFilterMiddleware(deferred_setup.deferred_names, deferred_setup.catalog_hash))
|
||||||
|
|
||||||
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
||||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
subagent_enabled = cfg.get("subagent_enabled", False)
|
||||||
@@ -355,7 +379,7 @@ def _build_middlewares(
|
|||||||
|
|
||||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
||||||
if is_bootstrap:
|
if is_bootstrap:
|
||||||
return {"bootstrap"}
|
return set(_BOOTSTRAP_SKILL_NAMES)
|
||||||
if agent_config and agent_config.skills is not None:
|
if agent_config and agent_config.skills is not None:
|
||||||
return set(agent_config.skills)
|
return set(agent_config.skills)
|
||||||
return None
|
return None
|
||||||
@@ -386,6 +410,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from deerflow.tools import get_available_tools
|
from deerflow.tools import get_available_tools
|
||||||
from deerflow.tools.builtins import setup_agent, update_agent
|
from deerflow.tools.builtins import setup_agent, update_agent
|
||||||
|
from deerflow.tools.builtins.tool_search import assemble_deferred_tools
|
||||||
|
|
||||||
cfg = _get_runtime_config(config)
|
cfg = _get_runtime_config(config)
|
||||||
resolved_app_config = app_config
|
resolved_app_config = app_config
|
||||||
@@ -460,16 +485,27 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
|
|
||||||
if is_bootstrap:
|
if is_bootstrap:
|
||||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||||
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
# Keep the bootstrap skill set intentionally narrow so agent creation
|
||||||
|
# remains deterministic before the custom agent's own config exists.
|
||||||
|
raw_tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
||||||
|
filtered = filter_tools_by_skill_allowed_tools(raw_tools, skills_for_tool_policy)
|
||||||
|
final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled)
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False),
|
||||||
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
|
tools=final_tools,
|
||||||
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
|
middleware=build_middlewares(
|
||||||
|
config,
|
||||||
|
model_name=model_name,
|
||||||
|
available_skills=set(_BOOTSTRAP_SKILL_NAMES),
|
||||||
|
app_config=resolved_app_config,
|
||||||
|
deferred_setup=setup,
|
||||||
|
),
|
||||||
system_prompt=apply_prompt_template(
|
system_prompt=apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
available_skills=set(["bootstrap"]),
|
available_skills=set(_BOOTSTRAP_SKILL_NAMES),
|
||||||
app_config=resolved_app_config,
|
app_config=resolved_app_config,
|
||||||
|
deferred_names=setup.deferred_names,
|
||||||
),
|
),
|
||||||
state_schema=ThreadState,
|
state_schema=ThreadState,
|
||||||
)
|
)
|
||||||
@@ -478,17 +514,27 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
# The default agent (no agent_name) does not see this tool.
|
# The default agent (no agent_name) does not see this tool.
|
||||||
extra_tools = [update_agent] if agent_name else []
|
extra_tools = [update_agent] if agent_name else []
|
||||||
# Default lead agent (unchanged behavior)
|
# Default lead agent (unchanged behavior)
|
||||||
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
raw_tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
||||||
|
filtered = filter_tools_by_skill_allowed_tools(raw_tools + extra_tools, skills_for_tool_policy)
|
||||||
|
final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled)
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False),
|
||||||
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
|
tools=final_tools,
|
||||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
|
middleware=build_middlewares(
|
||||||
|
config,
|
||||||
|
model_name=model_name,
|
||||||
|
agent_name=agent_name,
|
||||||
|
available_skills=available_skills,
|
||||||
|
app_config=resolved_app_config,
|
||||||
|
deferred_setup=setup,
|
||||||
|
),
|
||||||
system_prompt=apply_prompt_template(
|
system_prompt=apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
|
available_skills=available_skills,
|
||||||
app_config=resolved_app_config,
|
app_config=resolved_app_config,
|
||||||
|
deferred_names=setup.deferred_names,
|
||||||
),
|
),
|
||||||
state_schema=ThreadState,
|
state_schema=ThreadState,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from deerflow.config.agents_config import load_agent_soul
|
|||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
from deerflow.skills.types import Skill, SkillCategory
|
from deerflow.skills.types import Skill, SkillCategory
|
||||||
from deerflow.subagents import get_available_subagent_names
|
from deerflow.subagents import get_available_subagent_names
|
||||||
|
from deerflow.tools.builtins.tool_search import get_deferred_tools_prompt_section
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
@@ -542,6 +543,14 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
|
|||||||
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
|
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
|
||||||
- Progressive Loading: Load resources incrementally as referenced in skills
|
- Progressive Loading: Load resources incrementally as referenced in skills
|
||||||
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
|
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
|
||||||
|
- File Editing Workflow: When revising an existing file, prefer
|
||||||
|
`str_replace` over `write_file` — it sends only the diff and avoids
|
||||||
|
re-emitting the whole file (mirrors Claude Code's Edit and Codex's
|
||||||
|
apply_patch). When writing long new content from scratch, split it
|
||||||
|
into sections: the first `write_file` call creates the file, then use
|
||||||
|
`write_file` with append=True to extend it section by section. This
|
||||||
|
keeps each tool call small and avoids mid-stream chunk-gap timeouts
|
||||||
|
on oversized single-shot writes. (See issue #3189.)
|
||||||
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
|
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
|
||||||
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `\n\n` or "```mermaid" to display images in response or Markdown files
|
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `\n\n` or "```mermaid" to display images in response or Markdown files
|
||||||
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
|
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
|
||||||
@@ -577,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
memory_content = format_memory_for_injection(
|
||||||
|
memory_data,
|
||||||
|
max_tokens=config.max_injection_tokens,
|
||||||
|
use_tiktoken=(config.token_counting == "tiktoken"),
|
||||||
|
)
|
||||||
|
|
||||||
if not memory_content.strip():
|
if not memory_content.strip():
|
||||||
return ""
|
return ""
|
||||||
@@ -616,6 +629,11 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
|||||||
4. Load referenced resources only when needed during execution
|
4. Load referenced resources only when needed during execution
|
||||||
5. Follow the skill's instructions precisely
|
5. Follow the skill's instructions precisely
|
||||||
|
|
||||||
|
**Explicit Slash Skill Activation:**
|
||||||
|
- If the user starts a request with `/<skill-name>`, that skill was explicitly requested for the current turn.
|
||||||
|
- Follow the activated skill before choosing a general workflow.
|
||||||
|
- The runtime injects the activated skill content for explicit slash activations; do not call `read_file` for that SKILL.md again unless the injected skill references supporting resources you need.
|
||||||
|
|
||||||
**Skills are located at:** {container_base_path}
|
**Skills are located at:** {container_base_path}
|
||||||
{skill_evolution_section}
|
{skill_evolution_section}
|
||||||
{skills_list}
|
{skills_list}
|
||||||
@@ -678,42 +696,13 @@ SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace a
|
|||||||
Rules:
|
Rules:
|
||||||
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
|
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
|
||||||
- Only pass the fields that should change. Omit the others to preserve them.
|
- Only pass the fields that should change. Omit the others to preserve them.
|
||||||
|
- Never pass literal strings like `"null"`, `"none"`, or `"undefined"` for unchanged fields.
|
||||||
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
|
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
|
||||||
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
|
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
|
||||||
</self_update>
|
</self_update>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
|
|
||||||
"""Generate <available-deferred-tools> block for the system prompt.
|
|
||||||
|
|
||||||
Lists only deferred tool names so the agent knows what exists
|
|
||||||
and can use tool_search to load them.
|
|
||||||
Returns empty string when tool_search is disabled or no tools are deferred.
|
|
||||||
"""
|
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
|
||||||
|
|
||||||
if app_config is None:
|
|
||||||
try:
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
config = get_app_config()
|
|
||||||
except Exception:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
config = app_config
|
|
||||||
|
|
||||||
if not config.tool_search.enabled:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
registry = get_deferred_registry()
|
|
||||||
if not registry:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
names = "\n".join(e.name for e in registry.entries)
|
|
||||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
||||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||||
if app_config is None:
|
if app_config is None:
|
||||||
@@ -772,6 +761,7 @@ def apply_prompt_template(
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
available_skills: set[str] | None = None,
|
available_skills: set[str] | None = None,
|
||||||
app_config: AppConfig | None = None,
|
app_config: AppConfig | None = None,
|
||||||
|
deferred_names: frozenset[str] = frozenset(),
|
||||||
) -> str:
|
) -> str:
|
||||||
# Include subagent section only if enabled (from runtime parameter)
|
# Include subagent section only if enabled (from runtime parameter)
|
||||||
n = max_concurrent_subagents
|
n = max_concurrent_subagents
|
||||||
@@ -799,7 +789,7 @@ def apply_prompt_template(
|
|||||||
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
|
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
|
||||||
|
|
||||||
# Get deferred tools section (tool_search)
|
# Get deferred tools section (tool_search)
|
||||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
|
deferred_tools_section = get_deferred_tools_prompt_section(deferred_names=deferred_names)
|
||||||
|
|
||||||
# Build ACP agent section only if ACP agents are configured
|
# Build ACP agent section only if ACP agents are configured
|
||||||
acp_section = _build_acp_section(app_config=app_config)
|
acp_section = _build_acp_section(app_config=app_config)
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
"""Prompt templates for memory update and injection."""
|
"""Prompt templates for memory update and injection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@@ -160,26 +167,137 @@ Rules:
|
|||||||
Return ONLY valid JSON."""
|
Return ONLY valid JSON."""
|
||||||
|
|
||||||
|
|
||||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
# Module-level tiktoken encoding cache. Populated lazily on first use;
|
||||||
|
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||||
|
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||||
|
# (potentially slow) first ``get_encoding`` call.
|
||||||
|
#
|
||||||
|
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
||||||
|
# a network-restricted environment does not re-attempt the blocking BPE
|
||||||
|
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
||||||
|
# failure is allowed to expire so a transient network outage can self-heal back
|
||||||
|
# to accurate tiktoken counting without a process restart. A load already in
|
||||||
|
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
||||||
|
# fall back immediately instead of spawning more blocking
|
||||||
|
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
||||||
|
# config to skip tiktoken entirely.
|
||||||
|
_TIKTOKEN_ENCODING_MISSING = object()
|
||||||
|
_TIKTOKEN_ENCODING_LOADING = object()
|
||||||
|
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
||||||
|
# tuning constant rather than a user-facing config: it only affects how quickly
|
||||||
|
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
||||||
|
# Deployments that want to avoid tiktoken's network dependency entirely should
|
||||||
|
# set ``memory.token_counting: char`` instead of tuning this value.
|
||||||
|
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
||||||
|
_tiktoken_encoding_cache: dict[str, Any] = {}
|
||||||
|
_tiktoken_encoding_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||||
|
"""Return a cached tiktoken encoding, or ``None`` on failure / unavailability.
|
||||||
|
|
||||||
|
On the very first call for a given *encoding_name*, tiktoken may need to
|
||||||
|
download the BPE data from ``openaipublic.blob.core.windows.net``. In
|
||||||
|
network-restricted environments (e.g. deployments behind the GFW) this
|
||||||
|
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||||
|
The caller must therefore be prepared for this to block and should run it
|
||||||
|
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||||
|
|
||||||
|
A failed load is remembered (with a timestamp) so subsequent calls fall
|
||||||
|
back immediately to character-based estimation instead of re-triggering the
|
||||||
|
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
||||||
|
so a transient outage can self-heal without a restart. A load already in
|
||||||
|
progress is also remembered so that a timed-out caller does not leave a
|
||||||
|
window where later requests start more blocking ``get_encoding`` calls.
|
||||||
|
"""
|
||||||
|
if not TIKTOKEN_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
||||||
|
if cached is _TIKTOKEN_ENCODING_LOADING:
|
||||||
|
return None
|
||||||
|
if isinstance(cached, tuple):
|
||||||
|
# Cached failure: (None, failed_at). Retry only after cooldown.
|
||||||
|
_, failed_at = cached
|
||||||
|
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
||||||
|
return None
|
||||||
|
cached = _TIKTOKEN_ENCODING_MISSING
|
||||||
|
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
||||||
|
return cast("tiktoken.Encoding", cached)
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.get_encoding(encoding_name)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
||||||
|
return None
|
||||||
|
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
def _char_based_token_estimate(text: str) -> int:
|
||||||
|
"""Network-free token estimate that accounts for CJK density.
|
||||||
|
|
||||||
|
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
||||||
|
(~4 chars per token) but significantly under-estimates token counts for
|
||||||
|
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
||||||
|
characters per token. Counting CJK characters separately (~2 chars per
|
||||||
|
token) avoids over-filling the injection budget for CJK-heavy memory
|
||||||
|
content.
|
||||||
|
"""
|
||||||
|
cjk = sum(
|
||||||
|
1
|
||||||
|
for ch in text
|
||||||
|
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
||||||
|
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
||||||
|
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
||||||
|
)
|
||||||
|
return (len(text) - cjk) // 4 + cjk // 2
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
||||||
"""Count tokens in text using tiktoken.
|
"""Count tokens in text using tiktoken.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to count tokens for.
|
text: The text to count tokens for.
|
||||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||||
|
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
||||||
|
network-free character-based estimate. This guarantees no BPE
|
||||||
|
download is attempted (see ``memory.token_counting`` config).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens in the text.
|
The number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
if not TIKTOKEN_AVAILABLE:
|
if not use_tiktoken:
|
||||||
# Fallback to character-based estimation if tiktoken is not available
|
return _char_based_token_estimate(text)
|
||||||
return len(text) // 4
|
|
||||||
|
encoding = _get_tiktoken_encoding(encoding_name)
|
||||||
|
if encoding is None:
|
||||||
|
# Fallback to CJK-aware character estimation if tiktoken is not
|
||||||
|
# available or the encoding failed to load.
|
||||||
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to character-based estimation on error
|
# Fallback to CJK-aware character estimation on error.
|
||||||
return len(text) // 4
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
|
|
||||||
|
def warm_tiktoken_cache() -> bool:
|
||||||
|
"""Pre-warm the tiktoken encoding cache.
|
||||||
|
|
||||||
|
Call at startup (off the event loop) so the first request never blocks
|
||||||
|
on the BPE download. Returns ``True`` if the encoding was loaded
|
||||||
|
successfully (or was already cached), ``False`` if tiktoken is
|
||||||
|
unavailable or the download failed.
|
||||||
|
"""
|
||||||
|
return _get_tiktoken_encoding("cl100k_base") is not None
|
||||||
|
|
||||||
|
|
||||||
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||||
@@ -198,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
|||||||
return max(0.0, min(1.0, confidence))
|
return max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
|
||||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
||||||
"""Format memory data for injection into system prompt.
|
"""Format memory data for injection into system prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: The memory data dictionary.
|
memory_data: The memory data dictionary.
|
||||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||||
|
use_tiktoken: When ``False``, all token counting uses the network-free
|
||||||
|
character-based estimate instead of tiktoken (see
|
||||||
|
``memory.token_counting`` config). Defaults to ``True``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted memory string for system prompt injection.
|
Formatted memory string for system prompt injection.
|
||||||
@@ -265,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
# Compute token count for existing sections once, then account
|
# Compute token count for existing sections once, then account
|
||||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||||
base_text = "\n\n".join(sections)
|
base_text = "\n\n".join(sections)
|
||||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
||||||
# Account for the separator between existing sections and the facts section.
|
# Account for the separator between existing sections and the facts section.
|
||||||
facts_header = "Facts:\n"
|
facts_header = "Facts:\n"
|
||||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
||||||
running_tokens = base_tokens + separator_tokens
|
running_tokens = base_tokens + separator_tokens
|
||||||
|
|
||||||
fact_lines: list[str] = []
|
fact_lines: list[str] = []
|
||||||
@@ -289,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
# Each additional line is preceded by a newline (except the first).
|
# Each additional line is preceded by a newline (except the first).
|
||||||
line_text = ("\n" + line) if fact_lines else line
|
line_text = ("\n" + line) if fact_lines else line
|
||||||
line_tokens = _count_tokens(line_text)
|
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
||||||
|
|
||||||
if running_tokens + line_tokens <= max_tokens:
|
if running_tokens + line_tokens <= max_tokens:
|
||||||
fact_lines.append(line)
|
fact_lines.append(line)
|
||||||
@@ -305,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
result = "\n\n".join(sections)
|
result = "\n\n".join(sections)
|
||||||
|
|
||||||
# Use accurate token counting with tiktoken
|
# Use accurate token counting with tiktoken (or the char-based estimate
|
||||||
token_count = _count_tokens(result)
|
# when use_tiktoken is False).
|
||||||
|
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
||||||
if token_count > max_tokens:
|
if token_count > max_tokens:
|
||||||
# Truncate to fit within token limit
|
# Truncate to fit within token limit
|
||||||
# Estimate characters to remove based on token ratio
|
# Estimate characters to remove based on token ratio
|
||||||
|
|||||||
+39
-34
@@ -1,12 +1,15 @@
|
|||||||
"""Middleware to filter deferred tool schemas from model binding.
|
"""Middleware to filter deferred tool schemas from model binding.
|
||||||
|
|
||||||
When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
|
When tool_search is enabled, MCP tools are still passed to ToolNode for
|
||||||
and passed to ToolNode for execution, but their schemas should NOT be sent to the
|
execution, but their schemas must NOT be sent to the LLM via bind_tools until
|
||||||
LLM via bind_tools (that's the whole point of deferral — saving context tokens).
|
the model has discovered them via tool_search. This middleware removes the
|
||||||
|
still-deferred tools from request.tools before model binding, and blocks tool
|
||||||
|
calls to tools that have not been promoted yet.
|
||||||
|
|
||||||
This middleware intercepts wrap_model_call and removes deferred tools from
|
The deferred name set and the catalog hash are injected at construction time
|
||||||
request.tools so that model.bind_tools only receives active tool schemas.
|
(no ContextVar). Promotion state is read from graph state (``state["promoted"]``),
|
||||||
The agent discovers deferred tools at runtime via the tool_search tool.
|
scoped by catalog hash so a stale persisted promotion cannot expose a renamed
|
||||||
|
or drifted tool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -24,47 +27,49 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
||||||
"""Remove deferred tools from request.tools before model binding.
|
"""Hide deferred tool schemas from the bound model until promoted.
|
||||||
|
|
||||||
ToolNode still holds all tools (including deferred) for execution routing,
|
ToolNode still holds all tools (including deferred) for execution routing,
|
||||||
but the LLM only sees active tool schemas — deferred tools are discoverable
|
but the LLM only sees active tool schemas plus tools that have already been
|
||||||
via tool_search at runtime.
|
promoted (recorded in ``state["promoted"]`` under the current catalog hash).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deferred_names: frozenset[str], catalog_hash: str | None):
|
||||||
|
super().__init__()
|
||||||
|
self._deferred = deferred_names
|
||||||
|
self._catalog_hash = catalog_hash
|
||||||
|
|
||||||
|
def _promoted(self, state) -> set[str]:
|
||||||
|
promoted = (state or {}).get("promoted")
|
||||||
|
if promoted and promoted.get("catalog_hash") == self._catalog_hash:
|
||||||
|
return set(promoted.get("names") or [])
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def _hidden(self, state) -> set[str]:
|
||||||
|
return set(self._deferred) - self._promoted(state)
|
||||||
|
|
||||||
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
|
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
if not self._deferred:
|
||||||
|
|
||||||
registry = get_deferred_registry()
|
|
||||||
if not registry:
|
|
||||||
return request
|
return request
|
||||||
|
hide = self._hidden(request.state)
|
||||||
deferred_names = registry.deferred_names
|
if not hide:
|
||||||
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
return request
|
||||||
|
active = [t for t in request.tools if getattr(t, "name", None) not in hide]
|
||||||
if len(active_tools) < len(request.tools):
|
if len(active) < len(request.tools):
|
||||||
logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
|
logger.debug("Filtered %d deferred tool schema(s) from model binding", len(request.tools) - len(active))
|
||||||
|
return request.override(tools=active)
|
||||||
return request.override(tools=active_tools)
|
|
||||||
|
|
||||||
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
|
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
if not self._deferred:
|
||||||
|
|
||||||
registry = get_deferred_registry()
|
|
||||||
if not registry:
|
|
||||||
return None
|
return None
|
||||||
|
name = str(request.tool_call.get("name") or "")
|
||||||
tool_name = str(request.tool_call.get("name") or "")
|
if not name or name not in self._hidden(request.state):
|
||||||
if not tool_name:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not registry.contains(tool_name):
|
|
||||||
return None
|
|
||||||
|
|
||||||
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
|
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
|
content=(f"Error: Tool '{name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
name=tool_name,
|
name=name,
|
||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ Date-update format:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -43,6 +44,12 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Upper bound (seconds) for a single _inject() offload. If the warm-up at
|
||||||
|
# gateway startup failed silently, the first request may still hit a cold
|
||||||
|
# tiktoken BPE download that blocks until the OS TCP timeout (~26 min).
|
||||||
|
# This cap ensures the request degrades gracefully instead of hanging.
|
||||||
|
_INJECT_TIMEOUT_SECONDS = 5.0
|
||||||
|
|
||||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
||||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
||||||
_SUMMARY_MESSAGE_NAME = "summary"
|
_SUMMARY_MESSAGE_NAME = "summary"
|
||||||
@@ -201,4 +208,25 @@ class DynamicContextMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
||||||
return self._inject(state)
|
# _inject() performs synchronous file I/O (memory JSON loading) and
|
||||||
|
# potentially blocking network calls (tiktoken encoding download on
|
||||||
|
# first use). Offload to a thread so the event loop is never blocked
|
||||||
|
# — a blocking call here starves all concurrent HTTP handlers (auth,
|
||||||
|
# SSE heartbeats, etc.). See issue #3402.
|
||||||
|
#
|
||||||
|
# Bounded timeout: if startup warm-up failed silently (e.g. network
|
||||||
|
# blip during deploy), the first request's cold tiktoken download can
|
||||||
|
# block for tens of minutes (OS TCP timeout). Time-box injection so
|
||||||
|
# the request degrades gracefully (no memory context) rather than
|
||||||
|
# hanging.
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(self._inject, state),
|
||||||
|
timeout=_INJECT_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"DynamicContextMiddleware: injection timed out (%.1fs); skipping memory/date injection for this turn",
|
||||||
|
_INJECT_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|||||||
+106
-6
@@ -62,6 +62,41 @@ _AUTH_PATTERNS = (
|
|||||||
"未授权",
|
"未授权",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-exception retry budget overrides.
|
||||||
|
#
|
||||||
|
# Some transient errors are retriable in principle but expensive to retry at
|
||||||
|
# the default budget. StreamChunkTimeoutError in particular fires after the
|
||||||
|
# upstream provider has already stalled for `stream_chunk_timeout` seconds
|
||||||
|
# (typically 120-240s); a full 3-attempt loop can therefore stack 6-12 minutes
|
||||||
|
# of dead air before surfacing the failure to the user. We keep exactly one
|
||||||
|
# retry (cheap reconnect that catches genuine transient TCP blips) and then
|
||||||
|
# fail fast — the same buffered payload is overwhelmingly likely to fail
|
||||||
|
# again at the upstream provider for the same reason.
|
||||||
|
#
|
||||||
|
# Keys are exception class *names* (not classes) so we don't introduce
|
||||||
|
# import-time coupling on optional dependencies like langchain-openai. The
|
||||||
|
# value is the absolute max attempt count, NOT additional retries — so a
|
||||||
|
# value of 2 means "1 first attempt + 1 retry" (the CR-requested
|
||||||
|
# "keep one retry" behavior).
|
||||||
|
_RETRY_BUDGET_OVERRIDES: dict[str, int] = {
|
||||||
|
"StreamChunkTimeoutError": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Exception class names that indicate the upstream stream-chunk watchdog
|
||||||
|
# fired because the model stalled mid-flight. These deserve a more specific
|
||||||
|
# user-facing message than the generic "temporarily unavailable" copy,
|
||||||
|
# because the typical root cause is a long tool-call serialization stalling
|
||||||
|
# the upstream stream — and the most actionable advice we can give the user
|
||||||
|
# is "ask for a shorter / split output" rather than "wait and retry".
|
||||||
|
# Generic connection drops (httpx RemoteProtocolError / ReadError) are
|
||||||
|
# intentionally excluded: they routinely fire on transient network blips
|
||||||
|
# with normal payloads, where the "split the work" guidance is misleading.
|
||||||
|
_STREAM_DROP_EXCEPTIONS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"StreamChunkTimeoutError",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||||
@@ -83,6 +118,18 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
self._circuit_state = "closed"
|
self._circuit_state = "closed"
|
||||||
self._circuit_probe_in_flight = False
|
self._circuit_probe_in_flight = False
|
||||||
|
|
||||||
|
def _max_attempts_for(self, exc: BaseException) -> int:
|
||||||
|
"""Return the effective max attempt count for this exception.
|
||||||
|
|
||||||
|
Falls back to `self.retry_max_attempts` unless the exception class name
|
||||||
|
appears in the per-exception override table.
|
||||||
|
"""
|
||||||
|
override = _RETRY_BUDGET_OVERRIDES.get(type(exc).__name__)
|
||||||
|
if override is None:
|
||||||
|
return self.retry_max_attempts
|
||||||
|
|
||||||
|
return min(override, self.retry_max_attempts)
|
||||||
|
|
||||||
def _check_circuit(self) -> bool:
|
def _check_circuit(self) -> bool:
|
||||||
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
|
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
|
||||||
with self._circuit_lock:
|
with self._circuit_lock:
|
||||||
@@ -153,6 +200,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
"InternalServerError",
|
"InternalServerError",
|
||||||
"ReadError", # httpx.ReadError: connection dropped mid-stream
|
"ReadError", # httpx.ReadError: connection dropped mid-stream
|
||||||
"RemoteProtocolError", # httpx: server closed connection unexpectedly
|
"RemoteProtocolError", # httpx: server closed connection unexpectedly
|
||||||
|
"StreamChunkTimeoutError", # langchain-openai: chunk gap exceeded stream_chunk_timeout
|
||||||
}:
|
}:
|
||||||
return True, "transient"
|
return True, "transient"
|
||||||
if status_code in _RETRIABLE_STATUS_CODES:
|
if status_code in _RETRIABLE_STATUS_CODES:
|
||||||
@@ -177,6 +225,24 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
def _build_circuit_breaker_message(self) -> str:
|
def _build_circuit_breaker_message(self) -> str:
|
||||||
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
||||||
|
|
||||||
|
def _build_error_fallback_message(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
error_type: str,
|
||||||
|
reason: str,
|
||||||
|
detail: str,
|
||||||
|
) -> AIMessage:
|
||||||
|
return AIMessage(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_type": error_type,
|
||||||
|
"error_reason": reason,
|
||||||
|
"error_detail": detail,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||||
detail = _extract_error_detail(exc)
|
detail = _extract_error_detail(exc)
|
||||||
if reason == "quota":
|
if reason == "quota":
|
||||||
@@ -184,9 +250,31 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if reason == "auth":
|
if reason == "auth":
|
||||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||||
if reason in {"busy", "transient"}:
|
if reason in {"busy", "transient"}:
|
||||||
|
# Stream-drop failures (chunk-gap timeout, peer-closed connection,
|
||||||
|
# raw read error) almost always point at a single oversized
|
||||||
|
# tool-call payload — the model spent so long serializing JSON
|
||||||
|
# arguments that the upstream provider buffered and the stream
|
||||||
|
# gap exceeded `stream_chunk_timeout`. Surfacing this distinct
|
||||||
|
# cause lets the user split or shorten their next request
|
||||||
|
# instead of helplessly retrying the same prompt.
|
||||||
|
if type(exc).__name__ in _STREAM_DROP_EXCEPTIONS:
|
||||||
|
return (
|
||||||
|
"The model's streaming response was interrupted before it could "
|
||||||
|
"finish. This usually happens when a single response or tool call "
|
||||||
|
"is very large — please ask the assistant to split the work into "
|
||||||
|
"smaller steps, or shorten the requested output, and try again."
|
||||||
|
)
|
||||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||||
return f"LLM request failed: {detail}"
|
return f"LLM request failed: {detail}"
|
||||||
|
|
||||||
|
def _build_user_fallback_message(self, exc: BaseException, reason: str) -> AIMessage:
|
||||||
|
return self._build_error_fallback_message(
|
||||||
|
self._build_user_message(exc, reason),
|
||||||
|
error_type=type(exc).__name__,
|
||||||
|
reason=reason,
|
||||||
|
detail=_extract_error_detail(exc),
|
||||||
|
)
|
||||||
|
|
||||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||||
try:
|
try:
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
@@ -212,7 +300,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
if self._check_circuit():
|
if self._check_circuit():
|
||||||
return AIMessage(content=self._build_circuit_breaker_message())
|
return self._build_error_fallback_message(
|
||||||
|
self._build_circuit_breaker_message(),
|
||||||
|
error_type="CircuitBreakerOpen",
|
||||||
|
reason="circuit_open",
|
||||||
|
detail="LLM circuit breaker is open",
|
||||||
|
)
|
||||||
|
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
@@ -228,7 +321,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
retriable, reason = self._classify_error(exc)
|
retriable, reason = self._classify_error(exc)
|
||||||
if retriable and attempt < self.retry_max_attempts:
|
max_attempts = self._max_attempts_for(exc)
|
||||||
|
if retriable and attempt < max_attempts:
|
||||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||||
@@ -249,7 +343,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
)
|
)
|
||||||
if retriable:
|
if retriable:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
return AIMessage(content=self._build_user_message(exc, reason))
|
return self._build_user_fallback_message(exc, reason)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
@@ -258,7 +352,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
if self._check_circuit():
|
if self._check_circuit():
|
||||||
return AIMessage(content=self._build_circuit_breaker_message())
|
return self._build_error_fallback_message(
|
||||||
|
self._build_circuit_breaker_message(),
|
||||||
|
error_type="CircuitBreakerOpen",
|
||||||
|
reason="circuit_open",
|
||||||
|
detail="LLM circuit breaker is open",
|
||||||
|
)
|
||||||
|
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
@@ -274,7 +373,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
retriable, reason = self._classify_error(exc)
|
retriable, reason = self._classify_error(exc)
|
||||||
if retriable and attempt < self.retry_max_attempts:
|
max_attempts = self._max_attempts_for(exc)
|
||||||
|
if retriable and attempt < max_attempts:
|
||||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||||
@@ -295,7 +395,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
)
|
)
|
||||||
if retriable:
|
if retriable:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
return AIMessage(content=self._build_user_message(exc, reason))
|
return self._build_user_fallback_message(exc, reason)
|
||||||
|
|
||||||
|
|
||||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||||
|
|||||||
@@ -0,0 +1,289 @@
|
|||||||
|
"""Middleware for explicit slash skill activation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, override
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from deerflow.skills.slash import parse_slash_skill_reference, resolve_slash_skill
|
||||||
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
|
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||||
|
from deerflow.skills.types import SKILL_MD_FILE
|
||||||
|
from deerflow.utils.messages import get_original_user_content_text
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deerflow.config.app_config import AppConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SLASH_SKILL_ACTIVATION_KEY = "slash_skill_activation"
|
||||||
|
_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY = "slash_skill_activation_target_id"
|
||||||
|
_SUMMARY_MESSAGE_NAME = "summary"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _Activation:
|
||||||
|
skill_name: str
|
||||||
|
category: str
|
||||||
|
container_file_path: str
|
||||||
|
skill_content: str
|
||||||
|
content_hash: str
|
||||||
|
remaining_text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _ActivationResolution:
|
||||||
|
activation: _Activation | None = None
|
||||||
|
failure_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_slash_skill_activation_reminder(message: object) -> bool:
|
||||||
|
"""Return whether a message is hidden slash-skill activation context."""
|
||||||
|
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_KEY))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_activation_target(message: object) -> bool:
|
||||||
|
if not isinstance(message, HumanMessage):
|
||||||
|
return False
|
||||||
|
if message.name == _SUMMARY_MESSAGE_NAME:
|
||||||
|
return False
|
||||||
|
if message.additional_kwargs.get("hide_from_ui"):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class SkillActivationMiddleware(AgentMiddleware):
|
||||||
|
"""Inject full SKILL.md content when the user explicitly types /skill-name."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
available_skills: set[str] | None = None,
|
||||||
|
app_config: AppConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._available_skills = set(available_skills) if available_skills is not None else None
|
||||||
|
self._app_config = app_config
|
||||||
|
|
||||||
|
def _storage(self) -> SkillStorage:
|
||||||
|
if self._app_config is not None:
|
||||||
|
return get_or_new_skill_storage(app_config=self._app_config)
|
||||||
|
return get_or_new_skill_storage()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_skill_content(skill_file: Path, skills_root: Path) -> str:
|
||||||
|
if skill_file.name != SKILL_MD_FILE:
|
||||||
|
raise ValueError(f"Expected {SKILL_MD_FILE}, got {skill_file.name}")
|
||||||
|
resolved_root = skills_root.resolve()
|
||||||
|
resolved_file = skill_file.resolve()
|
||||||
|
try:
|
||||||
|
resolved_file.relative_to(resolved_root)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError("Resolved skill file must stay within the configured skills root.") from exc
|
||||||
|
if not resolved_file.is_file():
|
||||||
|
raise FileNotFoundError(resolved_file)
|
||||||
|
return resolved_file.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
def _resolve_activation(self, text: str) -> _ActivationResolution | None:
|
||||||
|
reference = parse_slash_skill_reference(text)
|
||||||
|
if reference is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
storage = self._storage()
|
||||||
|
skills = storage.load_skills(enabled_only=False)
|
||||||
|
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
|
||||||
|
if skill is None:
|
||||||
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not installed.")
|
||||||
|
if not skill.enabled:
|
||||||
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
|
||||||
|
if self._available_skills is not None and reference.name not in self._available_skills:
|
||||||
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
|
||||||
|
|
||||||
|
resolved = resolve_slash_skill(
|
||||||
|
text,
|
||||||
|
skills,
|
||||||
|
available_skills=self._available_skills,
|
||||||
|
container_base_path=storage.get_container_root(),
|
||||||
|
)
|
||||||
|
if resolved is None:
|
||||||
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be resolved.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
skill_content = self._read_skill_content(resolved.skill.skill_file, storage.get_skills_root_path())
|
||||||
|
except (OSError, ValueError):
|
||||||
|
logger.exception("Failed to read slash-activated skill %s", resolved.skill.name)
|
||||||
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be loaded safely. Please check the skill installation.")
|
||||||
|
|
||||||
|
content_hash = hashlib.sha256(skill_content.encode("utf-8")).hexdigest()
|
||||||
|
return _ActivationResolution(
|
||||||
|
activation=_Activation(
|
||||||
|
skill_name=resolved.skill.name,
|
||||||
|
category=str(resolved.skill.category),
|
||||||
|
container_file_path=resolved.container_file_path,
|
||||||
|
skill_content=skill_content,
|
||||||
|
content_hash=content_hash,
|
||||||
|
remaining_text=resolved.remaining_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_activation_reminder(activation: _Activation) -> str:
|
||||||
|
user_request = activation.remaining_text or ("No additional task text was provided after the slash skill command. Ask the user what they want to do with this skill if the next step is unclear.")
|
||||||
|
escaped_user_request = html.escape(user_request, quote=False)
|
||||||
|
escaped_skill_content = html.escape(activation.skill_content, quote=False)
|
||||||
|
escaped_skill_name = html.escape(activation.skill_name, quote=True)
|
||||||
|
escaped_category = html.escape(activation.category, quote=True)
|
||||||
|
escaped_path = html.escape(activation.container_file_path, quote=True)
|
||||||
|
escaped_content_hash = html.escape(activation.content_hash, quote=True)
|
||||||
|
return f"""<slash_skill_activation>
|
||||||
|
The user explicitly activated the `{activation.skill_name}` skill for this turn.
|
||||||
|
Treat the task text as:
|
||||||
|
<user_request>
|
||||||
|
{escaped_user_request}
|
||||||
|
</user_request>
|
||||||
|
|
||||||
|
Follow this skill before choosing a general workflow. Load supporting resources from the same skill directory only when needed.
|
||||||
|
|
||||||
|
<skill name="{escaped_skill_name}" category="{escaped_category}" path="{escaped_path}" sha256="{escaped_content_hash}">
|
||||||
|
<skill_content encoding="xml-escaped">
|
||||||
|
{escaped_skill_content}
|
||||||
|
</skill_content>
|
||||||
|
</skill>
|
||||||
|
</slash_skill_activation>"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_existing_activation_for_target(messages: list, target_index: int, target: HumanMessage) -> bool:
|
||||||
|
if target_index <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if target.id:
|
||||||
|
for previous in messages[:target_index]:
|
||||||
|
if not is_slash_skill_activation_reminder(previous):
|
||||||
|
continue
|
||||||
|
target_id = previous.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY)
|
||||||
|
if target_id == target.id or previous.id == f"{target.id}__slash_activation":
|
||||||
|
return True
|
||||||
|
|
||||||
|
previous = messages[target_index - 1]
|
||||||
|
return is_slash_skill_activation_reminder(previous)
|
||||||
|
|
||||||
|
def _find_activation_target(self, messages: list) -> tuple[int, HumanMessage, _ActivationResolution] | None:
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_index = next((idx for idx in range(len(messages) - 1, -1, -1) if _is_user_activation_target(messages[idx])), None)
|
||||||
|
if target_index is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target = messages[target_index]
|
||||||
|
if target is None:
|
||||||
|
return None
|
||||||
|
if self._has_existing_activation_for_target(messages, target_index, target):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = get_original_user_content_text(target.content, target.additional_kwargs)
|
||||||
|
resolution = self._resolve_activation(content)
|
||||||
|
if resolution is None:
|
||||||
|
return None
|
||||||
|
return target_index, target, resolution
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _record_activation(request: ModelRequest, activation: _Activation, *, hook: str) -> None:
|
||||||
|
runtime = getattr(request, "runtime", None)
|
||||||
|
context = getattr(runtime, "context", None)
|
||||||
|
journal = context.get("__run_journal") if isinstance(context, dict) else None
|
||||||
|
if journal is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
journal.record_middleware(
|
||||||
|
"skill_activation",
|
||||||
|
name="SkillActivationMiddleware",
|
||||||
|
hook=hook,
|
||||||
|
action="activate",
|
||||||
|
changes={
|
||||||
|
"skill_name": activation.skill_name,
|
||||||
|
"category": activation.category,
|
||||||
|
"path": activation.container_file_path,
|
||||||
|
"content_hash": activation.content_hash,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to record slash skill activation audit event", exc_info=True)
|
||||||
|
|
||||||
|
def _prepare_model_request(self, request: ModelRequest, *, hook: str) -> ModelRequest | AIMessage | None:
|
||||||
|
target_and_resolution = self._find_activation_target(list(request.messages))
|
||||||
|
if target_and_resolution is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_index, target, resolution = target_and_resolution
|
||||||
|
if resolution.failure_message:
|
||||||
|
return AIMessage(content=resolution.failure_message)
|
||||||
|
|
||||||
|
activation = resolution.activation
|
||||||
|
if activation is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"SkillActivationMiddleware: activating slash skill %s category=%s path=%s hash=%s",
|
||||||
|
activation.skill_name,
|
||||||
|
activation.category,
|
||||||
|
activation.container_file_path,
|
||||||
|
activation.content_hash,
|
||||||
|
)
|
||||||
|
self._record_activation(request, activation, hook=hook)
|
||||||
|
activation_msg = self._make_activation_message(target, self._build_activation_reminder(activation))
|
||||||
|
messages = list(request.messages)
|
||||||
|
messages.insert(target_index, activation_msg)
|
||||||
|
return request.override(messages=messages)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_activation_message(target: HumanMessage, activation_content: str) -> HumanMessage:
|
||||||
|
stable_id = target.id or str(uuid.uuid4())
|
||||||
|
additional_kwargs = {
|
||||||
|
"hide_from_ui": True,
|
||||||
|
_SLASH_SKILL_ACTIVATION_KEY: True,
|
||||||
|
}
|
||||||
|
if target.id:
|
||||||
|
additional_kwargs[_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY] = target.id
|
||||||
|
return HumanMessage(
|
||||||
|
content=activation_content,
|
||||||
|
id=f"{stable_id}__slash_activation",
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelResponse | AIMessage:
|
||||||
|
prepared = self._prepare_model_request(request, hook="wrap_model_call")
|
||||||
|
if prepared is None:
|
||||||
|
return handler(request)
|
||||||
|
if isinstance(prepared, AIMessage):
|
||||||
|
return prepared
|
||||||
|
return handler(prepared)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelResponse | AIMessage:
|
||||||
|
prepared = await asyncio.to_thread(self._prepare_model_request, request, hook="awrap_model_call")
|
||||||
|
if prepared is None:
|
||||||
|
return await handler(request)
|
||||||
|
if isinstance(prepared, AIMessage):
|
||||||
|
return prepared
|
||||||
|
return await handler(prepared)
|
||||||
@@ -9,8 +9,9 @@ from typing import Any, Protocol, override, runtime_checkable
|
|||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import SummarizationMiddleware
|
from langchain.agents.middleware import SummarizationMiddleware
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage, get_buffer_string
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
|
from langgraph.constants import TAG_NOSTREAM
|
||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
@@ -116,6 +117,74 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
||||||
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
||||||
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
||||||
|
# The summary LLM call runs inside a LangGraph middleware hook, so its token
|
||||||
|
# stream would otherwise be captured by the messages-tuple stream callback and
|
||||||
|
# broadcast to the frontend as a phantom AI message. Tag a dedicated model copy
|
||||||
|
# with TAG_NOSTREAM so the streaming handler skips it.
|
||||||
|
# Keep self.model untagged so the parent's profile / ls_params inspection still works.
|
||||||
|
#
|
||||||
|
# Preserve any tags already bound on the model (e.g. "middleware:summarize" set in
|
||||||
|
# lead_agent/agent.py for RunJournal attribution): RunnableBinding.with_config does a
|
||||||
|
# shallow merge that would otherwise overwrite the existing tags list entirely.
|
||||||
|
existing_tags = list((getattr(self.model, "config", None) or {}).get("tags") or [])
|
||||||
|
merged_tags = [*existing_tags, TAG_NOSTREAM] if TAG_NOSTREAM not in existing_tags else existing_tags
|
||||||
|
self._summary_model = self.model.with_config(tags=merged_tags)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||||
|
return self._summarize_with(messages_to_summarize)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||||
|
return await self._asummarize_with(messages_to_summarize)
|
||||||
|
|
||||||
|
def _summarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||||
|
"""Mirror the parent ``_create_summary`` but invoke the nostream-tagged model.
|
||||||
|
|
||||||
|
We do not swap ``self.model`` at the instance level: the agent/middleware is
|
||||||
|
cached and reused across concurrent runs, so a temporary swap would leak the
|
||||||
|
``RunnableBinding`` to other coroutines during ``await`` and break parent logic
|
||||||
|
that inspects the raw model (``profile`` / ``_get_ls_params``).
|
||||||
|
"""
|
||||||
|
if not messages_to_summarize:
|
||||||
|
return "No previous conversation history."
|
||||||
|
prompt = self._build_summary_prompt(messages_to_summarize)
|
||||||
|
if prompt is None:
|
||||||
|
return "Previous conversation was too long to summarize."
|
||||||
|
try:
|
||||||
|
response = self._summary_model.invoke(
|
||||||
|
prompt,
|
||||||
|
config={"metadata": {"lc_source": "summarization"}},
|
||||||
|
)
|
||||||
|
return response.text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error generating summary: {e!s}"
|
||||||
|
|
||||||
|
async def _asummarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||||
|
"""Async counterpart of :meth:`_summarize_with` using the nostream model."""
|
||||||
|
if not messages_to_summarize:
|
||||||
|
return "No previous conversation history."
|
||||||
|
prompt = self._build_summary_prompt(messages_to_summarize)
|
||||||
|
if prompt is None:
|
||||||
|
return "Previous conversation was too long to summarize."
|
||||||
|
try:
|
||||||
|
response = await self._summary_model.ainvoke(
|
||||||
|
prompt,
|
||||||
|
config={"metadata": {"lc_source": "summarization"}},
|
||||||
|
)
|
||||||
|
return response.text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error generating summary: {e!s}"
|
||||||
|
|
||||||
|
def _build_summary_prompt(self, messages_to_summarize: list[AnyMessage]) -> str | None:
|
||||||
|
"""Build the summary prompt, returning ``None`` when trimming leaves nothing."""
|
||||||
|
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||||
|
if not trimmed_messages:
|
||||||
|
return None
|
||||||
|
# Format messages to avoid token inflation from metadata when str() is called on
|
||||||
|
# message objects.
|
||||||
|
formatted_messages = get_buffer_string(trimmed_messages)
|
||||||
|
return self.summary_prompt.format(messages=formatted_messages).rstrip()
|
||||||
|
|
||||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
return self._maybe_summarize(state, runtime)
|
return self._maybe_summarize(state, runtime)
|
||||||
|
|||||||
@@ -46,11 +46,6 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _completion_reminder_count(messages: list[Any]) -> int:
|
|
||||||
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
|
|
||||||
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
|
|
||||||
|
|
||||||
|
|
||||||
def _format_todos(todos: list[Todo]) -> str:
|
def _format_todos(todos: list[Todo]) -> str:
|
||||||
"""Format a list of Todo items into a human-readable string."""
|
"""Format a list of Todo items into a human-readable string."""
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
|
|||||||
+74
-4
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import override
|
from typing import TYPE_CHECKING, override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
@@ -12,10 +12,48 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
|
|||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
|
from deerflow.subagents.status_contract import (
|
||||||
|
extract_subagent_status,
|
||||||
|
make_subagent_additional_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deerflow.tools.builtins.tool_search import DeferredToolSetup
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||||
|
_TASK_TOOL_NAME = "task"
|
||||||
|
|
||||||
|
|
||||||
|
def _stamp_task_subagent_status(message: ToolMessage, *, tool_name: str, error: str | None = None) -> ToolMessage:
|
||||||
|
"""Centralised stamping of ``additional_kwargs.subagent_status``.
|
||||||
|
|
||||||
|
Bytedance/deer-flow issue #3146: the frontend now reads the subagent
|
||||||
|
status from a structured field instead of parsing the leading text of
|
||||||
|
the task tool's return string. That contract is enforced here, in the
|
||||||
|
one place every task tool result flows through, rather than at the 5
|
||||||
|
normal-return + 3 ``Error:`` pre-execution branches inside
|
||||||
|
``task_tool.py``. Centralisation prevents the "added a new return
|
||||||
|
path, forgot the stamp" drift mode.
|
||||||
|
|
||||||
|
For non-``task`` tools this is a no-op so other tools' additional_kwargs
|
||||||
|
conventions are untouched.
|
||||||
|
"""
|
||||||
|
if tool_name != _TASK_TOOL_NAME:
|
||||||
|
return message
|
||||||
|
content = message.content if isinstance(message.content, str) else ""
|
||||||
|
status = extract_subagent_status(content)
|
||||||
|
if status is None:
|
||||||
|
# Non-terminal streaming chunks or unrecognised shapes leave the
|
||||||
|
# field unset so the frontend can keep the card on its in-progress
|
||||||
|
# placeholder until a real terminal frame arrives.
|
||||||
|
return message
|
||||||
|
stamp = make_subagent_additional_kwargs(status, error=error)
|
||||||
|
existing = dict(message.additional_kwargs or {})
|
||||||
|
existing.update(stamp)
|
||||||
|
message.additional_kwargs = existing
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||||
@@ -29,12 +67,31 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
detail = detail[:497] + "..."
|
detail = detail[:497] + "..."
|
||||||
|
|
||||||
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
||||||
return ToolMessage(
|
message = ToolMessage(
|
||||||
content=content,
|
content=content,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
|
# Stamp the structured subagent status on the wrapper too: the
|
||||||
|
# frontend would otherwise have to fall back to prefix-matching
|
||||||
|
# ``Error: Tool 'task' failed ...`` on the wire. The ``subagent_error``
|
||||||
|
# carries the same ``ExcClass: detail`` shape the wrapper string
|
||||||
|
# uses so debugging artifacts stay aligned.
|
||||||
|
structured_error = f"{exc.__class__.__name__}: {detail}"
|
||||||
|
return _stamp_task_subagent_status(message, tool_name=tool_name, error=structured_error)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_stamp(result: ToolMessage | Command, request: ToolCallRequest) -> ToolMessage | Command:
|
||||||
|
"""Apply the subagent stamp to successful task tool returns.
|
||||||
|
|
||||||
|
``Command`` results bypass the stamp — they encode LangGraph
|
||||||
|
control flow rather than user-facing tool output.
|
||||||
|
"""
|
||||||
|
if not isinstance(result, ToolMessage):
|
||||||
|
return result
|
||||||
|
tool_name = str(request.tool_call.get("name") or "")
|
||||||
|
return _stamp_task_subagent_status(result, tool_name=tool_name)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def wrap_tool_call(
|
def wrap_tool_call(
|
||||||
@@ -43,13 +100,14 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||||
) -> ToolMessage | Command:
|
) -> ToolMessage | Command:
|
||||||
try:
|
try:
|
||||||
return handler(request)
|
result = handler(request)
|
||||||
except GraphBubbleUp:
|
except GraphBubbleUp:
|
||||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||||
return self._build_error_message(request, exc)
|
return self._build_error_message(request, exc)
|
||||||
|
return self._maybe_stamp(result, request)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
@@ -58,13 +116,14 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||||
) -> ToolMessage | Command:
|
) -> ToolMessage | Command:
|
||||||
try:
|
try:
|
||||||
return await handler(request)
|
result = await handler(request)
|
||||||
except GraphBubbleUp:
|
except GraphBubbleUp:
|
||||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||||
return self._build_error_message(request, exc)
|
return self._build_error_message(request, exc)
|
||||||
|
return self._maybe_stamp(result, request)
|
||||||
|
|
||||||
|
|
||||||
def _build_runtime_middlewares(
|
def _build_runtime_middlewares(
|
||||||
@@ -143,6 +202,7 @@ def build_subagent_runtime_middlewares(
|
|||||||
app_config: AppConfig | None = None,
|
app_config: AppConfig | None = None,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
lazy_init: bool = True,
|
lazy_init: bool = True,
|
||||||
|
deferred_setup: "DeferredToolSetup | None" = None,
|
||||||
) -> list[AgentMiddleware]:
|
) -> list[AgentMiddleware]:
|
||||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||||
if app_config is None:
|
if app_config is None:
|
||||||
@@ -166,6 +226,16 @@ def build_subagent_runtime_middlewares(
|
|||||||
|
|
||||||
middlewares.append(ViewImageMiddleware())
|
middlewares.append(ViewImageMiddleware())
|
||||||
|
|
||||||
|
# Hide deferred (MCP) tool schemas from the subagent's model binding until
|
||||||
|
# tool_search promotes them. This is the same wiring the lead agent gets. The deferred
|
||||||
|
# set + catalog hash come from the build-time setup (assembled after
|
||||||
|
# tool-policy filtering); promotion is read from graph state. Empty/None
|
||||||
|
# setup (deferral disabled or no MCP tool survived) is a pure no-op.
|
||||||
|
if deferred_setup is not None and deferred_setup.deferred_names:
|
||||||
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
|
middlewares.append(DeferredToolFilterMiddleware(deferred_setup.deferred_names, deferred_setup.catalog_hash))
|
||||||
|
|
||||||
# Same provider safety-termination guard the lead agent uses — subagents
|
# Same provider safety-termination guard the lead agent uses — subagents
|
||||||
# are equally exposed to truncated tool_calls returned with
|
# are equally exposed to truncated tool_calls returned with
|
||||||
# finish_reason=content_filter (and friends), and the bad call would then
|
# finish_reason=content_filter (and friends), and the bad call would then
|
||||||
|
|||||||
+175
-21
@@ -11,10 +11,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shlex
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import replace as dc_replace
|
from dataclasses import replace as dc_replace
|
||||||
from typing import Any, override
|
from typing import TYPE_CHECKING, Any, override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
@@ -24,9 +25,19 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
|
|||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from deerflow.config.tool_output_config import ToolOutputConfig
|
from deerflow.config.tool_output_config import ToolOutputConfig
|
||||||
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deerflow.sandbox.sandbox import Sandbox
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Virtual outputs root inside the sandbox. Host-mounted sandboxes map this to
|
||||||
|
# the thread outputs dir on the host; for non-mounted (remote) sandboxes the
|
||||||
|
# same path is written directly into the sandbox filesystem so the model's
|
||||||
|
# ``read_file`` tool can read it back (issue #3416).
|
||||||
|
_VIRTUAL_OUTPUTS_BASE = "/mnt/user-data/outputs"
|
||||||
|
|
||||||
|
|
||||||
def _default_config() -> ToolOutputConfig:
|
def _default_config() -> ToolOutputConfig:
|
||||||
return ToolOutputConfig()
|
return ToolOutputConfig()
|
||||||
@@ -94,6 +105,18 @@ def _sanitize_tool_name(name: str) -> str:
|
|||||||
return safe or "unknown"
|
return safe or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_externalized_filename(*, tool_name: str, tool_call_id: str) -> str:
|
||||||
|
"""Build the on-disk filename for an externalized tool output.
|
||||||
|
|
||||||
|
Shared by the host-disk and sandbox externalization paths so both
|
||||||
|
produce the identical naming scheme.
|
||||||
|
"""
|
||||||
|
safe_name = _sanitize_tool_name(tool_name)
|
||||||
|
ext = _EXT_MAP.get(tool_name, "txt")
|
||||||
|
short_id = uuid.uuid4().hex[:12]
|
||||||
|
return f"{safe_name}-{short_id}.{ext}"
|
||||||
|
|
||||||
|
|
||||||
def _externalize(
|
def _externalize(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
@@ -111,10 +134,7 @@ def _externalize(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
safe_name = _sanitize_tool_name(tool_name)
|
filename = _build_externalized_filename(tool_name=tool_name, tool_call_id=tool_call_id)
|
||||||
ext = _EXT_MAP.get(tool_name, "txt")
|
|
||||||
short_id = uuid.uuid4().hex[:12]
|
|
||||||
filename = f"{safe_name}-{short_id}.{ext}"
|
|
||||||
filepath = os.path.join(storage_dir, filename)
|
filepath = os.path.join(storage_dir, filename)
|
||||||
|
|
||||||
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
|
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
|
||||||
@@ -126,8 +146,56 @@ def _externalize(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
virtual_base = "/mnt/user-data/outputs"
|
return f"{_VIRTUAL_OUTPUTS_BASE}/{storage_subdir}/{filename}"
|
||||||
return f"{virtual_base}/{storage_subdir}/{filename}"
|
|
||||||
|
|
||||||
|
def _externalize_to_sandbox(
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
tool_name: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
storage_subdir: str,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
) -> str | None:
|
||||||
|
"""Write *content* into the sandbox filesystem and return the virtual path.
|
||||||
|
|
||||||
|
Used when the sandbox does not use thread-data mounts (e.g. a remote AIO
|
||||||
|
sandbox): the host-side :func:`_externalize` virtual path would not exist
|
||||||
|
inside the sandbox, so the model's ``read_file`` tool could not read it
|
||||||
|
back (issue #3416). Returns the same virtual-path contract on success, or
|
||||||
|
``None`` to signal the caller to fall back to inline truncation.
|
||||||
|
"""
|
||||||
|
if os.path.isabs(storage_subdir) or ".." in storage_subdir:
|
||||||
|
return None
|
||||||
|
filename = _build_externalized_filename(tool_name=tool_name, tool_call_id=tool_call_id)
|
||||||
|
virtual_dir = f"{_VIRTUAL_OUTPUTS_BASE}/{storage_subdir}"
|
||||||
|
virtual_path = f"{virtual_dir}/{filename}"
|
||||||
|
try:
|
||||||
|
# AIO sandbox write_file does NOT create parent directories, so create
|
||||||
|
# them explicitly before writing. execute_command returns its stdout
|
||||||
|
# verbatim (including an "Error: ..." string on failure) rather than
|
||||||
|
# raising, so we cannot rely on exception propagation here.
|
||||||
|
sandbox.execute_command(f"mkdir -p {shlex.quote(virtual_dir)}")
|
||||||
|
sandbox.write_file(virtual_path, content)
|
||||||
|
# Validate the file landed: execute_command may have silently failed
|
||||||
|
# to create the directory, and write_file backends differ. Refuse to
|
||||||
|
# hand the model an unreadable read_file path.
|
||||||
|
check = sandbox.execute_command(f"test -s {shlex.quote(virtual_path)} && echo OK || echo MISSING")
|
||||||
|
if not isinstance(check, str) or check.strip() != "OK":
|
||||||
|
logger.warning(
|
||||||
|
"Sandbox externalize validation failed: path=%s, check=%r",
|
||||||
|
virtual_path,
|
||||||
|
check,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to externalize %s output to sandbox (call_id=%s)",
|
||||||
|
tool_name,
|
||||||
|
tool_call_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return virtual_path
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -227,6 +295,33 @@ def _resolve_outputs_path(request: ToolCallRequest) -> str | None:
|
|||||||
return outputs_path if isinstance(outputs_path, str) else None
|
return outputs_path if isinstance(outputs_path, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_sandbox(request: ToolCallRequest) -> Sandbox | None:
|
||||||
|
"""Resolve the active sandbox for the current tool call, or ``None``.
|
||||||
|
|
||||||
|
Reads the sandbox_id that ``SandboxMiddleware`` (and the sandbox tools
|
||||||
|
themselves) write into ``runtime.state["sandbox"]``. We intentionally do
|
||||||
|
NOT call ``provider.acquire`` here: acquiring a sandbox can trigger
|
||||||
|
blocking remote I/O, and this resolver runs on every tool call. Tools
|
||||||
|
that do not use a sandbox (``web_search``, MCP, ...) will return ``None``
|
||||||
|
here, which is fine -- the caller falls back to inline truncation.
|
||||||
|
"""
|
||||||
|
runtime = getattr(request, "runtime", None)
|
||||||
|
state = getattr(runtime, "state", None)
|
||||||
|
if not isinstance(state, dict):
|
||||||
|
return None
|
||||||
|
sandbox_state = state.get("sandbox")
|
||||||
|
if not isinstance(sandbox_state, dict):
|
||||||
|
return None
|
||||||
|
sandbox_id = sandbox_state.get("sandbox_id")
|
||||||
|
if not sandbox_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return get_sandbox_provider().get(sandbox_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to look up sandbox %s for tool-output externalization", sandbox_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _budget_content(
|
def _budget_content(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
@@ -234,6 +329,7 @@ def _budget_content(
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
outputs_path: str | None,
|
outputs_path: str | None,
|
||||||
config: ToolOutputConfig,
|
config: ToolOutputConfig,
|
||||||
|
sandbox: Sandbox | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Apply budget to *content*. Returns ``None`` if no change needed."""
|
"""Apply budget to *content*. Returns ``None`` if no change needed."""
|
||||||
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
|
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
|
||||||
@@ -242,14 +338,50 @@ def _budget_content(
|
|||||||
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
|
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if threshold > 0 and len(content) > threshold and outputs_path:
|
if threshold > 0 and len(content) > threshold:
|
||||||
virtual_path = _externalize(
|
virtual_path: str | None = None
|
||||||
content,
|
# Decide persistence target based on what's available, without touching
|
||||||
tool_name=tool_name,
|
# the sandbox provider unless a sandbox was actually resolved for this
|
||||||
tool_call_id=tool_call_id,
|
# call. This keeps the legacy host-disk path provider-free, so callers
|
||||||
outputs_path=outputs_path,
|
# without a configured sandbox (and CI environments without a
|
||||||
storage_subdir=config.storage_subdir,
|
# config.yaml) continue to externalize to the host as before.
|
||||||
)
|
if sandbox is not None:
|
||||||
|
provider = None
|
||||||
|
try:
|
||||||
|
provider = get_sandbox_provider()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get sandbox provider for tool-output externalization; falling back to inline truncation")
|
||||||
|
if provider is not None and getattr(provider, "uses_thread_data_mounts", False):
|
||||||
|
# Host-mounted sandbox: host outputs path is bind-mounted into
|
||||||
|
# the sandbox at the same virtual path, so writing host-side is
|
||||||
|
# equivalent. Preserve the original behavior to avoid extra
|
||||||
|
# sandbox round-trips.
|
||||||
|
if outputs_path:
|
||||||
|
virtual_path = _externalize(
|
||||||
|
content,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
outputs_path=outputs_path,
|
||||||
|
storage_subdir=config.storage_subdir,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
virtual_path = _externalize_to_sandbox(
|
||||||
|
content,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
storage_subdir=config.storage_subdir,
|
||||||
|
sandbox=sandbox,
|
||||||
|
)
|
||||||
|
elif outputs_path:
|
||||||
|
# No sandbox in this call (legacy / non-sandbox tools): write to
|
||||||
|
# host outputs path directly, no provider needed.
|
||||||
|
virtual_path = _externalize(
|
||||||
|
content,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
outputs_path=outputs_path,
|
||||||
|
storage_subdir=config.storage_subdir,
|
||||||
|
)
|
||||||
if virtual_path is not None:
|
if virtual_path is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Externalized %s output (%d chars) to %s",
|
"Externalized %s output (%d chars) to %s",
|
||||||
@@ -288,7 +420,12 @@ def _budget_content(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage:
|
def _patch_tool_message(
|
||||||
|
msg: ToolMessage,
|
||||||
|
config: ToolOutputConfig,
|
||||||
|
outputs_path: str | None,
|
||||||
|
sandbox: Sandbox | None = None,
|
||||||
|
) -> ToolMessage:
|
||||||
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
|
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
|
||||||
tool_name = msg.name or "unknown"
|
tool_name = msg.name or "unknown"
|
||||||
if tool_name in config.exempt_tools:
|
if tool_name in config.exempt_tools:
|
||||||
@@ -304,6 +441,7 @@ def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path
|
|||||||
tool_call_id=msg.tool_call_id or "",
|
tool_call_id=msg.tool_call_id or "",
|
||||||
outputs_path=outputs_path,
|
outputs_path=outputs_path,
|
||||||
config=config,
|
config=config,
|
||||||
|
sandbox=sandbox,
|
||||||
)
|
)
|
||||||
if replacement is None:
|
if replacement is None:
|
||||||
return msg
|
return msg
|
||||||
@@ -355,10 +493,15 @@ def _needs_budget(result: ToolMessage | Command, config: ToolOutputConfig) -> bo
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage | Command:
|
def _patch_result(
|
||||||
|
result: ToolMessage | Command,
|
||||||
|
config: ToolOutputConfig,
|
||||||
|
outputs_path: str | None,
|
||||||
|
sandbox: Sandbox | None = None,
|
||||||
|
) -> ToolMessage | Command:
|
||||||
"""Apply budget to a tool call result (ToolMessage or Command)."""
|
"""Apply budget to a tool call result (ToolMessage or Command)."""
|
||||||
if isinstance(result, ToolMessage):
|
if isinstance(result, ToolMessage):
|
||||||
return _patch_tool_message(result, config, outputs_path)
|
return _patch_tool_message(result, config, outputs_path, sandbox)
|
||||||
|
|
||||||
update = getattr(result, "update", None)
|
update = getattr(result, "update", None)
|
||||||
if not isinstance(update, dict):
|
if not isinstance(update, dict):
|
||||||
@@ -372,7 +515,7 @@ def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outpu
|
|||||||
changed = False
|
changed = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
patched = _patch_tool_message(msg, config, outputs_path)
|
patched = _patch_tool_message(msg, config, outputs_path, sandbox)
|
||||||
if patched is not msg:
|
if patched is not msg:
|
||||||
changed = True
|
changed = True
|
||||||
new_messages.append(patched)
|
new_messages.append(patched)
|
||||||
@@ -392,6 +535,11 @@ def _patch_model_messages(messages: list[Any], config: ToolOutputConfig) -> list
|
|||||||
ToolMessage exceeds the budget — the common case once every result has
|
ToolMessage exceeds the budget — the common case once every result has
|
||||||
already been budgeted at tool-call time, so a long history is not rebuilt
|
already been budgeted at tool-call time, so a long history is not rebuilt
|
||||||
on every model call.
|
on every model call.
|
||||||
|
|
||||||
|
Historical messages do not get a ``sandbox`` argument: any oversized tool
|
||||||
|
message in history was already budgeted (and possibly externalized) at
|
||||||
|
tool-call time, so the only thing left for the history path to do is
|
||||||
|
inline fallback truncation, which needs no sandbox.
|
||||||
"""
|
"""
|
||||||
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
|
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
|
||||||
return None
|
return None
|
||||||
@@ -442,7 +590,8 @@ class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if not _needs_budget(result, self._config):
|
if not _needs_budget(result, self._config):
|
||||||
return result
|
return result
|
||||||
outputs_path = _resolve_outputs_path(request)
|
outputs_path = _resolve_outputs_path(request)
|
||||||
return _patch_result(result, self._config, outputs_path)
|
sandbox = _resolve_sandbox(request)
|
||||||
|
return _patch_result(result, self._config, outputs_path, sandbox)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
@@ -456,7 +605,12 @@ class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if not _needs_budget(result, self._config):
|
if not _needs_budget(result, self._config):
|
||||||
return result
|
return result
|
||||||
outputs_path = _resolve_outputs_path(request)
|
outputs_path = _resolve_outputs_path(request)
|
||||||
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path)
|
# _resolve_sandbox only touches runtime.state and the provider's
|
||||||
|
# in-memory sandbox registry, so it is safe to call on the event
|
||||||
|
# loop. The actual sandbox I/O (mkdir/write/test) happens inside
|
||||||
|
# _patch_result, which is offloaded to a worker thread below.
|
||||||
|
sandbox = _resolve_sandbox(request)
|
||||||
|
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path, sandbox)
|
||||||
|
|
||||||
# -- model call hooks (historical message truncation) ------------------
|
# -- model call hooks (historical message truncation) ------------------
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from langgraph.runtime import Runtime
|
|||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.utils.file_conversion import extract_outline
|
from deerflow.utils.file_conversion import extract_outline
|
||||||
|
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY, message_content_to_text
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -265,6 +266,8 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
|
|
||||||
# Extract original content - handle both string and list formats
|
# Extract original content - handle both string and list formats
|
||||||
original_content = last_message.content
|
original_content = last_message.content
|
||||||
|
additional_kwargs = dict(last_message.additional_kwargs or {})
|
||||||
|
additional_kwargs.setdefault(ORIGINAL_USER_CONTENT_KEY, message_content_to_text(original_content))
|
||||||
if isinstance(original_content, str):
|
if isinstance(original_content, str):
|
||||||
# Simple case: string content, just prepend files message
|
# Simple case: string content, just prepend files message
|
||||||
updated_content = f"{files_message}\n\n{original_content}"
|
updated_content = f"{files_message}\n\n{original_content}"
|
||||||
@@ -285,7 +288,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
content=updated_content,
|
content=updated_content,
|
||||||
id=last_message.id,
|
id=last_message.id,
|
||||||
name=last_message.name,
|
name=last_message.name,
|
||||||
additional_kwargs=last_message.additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages[last_message_index] = updated_message
|
messages[last_message_index] = updated_message
|
||||||
|
|||||||
@@ -179,8 +179,10 @@ class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
|||||||
# Create the image details message with text and image content
|
# Create the image details message with text and image content
|
||||||
image_content = self._create_image_details_message(state)
|
image_content = self._create_image_details_message(state)
|
||||||
|
|
||||||
# Create a new human message with mixed content (text + images)
|
# Create a new human message with mixed content (text + images). This is
|
||||||
human_msg = HumanMessage(content=image_content)
|
# internal context for the model only, so hide it from the chat UI and IM
|
||||||
|
# channels (matches the other middleware-injected context messages).
|
||||||
|
human_msg = HumanMessage(content=image_content, additional_kwargs={"hide_from_ui": True})
|
||||||
|
|
||||||
logger.debug("Injecting image details message with images before LLM call")
|
logger.debug("Injecting image details message with images before LLM call")
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,27 @@ class ViewedImageData(TypedDict):
|
|||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
|
def merge_sandbox(existing: SandboxState | None, new: SandboxState | None) -> SandboxState | None:
|
||||||
|
"""Reducer for sandbox state - accepts idempotent writes only.
|
||||||
|
|
||||||
|
Multiple sandbox tools can initialize lazily in the same graph step and
|
||||||
|
emit the same sandbox_id via Command(update=...). LangGraph needs an
|
||||||
|
explicit reducer for that shared state key. Different sandbox ids in the
|
||||||
|
same thread indicate a lifecycle/isolation bug, so fail closed instead of
|
||||||
|
choosing one silently.
|
||||||
|
"""
|
||||||
|
if new is None:
|
||||||
|
return existing
|
||||||
|
if existing is None:
|
||||||
|
return new
|
||||||
|
|
||||||
|
existing_id = existing.get("sandbox_id")
|
||||||
|
new_id = new.get("sandbox_id")
|
||||||
|
if existing_id == new_id:
|
||||||
|
return existing
|
||||||
|
raise ValueError(f"Conflicting sandbox state updates: {existing_id!r} != {new_id!r}")
|
||||||
|
|
||||||
|
|
||||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||||
if existing is None:
|
if existing is None:
|
||||||
@@ -58,11 +79,38 @@ def merge_todos(existing: list | None, new: list | None) -> list | None:
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
class PromotedTools(TypedDict):
|
||||||
|
catalog_hash: str
|
||||||
|
names: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_promoted(existing: PromotedTools | None, new: PromotedTools | None) -> PromotedTools | None:
|
||||||
|
"""Reducer for deferred-tool promotions, scoped by catalog hash.
|
||||||
|
|
||||||
|
- new None/empty -> preserve existing (node didn't touch promotions).
|
||||||
|
- catalog_hash changed -> replace wholesale, dropping stale names (prevents a
|
||||||
|
persisted bare name from exposing a different tool after catalog drift).
|
||||||
|
- same catalog_hash -> union names, dedupe, preserve order.
|
||||||
|
"""
|
||||||
|
if not new:
|
||||||
|
return existing
|
||||||
|
if existing is None or existing.get("catalog_hash") != new["catalog_hash"]:
|
||||||
|
return {
|
||||||
|
"catalog_hash": new["catalog_hash"],
|
||||||
|
"names": list(dict.fromkeys(new["names"])),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"catalog_hash": existing["catalog_hash"],
|
||||||
|
"names": list(dict.fromkeys(existing["names"] + new["names"])),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ThreadState(AgentState):
|
class ThreadState(AgentState):
|
||||||
sandbox: NotRequired[SandboxState | None]
|
sandbox: Annotated[NotRequired[SandboxState | None], merge_sandbox]
|
||||||
thread_data: NotRequired[ThreadDataState | None]
|
thread_data: NotRequired[ThreadDataState | None]
|
||||||
title: NotRequired[str | None]
|
title: NotRequired[str | None]
|
||||||
artifacts: Annotated[list[str], merge_artifacts]
|
artifacts: Annotated[list[str], merge_artifacts]
|
||||||
todos: Annotated[list | None, merge_todos]
|
todos: Annotated[list | None, merge_todos]
|
||||||
uploaded_files: NotRequired[list[dict] | None]
|
uploaded_files: NotRequired[list[dict] | None]
|
||||||
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
|
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
|
||||||
|
promoted: Annotated[PromotedTools | None, merge_promoted]
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.agent import _build_middlewares
|
from deerflow.agents.lead_agent.agent import build_middlewares
|
||||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
||||||
@@ -43,6 +43,7 @@ from deerflow.config.paths import get_paths
|
|||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
|
from deerflow.tools.builtins.tool_search import assemble_deferred_tools
|
||||||
from deerflow.tracing import build_tracing_callbacks, inject_langfuse_metadata
|
from deerflow.tracing import build_tracing_callbacks, inject_langfuse_metadata
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
claim_unique_filename,
|
claim_unique_filename,
|
||||||
@@ -237,19 +238,30 @@ class DeerFlowClient:
|
|||||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
subagent_enabled = cfg.get("subagent_enabled", False)
|
||||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||||
|
|
||||||
|
tools = self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||||
|
final_tools, deferred_setup = assemble_deferred_tools(tools, enabled=self._app_config.tool_search.enabled)
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
# attach_tracing=False because ``stream()`` injects tracing
|
# attach_tracing=False because ``stream()`` injects tracing
|
||||||
# callbacks at the graph invocation root so a single embedded run
|
# callbacks at the graph invocation root so a single embedded run
|
||||||
# produces one trace with correct session_id / user_id propagation.
|
# produces one trace with correct session_id / user_id propagation.
|
||||||
# Attaching them again on the model would emit duplicate spans.
|
# Attaching them again on the model would emit duplicate spans.
|
||||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False),
|
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False),
|
||||||
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
"tools": final_tools,
|
||||||
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
|
"middleware": build_middlewares(
|
||||||
|
config,
|
||||||
|
model_name=model_name,
|
||||||
|
agent_name=self._agent_name,
|
||||||
|
available_skills=self._available_skills,
|
||||||
|
custom_middlewares=self._middlewares,
|
||||||
|
app_config=self._app_config,
|
||||||
|
deferred_setup=deferred_setup,
|
||||||
|
),
|
||||||
"system_prompt": apply_prompt_template(
|
"system_prompt": apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
agent_name=self._agent_name,
|
agent_name=self._agent_name,
|
||||||
available_skills=self._available_skills,
|
available_skills=self._available_skills,
|
||||||
|
deferred_names=deferred_setup.deferred_names,
|
||||||
),
|
),
|
||||||
"state_schema": ThreadState,
|
"state_schema": ThreadState,
|
||||||
}
|
}
|
||||||
@@ -1129,6 +1141,7 @@ class DeerFlowClient:
|
|||||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||||
"injection_enabled": config.injection_enabled,
|
"injection_enabled": config.injection_enabled,
|
||||||
"max_injection_tokens": config.max_injection_tokens,
|
"max_injection_tokens": config.max_injection_tokens,
|
||||||
|
"token_counting": config.token_counting,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_memory_status(self) -> dict:
|
def get_memory_status(self) -> dict:
|
||||||
@@ -1206,7 +1219,7 @@ class DeerFlowClient:
|
|||||||
|
|
||||||
info: dict[str, Any] = {
|
info: dict[str, Any] = {
|
||||||
"filename": dest_name,
|
"filename": dest_name,
|
||||||
"size": str(dest.stat().st_size),
|
"size": dest.stat().st_size,
|
||||||
"path": str(dest),
|
"path": str(dest),
|
||||||
"virtual_path": upload_virtual_path(dest_name),
|
"virtual_path": upload_virtual_path(dest_name),
|
||||||
"artifact_url": upload_artifact_url(thread_id, dest_name),
|
"artifact_url": upload_artifact_url(thread_id, dest_name),
|
||||||
|
|||||||
@@ -39,11 +39,63 @@ class AioSandbox(Sandbox):
|
|||||||
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
||||||
self._home_dir = home_dir
|
self._home_dir = home_dir
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
return self._base_url
|
return self._base_url
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Best-effort close of the host-side HTTP client owned by this sandbox.
|
||||||
|
|
||||||
|
The agent_sandbox SDK is Fern-generated and exposes no ``close()`` /
|
||||||
|
``__exit__``, so we reach the socket-owning ``httpx.Client`` explicitly
|
||||||
|
through its attribute chain::
|
||||||
|
|
||||||
|
Sandbox._client_wrapper -> SyncClientWrapper
|
||||||
|
.httpx_client -> Fern HttpClient (a wrapper, NOT httpx.Client)
|
||||||
|
.httpx_client -> httpx.Client <- the real socket owner
|
||||||
|
|
||||||
|
Closing it releases pooled sockets so long-running provider lifecycles
|
||||||
|
do not accumulate unreclaimed host-side resources (#2872).
|
||||||
|
|
||||||
|
Resolution is most-specific-first with graceful degradation: if a future
|
||||||
|
SDK adds a top-level ``Sandbox.close()`` it is picked up automatically
|
||||||
|
without changing this code. Idempotent, thread-safe, and non-fatal:
|
||||||
|
failures during teardown are logged and swallowed so provider/backend
|
||||||
|
cleanup is never blocked.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self._closed = True
|
||||||
|
client = self._client
|
||||||
|
# Drop the reference under the lock for use-after-close safety: any
|
||||||
|
# later command on this instance fails loudly instead of reusing a
|
||||||
|
# half-closed client.
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Walk from the real httpx.Client up to the top-level client, picking the
|
||||||
|
# first object that actually exposes close().
|
||||||
|
wrapper = getattr(client, "_client_wrapper", None)
|
||||||
|
fern_http = getattr(wrapper, "httpx_client", None)
|
||||||
|
real_httpx = getattr(fern_http, "httpx_client", None)
|
||||||
|
target = next(
|
||||||
|
(c for c in (real_httpx, fern_http, client) if c is not None and hasattr(c, "close")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if target is None:
|
||||||
|
logger.debug("AioSandbox %s: no closable client found, nothing to release", self.id)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
target.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing AioSandbox client for {self.id}: {e}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def home_dir(self) -> str:
|
def home_dir(self) -> str:
|
||||||
"""Get the home directory inside the sandbox."""
|
"""Get the home directory inside the sandbox."""
|
||||||
|
|||||||
@@ -470,14 +470,32 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
|
|
||||||
existing_id = self._thread_sandboxes[thread_id]
|
existing_id = self._thread_sandboxes[thread_id]
|
||||||
if existing_id in self._sandboxes:
|
if existing_id in self._sandboxes:
|
||||||
suffix = " (post-lock check)" if post_lock else ""
|
info = self._sandbox_infos.get(existing_id)
|
||||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
else:
|
||||||
self._last_activity[existing_id] = time.time()
|
del self._thread_sandboxes[thread_id]
|
||||||
return existing_id
|
return None
|
||||||
|
|
||||||
del self._thread_sandboxes[thread_id]
|
alive = self._check_tracked_sandbox_alive(existing_id, info) if info is not None else True
|
||||||
|
if alive is False:
|
||||||
|
self._drop_unhealthy_sandbox(
|
||||||
|
existing_id,
|
||||||
|
"in-process cache failed health check",
|
||||||
|
expected_info=info,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._thread_sandboxes.get(thread_id) != existing_id:
|
||||||
|
return None
|
||||||
|
if existing_id not in self._sandboxes:
|
||||||
|
self._thread_sandboxes.pop(thread_id, None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
suffix = " (post-lock check)" if post_lock else ""
|
||||||
|
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
||||||
|
self._last_activity[existing_id] = time.time()
|
||||||
|
return existing_id
|
||||||
|
|
||||||
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
|
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
|
||||||
"""Promote a warm-pool sandbox back to active tracking if available."""
|
"""Promote a warm-pool sandbox back to active tracking if available."""
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
@@ -487,7 +505,22 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
if sandbox_id not in self._warm_pool:
|
if sandbox_id not in self._warm_pool:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
info, _ = self._warm_pool.pop(sandbox_id)
|
info, _ = self._warm_pool[sandbox_id]
|
||||||
|
|
||||||
|
alive = self._check_tracked_sandbox_alive(sandbox_id, info)
|
||||||
|
if alive is False:
|
||||||
|
self._drop_unhealthy_sandbox(
|
||||||
|
sandbox_id,
|
||||||
|
"warm-pool cache failed health check",
|
||||||
|
expected_info=info,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
warm_item = self._warm_pool.pop(sandbox_id, None)
|
||||||
|
if warm_item is None:
|
||||||
|
return None
|
||||||
|
info, _ = warm_item
|
||||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||||
self._sandboxes[sandbox_id] = sandbox
|
self._sandboxes[sandbox_id] = sandbox
|
||||||
self._sandbox_infos[sandbox_id] = info
|
self._sandbox_infos[sandbox_id] = info
|
||||||
@@ -527,6 +560,70 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||||
return sandbox_id
|
return sandbox_id
|
||||||
|
|
||||||
|
def _check_tracked_sandbox_alive(self, sandbox_id: str, info: SandboxInfo) -> bool | None:
|
||||||
|
"""Return whether a tracked sandbox appears alive, or None if unknown."""
|
||||||
|
try:
|
||||||
|
return self._backend.is_alive(info)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check sandbox {sandbox_id} health: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _remove_tracked_sandbox(
|
||||||
|
self,
|
||||||
|
sandbox_id: str,
|
||||||
|
*,
|
||||||
|
expected_info: SandboxInfo | None = None,
|
||||||
|
) -> tuple[Sandbox | None, SandboxInfo | None, bool]:
|
||||||
|
"""Remove a sandbox from in-process tracking maps.
|
||||||
|
|
||||||
|
When expected_info is provided, removal only happens if the currently
|
||||||
|
tracked active or warm-pool entry is the exact info object that was
|
||||||
|
checked. This prevents a stale health-check result from deleting a
|
||||||
|
freshly recreated sandbox with the same deterministic id.
|
||||||
|
"""
|
||||||
|
thread_ids_to_remove: list[str] = []
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
active_info = self._sandbox_infos.get(sandbox_id)
|
||||||
|
warm_item = self._warm_pool.get(sandbox_id)
|
||||||
|
warm_info = warm_item[0] if warm_item is not None else None
|
||||||
|
if expected_info is not None and active_info is not expected_info and warm_info is not expected_info:
|
||||||
|
return None, None, False
|
||||||
|
|
||||||
|
sandbox = self._sandboxes.pop(sandbox_id, None)
|
||||||
|
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||||
|
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||||
|
for tid in thread_ids_to_remove:
|
||||||
|
del self._thread_sandboxes[tid]
|
||||||
|
self._last_activity.pop(sandbox_id, None)
|
||||||
|
if info is None and sandbox_id in self._warm_pool:
|
||||||
|
info, _ = self._warm_pool.pop(sandbox_id)
|
||||||
|
else:
|
||||||
|
self._warm_pool.pop(sandbox_id, None)
|
||||||
|
|
||||||
|
return sandbox, info, True
|
||||||
|
|
||||||
|
def _drop_unhealthy_sandbox(self, sandbox_id: str, reason: str, *, expected_info: SandboxInfo | None = None) -> None:
|
||||||
|
"""Remove and destroy a sandbox after a definitive failed health check."""
|
||||||
|
sandbox, info, removed = self._remove_tracked_sandbox(sandbox_id, expected_info=expected_info)
|
||||||
|
if not removed:
|
||||||
|
logger.info(f"Skipped dropping sandbox {sandbox_id}: tracked info changed after health check")
|
||||||
|
return
|
||||||
|
|
||||||
|
if sandbox is not None:
|
||||||
|
try:
|
||||||
|
sandbox.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing unhealthy sandbox {sandbox_id}: {e}")
|
||||||
|
|
||||||
|
if info is not None:
|
||||||
|
try:
|
||||||
|
self._backend.destroy(info)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error destroying unhealthy sandbox {sandbox_id}: {e}")
|
||||||
|
|
||||||
|
logger.warning(f"Dropped unhealthy sandbox {sandbox_id}: {reason}")
|
||||||
|
|
||||||
def _replica_count(self) -> tuple[int, int]:
|
def _replica_count(self) -> tuple[int, int]:
|
||||||
"""Return configured replicas and currently tracked sandbox count."""
|
"""Return configured replicas and currently tracked sandbox count."""
|
||||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||||
@@ -617,7 +714,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
|
|
||||||
async def _acquire_internal_async(self, thread_id: str | None) -> str:
|
async def _acquire_internal_async(self, thread_id: str | None) -> str:
|
||||||
"""Async counterpart to ``_acquire_internal``."""
|
"""Async counterpart to ``_acquire_internal``."""
|
||||||
cached_id = self._reuse_in_process_sandbox(thread_id)
|
cached_id = await asyncio.to_thread(self._reuse_in_process_sandbox, thread_id)
|
||||||
if cached_id is not None:
|
if cached_id is not None:
|
||||||
return cached_id
|
return cached_id
|
||||||
|
|
||||||
@@ -625,7 +722,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
||||||
|
|
||||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||||
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
|
reclaimed_id = await asyncio.to_thread(self._reclaim_warm_pool_sandbox, thread_id, sandbox_id)
|
||||||
if reclaimed_id is not None:
|
if reclaimed_id is not None:
|
||||||
return reclaimed_id
|
return reclaimed_id
|
||||||
|
|
||||||
@@ -681,7 +778,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
locked = True
|
locked = True
|
||||||
# Re-check in-process caches under the file lock in case another
|
# Re-check in-process caches under the file lock in case another
|
||||||
# thread in this process won the race while we were waiting.
|
# thread in this process won the race while we were waiting.
|
||||||
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
|
cached_id = await asyncio.to_thread(self._recheck_cached_sandbox, thread_id, sandbox_id)
|
||||||
if cached_id is not None:
|
if cached_id is not None:
|
||||||
return cached_id
|
return cached_id
|
||||||
|
|
||||||
@@ -790,14 +887,20 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
thread on its next turn without a cold-start. The container will only be
|
thread on its next turn without a cold-start. The container will only be
|
||||||
stopped when the replicas limit forces eviction or during shutdown.
|
stopped when the replicas limit forces eviction or during shutdown.
|
||||||
|
|
||||||
|
The host-side HTTP client owned by the cached ``AioSandbox`` instance is
|
||||||
|
closed before the instance is dropped (#2872). The warm-pool entry only
|
||||||
|
stores ``SandboxInfo``, so a fresh ``AioSandbox`` (and a fresh client)
|
||||||
|
is constructed if the container is later reclaimed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sandbox_id: The ID of the sandbox to release.
|
sandbox_id: The ID of the sandbox to release.
|
||||||
"""
|
"""
|
||||||
info = None
|
info = None
|
||||||
|
sandbox = None
|
||||||
thread_ids_to_remove: list[str] = []
|
thread_ids_to_remove: list[str] = []
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._sandboxes.pop(sandbox_id, None)
|
sandbox = self._sandboxes.pop(sandbox_id, None)
|
||||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||||
for tid in thread_ids_to_remove:
|
for tid in thread_ids_to_remove:
|
||||||
@@ -807,6 +910,15 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
if info and sandbox_id not in self._warm_pool:
|
if info and sandbox_id not in self._warm_pool:
|
||||||
self._warm_pool[sandbox_id] = (info, time.time())
|
self._warm_pool[sandbox_id] = (info, time.time())
|
||||||
|
|
||||||
|
if sandbox is not None:
|
||||||
|
# Defense-in-depth: close() already swallows its own errors; this
|
||||||
|
# guard only protects against a future close() that misbehaves, so
|
||||||
|
# host-side client cleanup can never block parking in the warm pool.
|
||||||
|
try:
|
||||||
|
sandbox.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing sandbox {sandbox_id} during release: {e}")
|
||||||
|
|
||||||
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
|
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
|
||||||
|
|
||||||
def destroy(self, sandbox_id: str) -> None:
|
def destroy(self, sandbox_id: str) -> None:
|
||||||
@@ -815,24 +927,23 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
Unlike release(), this actually stops the container. Use this for
|
Unlike release(), this actually stops the container. Use this for
|
||||||
explicit cleanup, capacity-driven eviction, or shutdown.
|
explicit cleanup, capacity-driven eviction, or shutdown.
|
||||||
|
|
||||||
|
The host-side HTTP client owned by the cached ``AioSandbox`` instance is
|
||||||
|
closed alongside backend/container destruction so no client/socket
|
||||||
|
resources leak (#2872).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sandbox_id: The ID of the sandbox to destroy.
|
sandbox_id: The ID of the sandbox to destroy.
|
||||||
"""
|
"""
|
||||||
info = None
|
sandbox, info, _ = self._remove_tracked_sandbox(sandbox_id)
|
||||||
thread_ids_to_remove: list[str] = []
|
|
||||||
|
|
||||||
with self._lock:
|
if sandbox is not None:
|
||||||
self._sandboxes.pop(sandbox_id, None)
|
# Defense-in-depth: close() already swallows its own errors; this
|
||||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
# guard only protects against a future close() that misbehaves, so
|
||||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
# host-side client cleanup can never block container destruction.
|
||||||
for tid in thread_ids_to_remove:
|
try:
|
||||||
del self._thread_sandboxes[tid]
|
sandbox.close()
|
||||||
self._last_activity.pop(sandbox_id, None)
|
except Exception as e:
|
||||||
# Also pull from warm pool if it was parked there
|
logger.warning(f"Error closing sandbox {sandbox_id} during destroy: {e}")
|
||||||
if info is None and sandbox_id in self._warm_pool:
|
|
||||||
info, _ = self._warm_pool.pop(sandbox_id)
|
|
||||||
else:
|
|
||||||
self._warm_pool.pop(sandbox_id, None)
|
|
||||||
|
|
||||||
if info:
|
if info:
|
||||||
self._backend.destroy(info)
|
self._backend.destroy(info)
|
||||||
|
|||||||
@@ -169,6 +169,24 @@ def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str |
|
|||||||
return "0.0.0.0"
|
return "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_no_such_container_error(stderr: str, container_name: str) -> bool:
|
||||||
|
"""Return True only when stderr definitively says the container does not exist.
|
||||||
|
|
||||||
|
Docker reports "No such object" / "No such container". Apple Container
|
||||||
|
reports a generic "not found", so that phrase is only trusted when the
|
||||||
|
message also names the inspected container (or refers to a
|
||||||
|
container/object); transient failures whose text happens to contain
|
||||||
|
"not found" (e.g. "command not found", "context not found") must stay on
|
||||||
|
the raise path instead of being misread as a dead container.
|
||||||
|
"""
|
||||||
|
message = stderr.lower()
|
||||||
|
if "no such object" in message or "no such container" in message:
|
||||||
|
return True
|
||||||
|
if "not found" not in message:
|
||||||
|
return False
|
||||||
|
return container_name.lower() in message or "container" in message or "object" in message
|
||||||
|
|
||||||
|
|
||||||
class LocalContainerBackend(SandboxBackend):
|
class LocalContainerBackend(SandboxBackend):
|
||||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||||
|
|
||||||
@@ -335,11 +353,21 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
sandbox_id: The deterministic sandbox ID (determines container name).
|
sandbox_id: The deterministic sandbox ID (determines container name).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SandboxInfo if container found and healthy, None otherwise.
|
SandboxInfo if container found and healthy, None otherwise. A
|
||||||
|
failed runtime check (e.g. transient daemon error) also returns
|
||||||
|
None — discovery must not adopt a container it cannot verify, and
|
||||||
|
falling through to create keeps acquire recoverable instead of
|
||||||
|
hard-failing on a hiccup.
|
||||||
"""
|
"""
|
||||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||||
|
|
||||||
if not self._is_container_running(container_name):
|
try:
|
||||||
|
running = self._is_container_running(container_name)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Could not verify container {container_name} during discovery; not adopting it: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not running:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
port = self._get_container_port(container_name)
|
port = self._get_container_port(container_name)
|
||||||
@@ -582,6 +610,13 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
|
|
||||||
This enables cross-process container discovery — any process can detect
|
This enables cross-process container discovery — any process can detect
|
||||||
containers started by another process via the deterministic container name.
|
containers started by another process via the deterministic container name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the container runtime cannot answer the inspect
|
||||||
|
query. A failed check is intentionally distinct from a
|
||||||
|
definitive "container does not exist" result so callers do not
|
||||||
|
destroy healthy containers during transient Docker/Container
|
||||||
|
daemon failures.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
@@ -590,9 +625,14 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
text=True,
|
text=True,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
except subprocess.TimeoutExpired as exc:
|
||||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
raise RuntimeError(f"Timed out checking container {container_name}") from exc
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return result.stdout.strip().lower() == "true"
|
||||||
|
if _is_no_such_container_error(result.stderr, container_name):
|
||||||
return False
|
return False
|
||||||
|
raise RuntimeError(f"Failed to inspect container {container_name}: {result.stderr.strip()}")
|
||||||
|
|
||||||
def _get_container_port(self, container_name: str) -> int | None:
|
def _get_container_port(self, container_name: str) -> int | None:
|
||||||
"""Get the host port of a running container.
|
"""Get the host port of a running container.
|
||||||
|
|||||||
@@ -176,12 +176,16 @@ class RemoteSandboxBackend(SandboxBackend):
|
|||||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
if resp.ok:
|
except requests.RequestException as exc:
|
||||||
data = resp.json()
|
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: {exc}") from exc
|
||||||
return data.get("status") == "Running"
|
|
||||||
return False
|
if resp.status_code == 404:
|
||||||
except requests.RequestException:
|
|
||||||
return False
|
return False
|
||||||
|
if not resp.ok:
|
||||||
|
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: HTTP {resp.status_code} {resp.text}")
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
return data.get("status") == "Running"
|
||||||
|
|
||||||
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||||
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .tools import web_search_tool
|
||||||
|
|
||||||
|
__all__ = ["web_search_tool"]
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
Web Search Tool - Search the web using the Brave Search API.
|
||||||
|
|
||||||
|
Brave Search provides web results from an independent search index via a
|
||||||
|
REST API. An API key is required. Sign up at https://brave.com/search/api/
|
||||||
|
to get one.
|
||||||
|
|
||||||
|
Unlike the DuckDuckGo ``backend: brave`` option (which scrapes results via the
|
||||||
|
DDGS aggregator), this provider calls the official Brave Search API directly,
|
||||||
|
giving structured results, authenticated quota, and a documented SLA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
_DEFAULT_MAX_RESULTS = 5
|
||||||
|
# Brave Search API caps the `count` parameter at 20 results per request.
|
||||||
|
_BRAVE_MAX_COUNT = 20
|
||||||
|
_api_key_warned = False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_key() -> str | None:
|
||||||
|
config = get_app_config().get_tool_config("web_search")
|
||||||
|
if config is not None:
|
||||||
|
api_key = (config.model_extra or {}).get("api_key")
|
||||||
|
if isinstance(api_key, str) and api_key.strip():
|
||||||
|
return api_key
|
||||||
|
return os.getenv("BRAVE_SEARCH_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_max_results(value: object, *, default: int = _DEFAULT_MAX_RESULTS) -> int:
|
||||||
|
try:
|
||||||
|
coerced = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning(
|
||||||
|
"Invalid Brave Search max_results=%r; using default %s",
|
||||||
|
value,
|
||||||
|
default,
|
||||||
|
)
|
||||||
|
coerced = default
|
||||||
|
|
||||||
|
return max(1, min(coerced, _BRAVE_MAX_COUNT))
|
||||||
|
|
||||||
|
|
||||||
|
@tool("web_search", parse_docstring=True)
|
||||||
|
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||||
|
"""Search the web for information using Brave Search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search keywords describing what you want to find. Be specific for better results.
|
||||||
|
max_results: Maximum number of search results to return. Default is 5.
|
||||||
|
"""
|
||||||
|
global _api_key_warned
|
||||||
|
|
||||||
|
config = get_app_config().get_tool_config("web_search")
|
||||||
|
if config is not None and "max_results" in (config.model_extra or {}):
|
||||||
|
max_results = config.model_extra["max_results"]
|
||||||
|
|
||||||
|
count = _coerce_max_results(max_results)
|
||||||
|
|
||||||
|
api_key = _get_api_key()
|
||||||
|
if not api_key:
|
||||||
|
if not _api_key_warned:
|
||||||
|
_api_key_warned = True
|
||||||
|
logger.warning("Brave Search API key is not set. Set BRAVE_SEARCH_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://brave.com/search/api/")
|
||||||
|
return json.dumps(
|
||||||
|
{"error": "BRAVE_SEARCH_API_KEY is not configured", "query": query},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"X-Subscription-Token": api_key,
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
params = {"q": query, "count": count, "text_decorations": False}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=30) as client:
|
||||||
|
response = client.get(_BRAVE_ENDPOINT, headers=headers, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"Brave Search API returned HTTP {e.response.status_code}: {e.response.text}")
|
||||||
|
return json.dumps(
|
||||||
|
{"error": f"Brave Search API error: HTTP {e.response.status_code}", "query": query},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Brave search failed: {type(e).__name__}: {e}")
|
||||||
|
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||||
|
|
||||||
|
web_results = (data.get("web") or {}).get("results", [])
|
||||||
|
if not web_results:
|
||||||
|
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
|
||||||
|
|
||||||
|
normalized_results = [
|
||||||
|
{
|
||||||
|
"title": r.get("title", ""),
|
||||||
|
"url": r.get("url", ""),
|
||||||
|
"content": r.get("description", ""),
|
||||||
|
}
|
||||||
|
for r in web_results
|
||||||
|
]
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"query": query,
|
||||||
|
"total_results": len(normalized_results),
|
||||||
|
"results": normalized_results,
|
||||||
|
}
|
||||||
|
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .browserless_client import BrowserlessClient
|
||||||
|
from .tools import web_fetch_tool
|
||||||
|
|
||||||
|
__all__ = ["BrowserlessClient", "web_fetch_tool"]
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserlessClient:
|
||||||
|
"""Client for Browserless headless Chrome API."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, token: str = "", timeout_s: float = 30) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.token = token
|
||||||
|
self.timeout_s = timeout_s
|
||||||
|
|
||||||
|
async def fetch_html(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
wait_for_event: str = "",
|
||||||
|
wait_for_timeout_ms: int = 0,
|
||||||
|
wait_for_selector: str = "",
|
||||||
|
wait_for_selector_timeout_ms: int = 5000,
|
||||||
|
reject_resource_types: list[str] | None = None,
|
||||||
|
reject_request_pattern: list[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch the rendered HTML of a page using Browserless.
|
||||||
|
|
||||||
|
Only sends accepted parameters for the current Browserless API version.
|
||||||
|
Sets a default navigation timeout (30s) via query param.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to fetch.
|
||||||
|
wait_for_event: Wait for a page event (e.g. "networkidle", "load").
|
||||||
|
wait_for_timeout_ms: Extra wait after page load.
|
||||||
|
wait_for_selector: CSS selector to wait for.
|
||||||
|
wait_for_selector_timeout_ms: Timeout for selector wait.
|
||||||
|
reject_resource_types: Resource types to block (e.g. ["image"]).
|
||||||
|
reject_request_pattern: URL patterns to block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered HTML content.
|
||||||
|
"""
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"url": url,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.token:
|
||||||
|
payload["token"] = self.token
|
||||||
|
if wait_for_event:
|
||||||
|
payload["waitForEvent"] = wait_for_event
|
||||||
|
if wait_for_timeout_ms > 0:
|
||||||
|
payload["waitForTimeout"] = wait_for_timeout_ms
|
||||||
|
if wait_for_selector:
|
||||||
|
payload["waitForSelector"] = {
|
||||||
|
"selector": wait_for_selector,
|
||||||
|
"timeout": wait_for_selector_timeout_ms,
|
||||||
|
}
|
||||||
|
if reject_resource_types:
|
||||||
|
payload["rejectResourceTypes"] = reject_resource_types
|
||||||
|
if reject_request_pattern:
|
||||||
|
payload["rejectRequestPattern"] = reject_request_pattern
|
||||||
|
|
||||||
|
logger.debug(f"Fetching URL via Browserless: {url}")
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout_s) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{self.base_url}/content",
|
||||||
|
json=payload,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
code = resp.status_code
|
||||||
|
target_code = resp.headers.get("X-Response-Code", "")
|
||||||
|
target_status = resp.headers.get("X-Response-Status", "")
|
||||||
|
|
||||||
|
logger.debug(f"Browserless response: code={code}, target_code={target_code}, target_status={target_status}")
|
||||||
|
|
||||||
|
if code != 200:
|
||||||
|
return f"Error: Browserless HTTP {code}: {resp.text[:200]}"
|
||||||
|
|
||||||
|
html = resp.text
|
||||||
|
if not html or not html.strip():
|
||||||
|
return "Error: Browserless returned empty response"
|
||||||
|
|
||||||
|
return html
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return f"Error: Browserless request timed out after {self.timeout_s}s"
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Browserless request failed: {e}")
|
||||||
|
return f"Error: Browserless request failed: {e!s}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Browserless fetch failed: {e}")
|
||||||
|
return f"Error: Browserless fetch failed: {e!s}"
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
from deerflow.utils.readability import ReadabilityExtractor
|
||||||
|
|
||||||
|
from .browserless_client import BrowserlessClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# readability_extractor runs CPU-bound parsing; always call via asyncio.to_thread
|
||||||
|
_readability_extractor = ReadabilityExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_config(tool_name: str) -> dict | None:
|
||||||
|
"""Get tool config extras safely, returning None if not configured."""
|
||||||
|
config = get_app_config().get_tool_config(tool_name)
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
extras = config.model_extra
|
||||||
|
return extras if extras is not None else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_browserless_client() -> BrowserlessClient:
|
||||||
|
cfg = _get_tool_config("web_fetch")
|
||||||
|
base_url = "http://localhost:3032"
|
||||||
|
token = ""
|
||||||
|
timeout_s = 30.0
|
||||||
|
if cfg is not None:
|
||||||
|
base_url = cfg.get("base_url", base_url)
|
||||||
|
token = cfg.get("token", token)
|
||||||
|
raw = cfg.get("timeout_s", timeout_s)
|
||||||
|
timeout_s = float(raw) if not isinstance(raw, float) else raw
|
||||||
|
return BrowserlessClient(base_url=base_url, token=token, timeout_s=timeout_s)
|
||||||
|
|
||||||
|
|
||||||
|
@tool("web_fetch", parse_docstring=True)
|
||||||
|
async def web_fetch_tool(url: str) -> str:
|
||||||
|
"""Fetch the contents of a web page at a given URL using Browserless (headless Chrome).
|
||||||
|
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||||
|
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||||
|
Do NOT add www. to URLs that do NOT have them.
|
||||||
|
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to fetch the contents of.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cfg = _get_tool_config("web_fetch")
|
||||||
|
|
||||||
|
wait_for_event = ""
|
||||||
|
wait_for_timeout_ms = 0
|
||||||
|
wait_for_selector = ""
|
||||||
|
wait_for_selector_timeout_ms = 5000
|
||||||
|
reject_resource_types: list[str] | None = None
|
||||||
|
reject_request_pattern: list[str] | None = None
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
wait_for_event = cfg.get("wait_for_event", wait_for_event)
|
||||||
|
raw_wait = cfg.get("wait_for_timeout_ms", wait_for_timeout_ms)
|
||||||
|
wait_for_timeout_ms = int(raw_wait) if not isinstance(raw_wait, int) else raw_wait
|
||||||
|
wait_for_selector = cfg.get("wait_for_selector", wait_for_selector)
|
||||||
|
|
||||||
|
client = _get_browserless_client()
|
||||||
|
html = await client.fetch_html(
|
||||||
|
url=url,
|
||||||
|
wait_for_event=wait_for_event,
|
||||||
|
wait_for_timeout_ms=wait_for_timeout_ms,
|
||||||
|
wait_for_selector=wait_for_selector,
|
||||||
|
wait_for_selector_timeout_ms=wait_for_selector_timeout_ms,
|
||||||
|
reject_resource_types=reject_resource_types,
|
||||||
|
reject_request_pattern=reject_request_pattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
if html.startswith("Error:"):
|
||||||
|
return html
|
||||||
|
|
||||||
|
article = await asyncio.to_thread(_readability_extractor.extract_article, html)
|
||||||
|
return article.to_markdown()[:4096]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in web_fetch_tool: {e}")
|
||||||
|
return f"Error: {str(e)}"
|
||||||
@@ -11,12 +11,85 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_BACKEND = "auto"
|
||||||
|
DEFAULT_REGION = "wt-wt"
|
||||||
|
DEFAULT_SAFESEARCH = "moderate"
|
||||||
|
DEFAULT_WIKIPEDIA_REGION = "us-en"
|
||||||
|
|
||||||
|
WIKIPEDIA_BACKENDS = {"auto", "all", "wikipedia"}
|
||||||
|
WIKIPEDIA_LANGUAGE_ALIASES = {
|
||||||
|
"jp": "ja",
|
||||||
|
"kr": "ko",
|
||||||
|
"tzh": "zh",
|
||||||
|
"wt": "en",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_backend(backend: str | list[str] | tuple[str, ...] | None) -> str:
|
||||||
|
if backend is None:
|
||||||
|
return DEFAULT_BACKEND
|
||||||
|
if isinstance(backend, (list, tuple)):
|
||||||
|
return ",".join(str(part).strip() for part in backend if str(part).strip()) or DEFAULT_BACKEND
|
||||||
|
return str(backend).strip() or DEFAULT_BACKEND
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_setting(value: str | None, default: str) -> str:
|
||||||
|
return str(value).strip() if value else default
|
||||||
|
|
||||||
|
|
||||||
|
def _backend_includes_wikipedia(backend: str | list[str] | tuple[str, ...] | None) -> bool:
|
||||||
|
backend = _normalize_backend(backend)
|
||||||
|
return any(part.strip().lower() in WIKIPEDIA_BACKENDS for part in backend.split(","))
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_codepoint(query: str, ranges: tuple[tuple[int, int], ...]) -> bool:
|
||||||
|
return any(start <= ord(char) <= end for char in query for start, end in ranges)
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_wikipedia_region(query: str) -> str:
|
||||||
|
"""Pick a valid Wikipedia language region when DDGS' worldwide region is used."""
|
||||||
|
if _contains_codepoint(query, ((0x3040, 0x30FF), (0x31F0, 0x31FF))):
|
||||||
|
return "jp-ja"
|
||||||
|
if _contains_codepoint(query, ((0xAC00, 0xD7AF), (0x1100, 0x11FF), (0x3130, 0x318F))):
|
||||||
|
return "kr-ko"
|
||||||
|
if _contains_codepoint(query, ((0x3400, 0x9FFF),)):
|
||||||
|
return "cn-zh"
|
||||||
|
if _contains_codepoint(query, ((0x0400, 0x04FF),)):
|
||||||
|
return "ru-ru"
|
||||||
|
if _contains_codepoint(query, ((0x0370, 0x03FF),)):
|
||||||
|
return "gr-el"
|
||||||
|
if _contains_codepoint(query, ((0x0590, 0x05FF),)):
|
||||||
|
return "il-he"
|
||||||
|
if _contains_codepoint(query, ((0x0600, 0x06FF),)):
|
||||||
|
return "xa-ar"
|
||||||
|
return DEFAULT_WIKIPEDIA_REGION
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_ddgs_region(query: str, region: str | None, backend: str | list[str] | tuple[str, ...] | None) -> str:
|
||||||
|
"""
|
||||||
|
DDGS' wikipedia engine treats the second part of region as a Wikipedia
|
||||||
|
subdomain. Its default worldwide region, wt-wt, becomes wt.wikipedia.org.
|
||||||
|
"""
|
||||||
|
normalized_region = _normalize_setting(region, DEFAULT_REGION).lower()
|
||||||
|
if not _backend_includes_wikipedia(backend):
|
||||||
|
return normalized_region
|
||||||
|
|
||||||
|
if normalized_region == DEFAULT_REGION:
|
||||||
|
return _infer_wikipedia_region(query)
|
||||||
|
|
||||||
|
if "-" not in normalized_region:
|
||||||
|
return DEFAULT_WIKIPEDIA_REGION
|
||||||
|
|
||||||
|
country, language = normalized_region.split("-", 1)
|
||||||
|
return f"{country}-{WIKIPEDIA_LANGUAGE_ALIASES.get(language, language)}"
|
||||||
|
|
||||||
|
|
||||||
def _search_text(
|
def _search_text(
|
||||||
query: str,
|
query: str,
|
||||||
max_results: int = 5,
|
max_results: int = 5,
|
||||||
region: str = "wt-wt",
|
region: str | None = DEFAULT_REGION,
|
||||||
safesearch: str = "moderate",
|
safesearch: str | None = DEFAULT_SAFESEARCH,
|
||||||
|
backend: str | list[str] | tuple[str, ...] | None = DEFAULT_BACKEND,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Execute text search using DuckDuckGo.
|
Execute text search using DuckDuckGo.
|
||||||
@@ -26,6 +99,7 @@ def _search_text(
|
|||||||
max_results: Maximum number of results
|
max_results: Maximum number of results
|
||||||
region: Search region
|
region: Search region
|
||||||
safesearch: Safe search level
|
safesearch: Safe search level
|
||||||
|
backend: DDGS backend(s), e.g. "auto", "duckduckgo", or "duckduckgo,brave"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of search results
|
List of search results
|
||||||
@@ -39,11 +113,15 @@ def _search_text(
|
|||||||
ddgs = DDGS(timeout=30)
|
ddgs = DDGS(timeout=30)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
backend = _normalize_backend(backend)
|
||||||
|
safesearch = _normalize_setting(safesearch, DEFAULT_SAFESEARCH)
|
||||||
|
effective_region = _resolve_ddgs_region(query, region, backend)
|
||||||
results = ddgs.text(
|
results = ddgs.text(
|
||||||
query,
|
query,
|
||||||
region=region,
|
region=effective_region,
|
||||||
safesearch=safesearch,
|
safesearch=safesearch,
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
return list(results) if results else []
|
return list(results) if results else []
|
||||||
|
|
||||||
@@ -64,14 +142,23 @@ def web_search_tool(
|
|||||||
max_results: Maximum number of results to return. Default is 5.
|
max_results: Maximum number of results to return. Default is 5.
|
||||||
"""
|
"""
|
||||||
config = get_app_config().get_tool_config("web_search")
|
config = get_app_config().get_tool_config("web_search")
|
||||||
|
region = DEFAULT_REGION
|
||||||
|
safesearch = DEFAULT_SAFESEARCH
|
||||||
|
backend = DEFAULT_BACKEND
|
||||||
|
|
||||||
# Override max_results from config if set
|
if config is not None:
|
||||||
if config is not None and "max_results" in config.model_extra:
|
# Override tool call defaults from config if set.
|
||||||
max_results = config.model_extra.get("max_results", max_results)
|
max_results = config.model_extra.get("max_results", max_results)
|
||||||
|
region = config.model_extra.get("region", region)
|
||||||
|
safesearch = config.model_extra.get("safesearch", safesearch)
|
||||||
|
backend = config.model_extra.get("backend", backend)
|
||||||
|
|
||||||
results = _search_text(
|
results = _search_text(
|
||||||
query=query,
|
query=query,
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
|
region=region,
|
||||||
|
safesearch=safesearch,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ _api_key_warned = False
|
|||||||
|
|
||||||
|
|
||||||
class JinaClient:
|
class JinaClient:
|
||||||
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10, proxy: str | None = None, trust_env: bool = True) -> str:
|
||||||
global _api_key_warned
|
global _api_key_warned
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -23,7 +23,10 @@ class JinaClient:
|
|||||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||||
data = {"url": url}
|
data = {"url": url}
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
client_kwargs: dict[str, object] = {"trust_env": trust_env}
|
||||||
|
if proxy:
|
||||||
|
client_kwargs["proxy"] = proxy
|
||||||
|
async with httpx.AsyncClient(**client_kwargs) as client:
|
||||||
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|||||||
@@ -9,6 +9,38 @@ from deerflow.utils.readability import ReadabilityExtractor
|
|||||||
readability_extractor = ReadabilityExtractor()
|
readability_extractor = ReadabilityExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(value: object, default: bool) -> bool:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if normalized in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_timeout(value: object, default: int) -> int:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return default
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_proxy(value: object) -> str | None:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return None
|
||||||
|
proxy = value.strip()
|
||||||
|
return proxy or None
|
||||||
|
|
||||||
|
|
||||||
@tool("web_fetch", parse_docstring=True)
|
@tool("web_fetch", parse_docstring=True)
|
||||||
async def web_fetch_tool(url: str) -> str:
|
async def web_fetch_tool(url: str) -> str:
|
||||||
"""Fetch the contents of a web page at a given URL.
|
"""Fetch the contents of a web page at a given URL.
|
||||||
@@ -22,10 +54,14 @@ async def web_fetch_tool(url: str) -> str:
|
|||||||
"""
|
"""
|
||||||
jina_client = JinaClient()
|
jina_client = JinaClient()
|
||||||
timeout = 10
|
timeout = 10
|
||||||
|
proxy = None
|
||||||
|
trust_env = True
|
||||||
config = get_app_config().get_tool_config("web_fetch")
|
config = get_app_config().get_tool_config("web_fetch")
|
||||||
if config is not None and "timeout" in config.model_extra:
|
if config is not None:
|
||||||
timeout = config.model_extra.get("timeout")
|
timeout = _coerce_timeout(config.model_extra.get("timeout"), timeout)
|
||||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
proxy = _coerce_proxy(config.model_extra.get("proxy"))
|
||||||
|
trust_env = _coerce_bool(config.model_extra.get("trust_env"), trust_env)
|
||||||
|
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout, proxy=proxy, trust_env=trust_env)
|
||||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||||
return html_content
|
return html_content
|
||||||
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
|
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .tools import web_search_tool
|
||||||
|
|
||||||
|
__all__ = ["web_search_tool"]
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearxngClient:
|
||||||
|
"""Client for SearXNG meta search engine API."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
max_results: int = 5,
|
||||||
|
categories: list[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search the web using SearXNG.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
max_results: Maximum number of results to return.
|
||||||
|
categories: Search categories to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search result dictionaries.
|
||||||
|
"""
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"q": query,
|
||||||
|
"format": "json",
|
||||||
|
"language": "auto",
|
||||||
|
"pageno": 1,
|
||||||
|
}
|
||||||
|
if max_results:
|
||||||
|
params["limit"] = max_results
|
||||||
|
if categories:
|
||||||
|
params["categories"] = ",".join(categories)
|
||||||
|
|
||||||
|
logger.debug(f"Searching SearXNG at {self.base_url} with query: {query}")
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{self.base_url}/search",
|
||||||
|
params=params,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "Mozilla/5.0 (compatible; DeerFlow/1.0)",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
results = data.get("results", [])
|
||||||
|
return results[:max_results] if max_results else results
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"SearXNG search returned error status: {e}")
|
||||||
|
raise
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"SearXNG search request failed: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An unexpected error occurred during SearXNG search: {e}")
|
||||||
|
raise
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
from .searxng_client import SearxngClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_config(tool_name: str) -> dict | None:
|
||||||
|
"""Get tool config extras safely, returning None if not configured."""
|
||||||
|
config = get_app_config().get_tool_config(tool_name)
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
extras = config.model_extra
|
||||||
|
return extras if extras is not None else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_searxng_client() -> SearxngClient:
|
||||||
|
cfg = _get_tool_config("web_search")
|
||||||
|
base_url = "http://localhost:8088"
|
||||||
|
if cfg is not None:
|
||||||
|
base_url = cfg.get("base_url", base_url)
|
||||||
|
return SearxngClient(base_url=base_url)
|
||||||
|
|
||||||
|
|
||||||
|
@tool("web_search", parse_docstring=True)
|
||||||
|
async def web_search_tool(query: str) -> str:
|
||||||
|
"""Search the web using SearXNG.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to search for.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cfg = _get_tool_config("web_search")
|
||||||
|
max_results = 5
|
||||||
|
if cfg is not None:
|
||||||
|
raw = cfg.get("max_results", max_results)
|
||||||
|
max_results = int(raw) if not isinstance(raw, int) else raw
|
||||||
|
|
||||||
|
client = _get_searxng_client()
|
||||||
|
results = await client.search(query, max_results=max_results)
|
||||||
|
|
||||||
|
normalized = [
|
||||||
|
{
|
||||||
|
"title": r.get("title", ""),
|
||||||
|
"url": r.get("url", ""),
|
||||||
|
"snippet": r.get("content", ""),
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
return json.dumps(normalized, indent=2, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in web_search_tool: {e}")
|
||||||
|
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||||
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
|||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
effective_user = user_id or get_effective_user_id()
|
effective_user = user_id or get_effective_user_id()
|
||||||
user_path = paths.user_agent_dir(effective_user, name)
|
user_path = paths.user_agent_dir(effective_user, name)
|
||||||
if user_path.exists():
|
# Require config.yaml to confirm this is a genuine agent directory,
|
||||||
|
# not a leftover from memory/storage writes (see #3390).
|
||||||
|
if user_path.exists() and (user_path / "config.yaml").exists():
|
||||||
return user_path
|
return user_path
|
||||||
|
|
||||||
legacy_path = paths.agent_dir(name)
|
legacy_path = paths.agent_dir(name)
|
||||||
if legacy_path.exists():
|
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
|
||||||
return legacy_path
|
return legacy_path
|
||||||
|
|
||||||
return user_path
|
return user_path
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ from typing import Any, Self
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
||||||
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||||
from deerflow.config.database_config import DatabaseConfig
|
from deerflow.config.database_config import DatabaseConfig
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig
|
from deerflow.config.extensions_config import ExtensionsConfig
|
||||||
@@ -18,6 +19,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_
|
|||||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||||
from deerflow.config.model_config import ModelConfig
|
from deerflow.config.model_config import ModelConfig
|
||||||
|
from deerflow.config.reload_boundary import format_field_description
|
||||||
from deerflow.config.run_events_config import RunEventsConfig
|
from deerflow.config.run_events_config import RunEventsConfig
|
||||||
from deerflow.config.runtime_paths import existing_project_file
|
from deerflow.config.runtime_paths import existing_project_file
|
||||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||||
@@ -85,10 +87,21 @@ def apply_logging_level(name: str | None) -> None:
|
|||||||
class AppConfig(BaseModel):
|
class AppConfig(BaseModel):
|
||||||
"""Config for the DeerFlow application"""
|
"""Config for the DeerFlow application"""
|
||||||
|
|
||||||
log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected")
|
log_level: str = Field(
|
||||||
|
default="info",
|
||||||
|
description=format_field_description(
|
||||||
|
"log_level",
|
||||||
|
field_doc="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected.",
|
||||||
|
),
|
||||||
|
)
|
||||||
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
||||||
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
||||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
sandbox: SandboxConfig = Field(
|
||||||
|
description=format_field_description(
|
||||||
|
"sandbox",
|
||||||
|
field_doc="Sandbox provider configuration (local filesystem or Docker-based aio sandbox).",
|
||||||
|
),
|
||||||
|
)
|
||||||
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
|
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
|
||||||
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
|
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
|
||||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||||
@@ -104,13 +117,59 @@ class AppConfig(BaseModel):
|
|||||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||||
|
channel_connections: ChannelConnectionsConfig = Field(
|
||||||
|
default_factory=ChannelConnectionsConfig,
|
||||||
|
description=format_field_description(
|
||||||
|
"channel_connections",
|
||||||
|
field_doc="User-facing IM channel connection configuration.",
|
||||||
|
),
|
||||||
|
)
|
||||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
database: DatabaseConfig = Field(
|
||||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
default_factory=DatabaseConfig,
|
||||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
description=format_field_description(
|
||||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
"database",
|
||||||
|
field_doc="Unified database backend for run/feedback metadata (memory, sqlite, or postgres).",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
run_events: RunEventsConfig = Field(
|
||||||
|
default_factory=RunEventsConfig,
|
||||||
|
description=format_field_description(
|
||||||
|
"run_events",
|
||||||
|
field_doc="Run-event store backend (memory for dev, db for production queries, jsonl for lightweight single-node persistence).",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
checkpointer: CheckpointerConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=format_field_description(
|
||||||
|
"checkpointer",
|
||||||
|
field_doc="LangGraph state-persistence checkpointer configuration.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
stream_bridge: StreamBridgeConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=format_field_description(
|
||||||
|
"stream_bridge",
|
||||||
|
field_doc="Stream bridge connecting agent workers to SSE endpoints.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("models", "tools", "tool_groups", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _coerce_null_list_sections(cls, value: Any) -> Any:
|
||||||
|
"""Treat a present-but-empty config section as an empty list.
|
||||||
|
|
||||||
|
Commenting out every entry under a top-level YAML key — e.g. ``models:``
|
||||||
|
with only comments beneath it, exactly as shipped in
|
||||||
|
``config.example.yaml`` — makes PyYAML parse the value as ``None``.
|
||||||
|
Without this, the documented ``cp config.example.yaml config.yaml``
|
||||||
|
first-run flow crashes with an opaque ``Input should be a valid list``
|
||||||
|
pydantic error. Coercing ``None`` to ``[]`` keeps that flow working and
|
||||||
|
matches the field's own ``default_factory=list``.
|
||||||
|
"""
|
||||||
|
return [] if value is None else value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||||
@@ -173,6 +232,11 @@ class AppConfig(BaseModel):
|
|||||||
config_data["extensions"] = extensions_config.model_dump()
|
config_data["extensions"] = extensions_config.model_dump()
|
||||||
|
|
||||||
result = cls.model_validate(config_data)
|
result = cls.model_validate(config_data)
|
||||||
|
if not result.models:
|
||||||
|
logger.warning(
|
||||||
|
"No models are configured in %s. Add at least one entry under `models:` (see the commented examples in config.example.yaml) or run `make setup`.",
|
||||||
|
resolved_path,
|
||||||
|
)
|
||||||
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
|
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
|
||||||
cls._apply_singleton_configs(result, acp_agents)
|
cls._apply_singleton_configs(result, acp_agents)
|
||||||
return result
|
return result
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user